AAAI Version
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import gaussian_blur, to_tensor
|
||||
|
||||
|
||||
def _preprocess_image(image, device, dtype=torch.float16):
|
||||
image = to_tensor(image)
|
||||
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
|
||||
if image.shape[1] != 3:
|
||||
image = image.expand(-1, 3, -1, -1)
|
||||
image = F.interpolate(image, (1024, 1024))
|
||||
return image.to(dtype).to(device)
|
||||
|
||||
|
||||
def _preprocess_mask(mask, device, dtype=torch.float16):
|
||||
mask = to_tensor(mask.convert("L"))
|
||||
mask = mask.unsqueeze_(0).float() # 0 or 1
|
||||
mask = F.interpolate(mask, (1024, 1024))
|
||||
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
||||
mask[mask < 0.1] = 0
|
||||
mask[mask >= 0.1] = 1
|
||||
return mask.to(dtype).to(device)
|
||||
|
||||
|
||||
class AttentiveEraser:
|
||||
"""Attentive Eraser Pipeline + Pre- and Post-Processing."""
|
||||
|
||||
prompt = ""
|
||||
dtype = torch.float16
|
||||
|
||||
def __init__(self, num_steps=50, device=None):
|
||||
"""Create the attentive eraser.
|
||||
|
||||
Args:
|
||||
num_steps (int, optional): Number of steps in the diffusion process. Will start at 20%. Defaults to 100.
|
||||
device (_type_, optional): Device to run on. Defaults to 'cuda' if available, else 'cpu'.
|
||||
|
||||
"""
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.device = device if device else torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
|
||||
)
|
||||
self.model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_path,
|
||||
custom_pipeline="./pipelines/pipeline_stable_diffusion_xl_attentive_eraser.py",
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.pipeline.enable_attention_slicing()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
def __call__(self, image, mask):
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if isinstance(mask, np.ndarray):
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
prep_img = _preprocess_image(image, self.device, self.dtype)
|
||||
prep_mask = _preprocess_mask(mask, self.device, self.dtype)
|
||||
orig_shape = image.size
|
||||
|
||||
diff_img = (
|
||||
self.pipeline(
|
||||
prompt=self.prompt,
|
||||
image=prep_img,
|
||||
mask_image=prep_mask,
|
||||
height=1024,
|
||||
width=1024,
|
||||
AAS=True, # enable AAS
|
||||
strength=0.8, # inpainting strength
|
||||
rm_guidance_scale=9, # removal guidance scale
|
||||
ss_steps=9, # similarity suppression steps
|
||||
ss_scale=0.3, # similarity suppression scale
|
||||
AAS_start_step=0, # AAS start step
|
||||
AAS_start_layer=34, # AAS start layer
|
||||
AAS_end_layer=70, # AAS end layer
|
||||
num_inference_steps=self.num_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
||||
guidance_scale=1,
|
||||
)
|
||||
.images[0]
|
||||
.resize(orig_shape)
|
||||
)
|
||||
|
||||
paste_mask = Image.fromarray(np.array(mask) > 0 * 255)
|
||||
paste_mask = paste_mask.convert("RGB").filter(ImageFilter.GaussianBlur(radius=10)).convert("L")
|
||||
out_img = copy(image)
|
||||
out_img.paste(diff_img, (0, 0), paste_mask)
|
||||
return out_img
|
||||
Reference in New Issue
Block a user