AAAI Version
This commit is contained in:
271
AAAI Supplementary Material/ForNet Creation Code/infill_lama.py
Normal file
271
AAAI Supplementary Material/ForNet Creation Code/infill_lama.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Use the LaMa (*LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions*) model to infill the background of an image.
|
||||
|
||||
Based on: https://github.com/Sanster/IOPaint/blob/main/iopaint/model/lama.py
|
||||
"""
|
||||
|
||||
import abc
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.hub import download_url_to_file
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
|
||||
|
||||
|
||||
class LDMSampler(str, Enum):
|
||||
ddim = "ddim"
|
||||
plms = "plms"
|
||||
|
||||
|
||||
class SDSampler(str, Enum):
|
||||
ddim = "ddim"
|
||||
pndm = "pndm"
|
||||
k_lms = "k_lms"
|
||||
k_euler = "k_euler"
|
||||
k_euler_a = "k_euler_a"
|
||||
dpm_plus_plus = "dpm++"
|
||||
uni_pc = "uni_pc"
|
||||
|
||||
|
||||
def ceil_modulo(x, mod):
|
||||
if x % mod == 0:
|
||||
return x
|
||||
return (x // mod + 1) * mod
|
||||
|
||||
|
||||
def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img: [H, W, C]
|
||||
mod:
|
||||
square: 是否为正方形
|
||||
min_size:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
height, width = img.shape[:2]
|
||||
out_height = ceil_modulo(height, mod)
|
||||
out_width = ceil_modulo(width, mod)
|
||||
|
||||
if min_size is not None:
|
||||
assert min_size % mod == 0
|
||||
out_width = max(min_size, out_width)
|
||||
out_height = max(min_size, out_height)
|
||||
|
||||
if square:
|
||||
max_size = max(out_height, out_width)
|
||||
out_height = max_size
|
||||
out_width = max_size
|
||||
|
||||
return np.pad(
|
||||
img,
|
||||
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||
mode="symmetric",
|
||||
)
|
||||
|
||||
|
||||
def undo_pad_to_mod(img: np.ndarray, height: int, width: int):
|
||||
return img[:height, :width, :]
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
name = "base"
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
self.device = device
|
||||
self.init_model(device, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device, **kwargs): ...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask):
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W, 1] 255 为 masks 区域
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
|
||||
result = self.forward(pad_image, pad_mask)
|
||||
# result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
# result, image, mask = self.forward_post_process(result, image, mask)
|
||||
|
||||
mask = mask[:, :, np.newaxis]
|
||||
result = result[0].permute(1, 2, 0).cpu().numpy() # 400 x 504 x 3
|
||||
result = undo_pad_to_mod(result, origin_height, origin_width)
|
||||
assert result.shape[:2] == image.shape[:2], f"{result.shape[:2]} != {image.shape[:2]}"
|
||||
# result = result * 255
|
||||
mask = (mask > 0) * 255
|
||||
result = result * mask + image * (1 - (mask / 255))
|
||||
result = np.clip(result, 0, 255).astype("uint8")
|
||||
return result
|
||||
|
||||
def forward_post_process(self, result, image, mask):
|
||||
return result, image, mask
|
||||
|
||||
|
||||
def resize_np_img(np_img, size, interpolation="bicubic"):
|
||||
assert interpolation in [
|
||||
"nearest",
|
||||
"bilinear",
|
||||
"bicubic",
|
||||
], f"Unsupported interpolation: {interpolation}, use nearest, bilinear or bicubic."
|
||||
torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).float() / 255
|
||||
interp_img = torch.nn.functional.interpolate(torch_img, size=size, mode=interpolation, align_corners=True)
|
||||
return (interp_img.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
|
||||
|
||||
def resize_max_size(np_img, size_limit: int, interpolation="bicubic") -> np.ndarray:
|
||||
# Resize image's longer size to size_limit if longer size larger than size_limit
|
||||
h, w = np_img.shape[:2]
|
||||
if max(h, w) > size_limit:
|
||||
ratio = size_limit / max(h, w)
|
||||
new_w = int(w * ratio + 0.5)
|
||||
new_h = int(h * ratio + 0.5)
|
||||
return resize_np_img(np_img, size=(new_w, new_h), interpolation=interpolation)
|
||||
else:
|
||||
return np_img
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
np_img = np_img[:, :, np.newaxis]
|
||||
np_img = np.transpose(np_img, (2, 0, 1))
|
||||
np_img = np_img.astype("float32") / 255
|
||||
return np_img
|
||||
|
||||
|
||||
def get_cache_path_by_url(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = "~/.TORCH_HUB_CACHE"
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
return cached_file
|
||||
|
||||
|
||||
def md5sum(filename):
|
||||
md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def download_model(url, model_md5: str = None):
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
||||
if model_md5:
|
||||
_md5 = md5sum(cached_file)
|
||||
if model_md5 == _md5:
|
||||
print(f"Download model success, md5: {_md5}")
|
||||
else:
|
||||
try:
|
||||
os.remove(cached_file)
|
||||
print(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart"
|
||||
" lama-cleaner.If you still have errors, please try download model manually first"
|
||||
" https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
)
|
||||
except:
|
||||
print(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart"
|
||||
" lama-cleaner."
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device, model_md5: str):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
print(f"Loading model from: {model_path}")
|
||||
try:
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
raise
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
name = "lama"
|
||||
pad_mod = 8
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# boxes = boxes_from_mask(mask)
|
||||
inpaint_result = self._pad_forward(image, mask)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||
|
||||
def forward(self, image, mask):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
return inpainted_image
|
||||
Reference in New Issue
Block a user