sd-webui-inpaint-anything/scripts/inpaint_anything.py

1212 lines
55 KiB
Python

import os
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline, DDIMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from get_dataset_colormap import create_pascal_label_colormap
from torch.hub import download_url_to_file
from torchvision import transforms
from datetime import datetime
import gc
import argparse
import platform
from PIL.PngImagePlugin import PngInfo
import time
import random
import cv2
from huggingface_hub import snapshot_download
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
import modules.scripts as scripts
from modules import shared, script_callbacks
try:
from modules.paths_internal import extensions_dir
except Exception:
from modules.extensions import extensions_dir
from modules.devices import device, torch_gc
from modules.safe import unsafe_torch_load, load
import re
from webui_controlnet import (find_controlnet, get_sd_img2img_processing,
backup_alwayson_scripts, disable_alwayson_scripts, restore_alwayson_scripts, disable_all_alwayson_scripts,
get_controlnet_args_to, clear_controlnet_cache, get_max_args_to)
from modules.processing import process_images, create_infotext
from modules.sd_samplers import samplers_for_img2img
from modules.images import resize_image
from modules.sd_models import unload_model_weights, reload_model_weights, get_closet_checkpoint_match
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
from segment_anything_hq import SamPredictor as SamPredictorHQ
from ia_logging import ia_logging
from ia_ui_items import (get_sampler_names, get_sam_model_ids, get_model_ids, get_cleaner_model_ids, get_padding_mode_names)
_DOWNLOAD_COMPLETE = "Download complete"
def download_model(sam_model_id):
"""Download SAM model.
Args:
sam_model_id (str): SAM model id
Returns:
str: download status
"""
# print(sam_model_id)
if "_hq_" in sam_model_id:
url_sam = "https://huggingface.co/Uminosachi/sam-hq/resolve/main/" + sam_model_id
else:
# url_sam_vit_h_4b8939 = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
url_sam = "https://dl.fbaipublicfiles.com/segment_anything/" + sam_model_id
models_dir = os.path.join(extensions_dir, "sd-webui-inpaint-anything", "models")
sam_checkpoint = os.path.join(models_dir, sam_model_id)
if not os.path.isfile(sam_checkpoint):
if not os.path.isdir(models_dir):
os.makedirs(models_dir, exist_ok=True)
download_url_to_file(url_sam, sam_checkpoint)
return _DOWNLOAD_COMPLETE
else:
return "Model already exists"
def download_model_from_hf(hf_model_id, local_files_only=False):
"""Download model from HuggingFace Hub.
Args:
sam_model_id (str): HuggingFace model id
local_files_only (bool, optional): If True, use only local files. Defaults to False.
Returns:
str: download status
"""
if not local_files_only:
ia_logging.info(f"Downloading {hf_model_id}")
try:
snapshot_download(repo_id=hf_model_id, local_files_only=local_files_only)
except FileNotFoundError:
return f"{hf_model_id} not found, please download"
except Exception as e:
return str(e)
return _DOWNLOAD_COMPLETE
def get_sam_mask_generator(sam_checkpoint):
"""Get SAM mask generator.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamAutomaticMaskGenerator or None: SAM mask generator
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
points_per_batch = 32
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
points_per_batch = 64
if os.path.isfile(sam_checkpoint):
torch.load = unsafe_torch_load
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
sam.to(device="cpu")
else:
sam.to(device=device)
sam_mask_generator = SamAutomaticMaskGeneratorLocal(sam, points_per_batch=points_per_batch)
torch.load = load
else:
sam_mask_generator = None
return sam_mask_generator
def get_sam_predictor(sam_checkpoint):
"""Get SAM predictor.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamPredictor or None: SAM predictor
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamPredictorLocal = SamPredictorHQ
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamPredictorLocal = SamPredictor
if os.path.isfile(sam_checkpoint):
torch.load = unsafe_torch_load
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
sam.to(device="cpu")
else:
sam.to(device=device)
sam_predictor = SamPredictorLocal(sam)
torch.load = load
else:
sam_predictor = None
return sam_predictor
ia_outputs_dir = os.path.join(os.path.dirname(extensions_dir),
"outputs", "inpaint-anything",
datetime.now().strftime("%Y-%m-%d"))
sam_dict = dict(sam_masks=None, mask_image=None, cnet=None, orig_image=None, pad_mask=None)
def update_ia_outputs_dir():
"""Update inpaint-anything outputs directory.
Returns:
None
"""
global ia_outputs_dir
config_save_folder = shared.opts.data.get("inpaint_anything_save_folder", "inpaint-anything")
if config_save_folder in ["inpaint-anything", "img2img-images"]:
ia_outputs_dir = os.path.join(os.path.dirname(extensions_dir),
"outputs", config_save_folder,
datetime.now().strftime("%Y-%m-%d"))
def save_mask_image(mask_image, save_mask_chk=False):
"""Save mask image.
Args:
mask_image (np.ndarray): mask image
save_mask_chk (bool, optional): If True, save mask image. Defaults to False.
Returns:
None
"""
global ia_outputs_dir
if save_mask_chk:
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + "created_mask" + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
Image.fromarray(mask_image).save(save_name)
def pre_unload_model_weights():
unload_model_weights()
clear_cache()
backup_ckpt_info = None
def post_reload_model_weights():
global backup_ckpt_info
if shared.sd_model is None:
reload_model_weights()
elif backup_ckpt_info is not None:
unload_model_weights()
reload_model_weights(sd_model=None, info=backup_ckpt_info)
backup_ckpt_info = None
def clear_cache():
gc.collect()
torch_gc()
def sleep_clear_cache_and_reload_model():
time.sleep(0.1)
clear_cache()
post_reload_model_weights()
def input_image_upload(input_image):
clear_cache()
global sam_dict
sam_dict["orig_image"] = input_image
sam_dict["pad_mask"] = None
def run_padding(input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode="edge"):
clear_cache()
global sam_dict
if input_image is None or sam_dict["orig_image"] is None:
sam_dict["orig_image"] = None
sam_dict["pad_mask"] = None
return None, "Input image not found"
orig_image = sam_dict["orig_image"]
height, width = orig_image.shape[:2]
pad_width, pad_height = (int(width * pad_scale_width), int(height * pad_scale_height))
ia_logging.info(f"resize by padding: ({height}, {width}) -> ({pad_height}, {pad_width})")
pad_size_w, pad_size_h = (pad_width - width, pad_height - height)
pad_size_l = int(pad_size_w * pad_lr_barance)
pad_size_r = pad_size_w - pad_size_l
pad_size_t = int(pad_size_h * pad_tb_barance)
pad_size_b = pad_size_h - pad_size_t
pad_width=[(pad_size_t, pad_size_b), (pad_size_l, pad_size_r), (0, 0)]
if padding_mode == "constant":
fill_value = shared.opts.data.get("inpaint_anything_padding_fill", 127)
pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode, constant_values=fill_value)
else:
pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode)
mask_pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r)]
pad_mask = np.zeros((height, width), dtype=np.uint8)
pad_mask = np.pad(pad_mask, pad_width=mask_pad_width, mode="constant", constant_values=255)
sam_dict["pad_mask"] = dict(segmentation=pad_mask.astype(bool))
return pad_image, "Padding done"
def run_sam(input_image, sam_model_id, sam_image):
clear_cache()
global sam_dict
if sam_dict["sam_masks"] is not None:
sam_dict["sam_masks"] = None
clear_cache()
sam_checkpoint = os.path.join(extensions_dir, "sd-webui-inpaint-anything", "models", sam_model_id)
if not os.path.isfile(sam_checkpoint):
return None, f"{sam_model_id} not found, please download"
if input_image is None:
return None, "Input image not found"
pre_unload_model_weights()
ia_logging.info(f"input_image: {input_image.shape} {input_image.dtype}")
cm_pascal = create_pascal_label_colormap()
seg_colormap = cm_pascal
seg_colormap = [c for c in seg_colormap if max(c) >= 64]
# print(len(seg_colormap))
sam_mask_generator = get_sam_mask_generator(sam_checkpoint)
ia_logging.info(f"{sam_mask_generator.__class__.__name__} {sam_model_id}")
sam_masks = sam_mask_generator.generate(input_image)
canvas_image = np.zeros_like(input_image, dtype=np.uint8)
ia_logging.info("sam_masks: {}".format(len(sam_masks)))
sam_masks = sorted(sam_masks, key=lambda x: np.sum(x.get("segmentation").astype(np.uint32)))
if sam_dict["pad_mask"] is not None:
if len(sam_masks) > 0 and sam_masks[0]["segmentation"].shape == sam_dict["pad_mask"]["segmentation"].shape:
sam_masks.insert(0, sam_dict["pad_mask"])
ia_logging.info("insert pad_mask to sam_masks")
sam_masks = sam_masks[:len(seg_colormap)]
for idx, seg_dict in enumerate(sam_masks):
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
canvas_mask = np.logical_not(canvas_image.astype(bool).any(axis=-1, keepdims=True)).astype(np.uint8)
seg_color = seg_colormap[idx] * seg_mask * canvas_mask
canvas_image = canvas_image + seg_color
seg_image = canvas_image.astype(np.uint8)
sam_dict["sam_masks"] = sam_masks
del sam_mask_generator
if sam_image is None:
return seg_image, "Segment Anything complete"
else:
if sam_image["image"].shape == seg_image.shape and np.all(sam_image["image"] == seg_image):
return gr.update(), "Segment Anything complete"
else:
return gr.update(value=seg_image), "Segment Anything complete"
def select_mask(input_image, sam_image, invert_chk, sel_mask):
clear_cache()
global sam_dict
if sam_dict["sam_masks"] is None or sam_image is None:
return None
sam_masks = sam_dict["sam_masks"]
image = sam_image["image"]
mask = sam_image["mask"][:,:,0:3]
canvas_image = np.zeros_like(image, dtype=np.uint8)
mask_region = np.zeros_like(image, dtype=np.uint8)
for idx, seg_dict in enumerate(sam_masks):
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
canvas_mask = np.logical_not(canvas_image.astype(bool).any(axis=-1, keepdims=True)).astype(np.uint8)
if (seg_mask * canvas_mask * mask).astype(bool).any():
mask_region = mask_region + (seg_mask * canvas_mask * 255)
# seg_color = seg_colormap[idx] * seg_mask * canvas_mask
seg_color = [127, 127, 127] * seg_mask * canvas_mask
canvas_image = canvas_image + seg_color
canvas_mask = np.logical_not(canvas_image.astype(bool).any(axis=-1, keepdims=True)).astype(np.uint8)
if (canvas_mask * mask).astype(bool).any():
mask_region = mask_region + (canvas_mask * 255)
seg_image = mask_region.astype(np.uint8)
if invert_chk:
seg_image = np.logical_not(seg_image.astype(bool)).astype(np.uint8) * 255
sam_dict["mask_image"] = seg_image
if input_image is not None and input_image.shape == seg_image.shape:
ret_image = cv2.addWeighted(input_image, 0.5, seg_image, 0.5, 0)
else:
ret_image = seg_image
clear_cache()
if sel_mask is None:
return ret_image
else:
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
return gr.update()
else:
return gr.update(value=ret_image)
def expand_mask(input_image, sel_mask, expand_iteration=1):
clear_cache()
global sam_dict
if sam_dict["mask_image"] is None or sel_mask is None:
return None
new_sel_mask = sam_dict["mask_image"]
expand_iteration = int(np.clip(expand_iteration, 1, 5))
new_sel_mask = cv2.dilate(new_sel_mask, np.ones((3, 3), dtype=np.uint8), iterations=expand_iteration)
sam_dict["mask_image"] = new_sel_mask
if input_image is not None and input_image.shape == new_sel_mask.shape:
ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
else:
ret_image = new_sel_mask
clear_cache()
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
return gr.update()
else:
return gr.update(value=ret_image)
def apply_mask(input_image, sel_mask):
clear_cache()
global sam_dict
if sam_dict["mask_image"] is None or sel_mask is None:
return None
sel_mask_image = sam_dict["mask_image"]
sel_mask_mask = np.logical_not(sel_mask["mask"][:,:,0:3].astype(bool)).astype(np.uint8)
new_sel_mask = sel_mask_image * sel_mask_mask
sam_dict["mask_image"] = new_sel_mask
if input_image is not None and input_image.shape == new_sel_mask.shape:
ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
else:
ret_image = new_sel_mask
clear_cache()
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
return gr.update()
else:
return gr.update(value=ret_image)
def auto_resize_to_pil(input_image, mask_image):
init_image = Image.fromarray(input_image).convert("RGB")
mask_image = Image.fromarray(mask_image).convert("RGB")
assert init_image.size == mask_image.size, "The size of image and mask do not match"
width, height = init_image.size
new_height = (height // 8) * 8
new_width = (width // 8) * 8
if new_width < width or new_height < height:
if (new_width / width) < (new_height / height):
scale = new_height / height
else:
scale = new_width / width
resize_height = int(height*scale+0.5)
resize_width = int(width*scale+0.5)
ia_logging.info(f"resize: ({height}, {width}) -> ({resize_height}, {resize_width})")
init_image = transforms.functional.resize(init_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS)
mask_image = transforms.functional.resize(mask_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS)
ia_logging.info(f"center_crop: ({resize_height}, {resize_width}) -> ({new_height}, {new_width})")
init_image = transforms.functional.center_crop(init_image, (new_height, new_width))
mask_image = transforms.functional.center_crop(mask_image, (new_height, new_width))
assert init_image.size == mask_image.size, "The size of image and mask do not match"
return init_image, mask_image
def run_inpaint(input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, model_id, save_mask_chk, composite_chk, sampler_name="DDIM"):
clear_cache()
global sam_dict
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
return None
mask_image = sam_dict["mask_image"]
if input_image.shape != mask_image.shape:
ia_logging.warning("The size of image and mask do not match")
return None
global ia_outputs_dir
update_ia_outputs_dir()
save_mask_image(mask_image, save_mask_chk)
pre_unload_model_weights()
ia_logging.info(f"Loading model {model_id}")
config_offline_inpainting = shared.opts.data.get("inpaint_anything_offline_inpainting", False)
if config_offline_inpainting:
ia_logging.info("Enable offline network Inpainting: {}".format(str(config_offline_inpainting)))
local_files_only = False
local_file_status = download_model_from_hf(model_id, local_files_only=True)
if local_file_status != _DOWNLOAD_COMPLETE:
if config_offline_inpainting:
ia_logging.warning(local_file_status)
return None
else:
local_files_only = True
ia_logging.info("local_files_only: {}".format(str(local_files_only)))
if platform.system() == "Darwin":
torch_dtype = torch.float32
else:
torch_dtype = torch.float16
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, local_files_only=local_files_only)
except Exception as e:
ia_logging.error(str(e))
if not config_offline_inpainting:
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, resume_download=True)
except Exception as e:
ia_logging.error(str(e))
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, force_download=True)
except Exception as e:
ia_logging.error(str(e))
return None
else:
return None
pipe.safety_checker = None
ia_logging.info(f"Using sampler {sampler_name}")
if sampler_name == "DDIM":
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
elif sampler_name == "Euler":
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
elif sampler_name == "Euler a":
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
elif sampler_name == "DPM2 Karras":
pipe.scheduler = KDPM2DiscreteScheduler.from_config(pipe.scheduler.config)
elif sampler_name == "DPM2 a Karras":
pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
else:
ia_logging.info("Sampler fallback to DDIM")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
if seed < 0:
seed = random.randint(0, 2147483647)
if platform.system() == "Darwin":
pipe = pipe.to("mps")
pipe.enable_attention_slicing()
generator = torch.Generator("cpu").manual_seed(seed)
else:
# pipe.enable_model_cpu_offload()
pipe = pipe.to(device)
if shared.xformers_available:
pipe.enable_xformers_memory_efficient_attention()
else:
pipe.enable_attention_slicing()
generator = torch.Generator(device).manual_seed(seed)
init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
width, height = init_image.size
pipe_args_dict = {
"prompt": prompt,
"image": init_image,
"width": width,
"height": height,
"mask_image": mask_image,
"num_inference_steps": ddim_steps,
"guidance_scale": cfg_scale,
"negative_prompt": n_prompt,
"generator": generator,
}
output_image = pipe(**pipe_args_dict).images[0]
if composite_chk:
mask_image = Image.fromarray(cv2.dilate(np.array(mask_image), np.ones((3, 3), dtype=np.uint8), iterations=4))
output_image = Image.composite(output_image, init_image, mask_image.convert("L").filter(ImageFilter.GaussianBlur(3)))
generation_params = {
"Steps": ddim_steps,
"Sampler": pipe.scheduler.__class__.__name__,
"CFG scale": cfg_scale,
"Seed": seed,
"Size": f"{width}x{height}",
"Model": model_id,
}
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
prompt_text = prompt if prompt else ""
negative_prompt_text = "Negative prompt: " + n_prompt if n_prompt else ""
infotext = f"{prompt_text}\n{negative_prompt_text}\n{generation_params_text}".strip()
metadata = PngInfo()
metadata.add_text("parameters", infotext)
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + os.path.basename(model_id) + "_" + str(seed) + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
output_image.save(save_name, pnginfo=metadata)
del pipe
return output_image
def run_cleaner(input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk):
clear_cache()
global sam_dict
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
return None
mask_image = sam_dict["mask_image"]
if input_image.shape != mask_image.shape:
ia_logging.warning("The size of image and mask do not match")
return None
global ia_outputs_dir
update_ia_outputs_dir()
save_mask_image(mask_image, cleaner_save_mask_chk)
pre_unload_model_weights()
ia_logging.info(f"Loading model {cleaner_model_id}")
if platform.system() == "Darwin":
model = ModelManager(name=cleaner_model_id, device="cpu")
else:
model = ModelManager(name=cleaner_model_id, device=device)
init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
width, height = init_image.size
init_image = np.array(init_image)
mask_image = np.array(mask_image.convert("L"))
config = Config(
ldm_steps=20,
ldm_sampler=LDMSampler.ddim,
hd_strategy=HDStrategy.ORIGINAL,
hd_strategy_crop_margin=32,
hd_strategy_crop_trigger_size=512,
hd_strategy_resize_limit=512,
prompt="",
sd_steps=20,
sd_sampler=SDSampler.ddim
)
output_image = model(image=init_image, mask=mask_image, config=config)
# print(output_image.shape, output_image.dtype, np.min(output_image), np.max(output_image))
output_image = cv2.cvtColor(output_image.astype(np.uint8), cv2.COLOR_BGR2RGB)
output_image = Image.fromarray(output_image)
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + os.path.basename(cleaner_model_id) + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
output_image.save(save_name)
del model
return output_image
def run_get_alpha_image(input_image, sel_mask):
clear_cache()
global sam_dict
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
return None
mask_image = sam_dict["mask_image"]
if input_image.shape != mask_image.shape:
ia_logging.warning("The size of image and mask do not match")
return None
alpha_image = Image.fromarray(input_image).convert("RGBA")
mask_image = Image.fromarray(mask_image).convert("L")
alpha_image.putalpha(mask_image)
global ia_outputs_dir
update_ia_outputs_dir()
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + "rgba_image" + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
alpha_image.save(save_name)
def make_checkerboard(n_rows, n_columns, square_size):
n_rows_, n_columns_ = int(n_rows/square_size + 1), int(n_columns/square_size + 1)
rows_grid, columns_grid = np.meshgrid(range(n_rows_), range(n_columns_), indexing='ij')
high_res_checkerboard = (np.mod(rows_grid, 2) + np.mod(columns_grid, 2)) == 1
square = np.ones((square_size,square_size))
checkerboard = np.kron(high_res_checkerboard, square)[:n_rows,:n_columns]
return checkerboard
checkerboard = make_checkerboard(alpha_image.size[1], alpha_image.size[0], 16)
checkerboard = np.clip((checkerboard * 255), 128, 192).astype(np.uint8)
checkerboard = Image.fromarray(checkerboard).convert("RGBA")
checkerboard.putalpha(ImageOps.invert(mask_image))
output_image = Image.alpha_composite(alpha_image, checkerboard)
clear_cache()
return output_image
def run_get_mask(sel_mask):
clear_cache()
global sam_dict
if sam_dict["mask_image"] is None or sel_mask is None:
return None
mask_image = sam_dict["mask_image"]
global ia_outputs_dir
update_ia_outputs_dir()
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + "created_mask" + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
Image.fromarray(mask_image).save(save_name)
clear_cache()
return mask_image
def run_cn_inpaint(input_image, sel_mask,
cn_prompt, cn_n_prompt, cn_sampler_id, cn_ddim_steps, cn_cfg_scale, cn_strength, cn_seed, cn_module_id, cn_model_id, cn_save_mask_chk,
cn_low_vram_chk, cn_weight, cn_mode, cn_ref_module_id=None, cn_ref_image=None, cn_ref_weight=1.0, cn_ref_mode="Balanced"):
clear_cache()
global sam_dict
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
return None
mask_image = sam_dict["mask_image"]
if input_image.shape != mask_image.shape:
ia_logging.warning("The size of image and mask do not match")
return None
if shared.sd_model is None:
reload_model_weights()
if (shared.sd_model.parameterization == "v" and "sd15" in cn_model_id):
ia_logging.warning("The SD model is not compatible with the ControlNet model")
return None
cnet = sam_dict.get("cnet", None)
if cnet is None:
ia_logging.warning("The ControlNet extension is not loaded")
return None
global ia_outputs_dir
update_ia_outputs_dir()
save_mask_image(mask_image, cn_save_mask_chk)
# print(cn_model_id)
if cn_seed < 0:
cn_seed = random.randint(0, 2147483647)
init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
width, height = init_image.size
p = get_sd_img2img_processing(init_image, None, cn_prompt, cn_n_prompt, cn_sampler_id, cn_ddim_steps, cn_cfg_scale, cn_strength, cn_seed, 1)
backup_alwayson_scripts(p.scripts)
disable_alwayson_scripts(p.scripts)
cn_units = [cnet.ControlNetUnit(
enabled=True,
module=cn_module_id,
model=cn_model_id,
weight=cn_weight,
image={"image": np.array(init_image), "mask": np.array(mask_image)},
resize_mode=cnet.ResizeMode.RESIZE,
low_vram=cn_low_vram_chk,
processor_res=min(width, height),
guidance_start=0.0,
guidance_end=1.0,
pixel_perfect=True,
control_mode=cn_mode,
)]
if cn_ref_module_id is not None and cn_ref_image is not None:
cn_ref_image = resize_image(1, Image.fromarray(cn_ref_image), width=width, height=height)
cn_units.append(cnet.ControlNetUnit(
enabled=True,
module=cn_ref_module_id,
model=None,
weight=cn_ref_weight,
image=np.array(cn_ref_image),
resize_mode=cnet.ResizeMode.RESIZE,
low_vram=cn_low_vram_chk,
processor_res=min(width, height),
guidance_start=0.0,
guidance_end=1.0,
pixel_perfect=True,
control_mode=cn_ref_mode,
))
p.script_args = np.zeros(get_controlnet_args_to(p.scripts))
cnet.update_cn_script_in_processing(p, cn_units)
processed = process_images(p)
clear_controlnet_cache(p.scripts)
restore_alwayson_scripts(p.scripts)
no_hash_cn_model_id = re.sub("\s\[[0-9a-f]{8,10}\]", "", cn_model_id).strip()
if processed is not None:
if len(processed.images) > 0:
output_image = processed.images[0]
infotext = create_infotext(p, all_prompts=[cn_prompt], all_seeds=[cn_seed], all_subseeds=[-1])
metadata = PngInfo()
metadata.add_text("parameters", infotext)
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + os.path.basename(no_hash_cn_model_id) + "_" + str(cn_seed) + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
output_image.save(save_name, pnginfo=metadata)
else:
output_image = None
else:
output_image = None
return output_image
def run_webui_inpaint(input_image, sel_mask,
webui_prompt, webui_n_prompt, webui_sampler_id, webui_ddim_steps, webui_cfg_scale, webui_strength, webui_seed, webui_model_id, webui_save_mask_chk,
webui_fill_mode):
clear_cache()
global sam_dict
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
return None
mask_image = sam_dict["mask_image"]
if input_image.shape != mask_image.shape:
ia_logging.warning("The size of image and mask do not match")
return None
global ia_outputs_dir
update_ia_outputs_dir()
save_mask_image(mask_image, webui_save_mask_chk)
info = get_closet_checkpoint_match(webui_model_id)
if info is None:
ia_logging.error(f"No model found: {webui_model_id}")
return None
global backup_ckpt_info
if shared.sd_model is not None:
backup_ckpt_info = shared.sd_model.sd_checkpoint_info
unload_model_weights()
reload_model_weights(sd_model=None, info=info)
if webui_seed < 0:
webui_seed = random.randint(0, 2147483647)
init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
width, height = init_image.size
p = get_sd_img2img_processing(init_image, mask_image, webui_prompt, webui_n_prompt, webui_sampler_id, webui_ddim_steps, webui_cfg_scale, webui_strength, webui_seed, webui_fill_mode)
backup_alwayson_scripts(p.scripts)
disable_all_alwayson_scripts(p.scripts)
p.script_args = np.zeros(get_max_args_to(p.scripts))
processed = process_images(p)
restore_alwayson_scripts(p.scripts)
no_hash_webui_model_id = re.sub("\s\[[0-9a-f]{8,10}\]", "", webui_model_id).strip()
no_hash_webui_model_id = os.path.splitext(no_hash_webui_model_id)[0]
if processed is not None:
if len(processed.images) > 0:
output_image = processed.images[0]
infotext = create_infotext(p, all_prompts=[webui_prompt], all_seeds=[webui_seed], all_subseeds=[-1])
metadata = PngInfo()
metadata.add_text("parameters", infotext)
if not os.path.isdir(ia_outputs_dir):
os.makedirs(ia_outputs_dir, exist_ok=True)
save_name = datetime.now().strftime("%Y%m%d-%H%M%S") + "_" + os.path.basename(no_hash_webui_model_id) + "_" + str(webui_seed) + ".png"
save_name = os.path.join(ia_outputs_dir, save_name)
output_image.save(save_name, pnginfo=metadata)
else:
output_image = None
else:
output_image = None
return output_image
# class Script(scripts.Script):
# def __init__(self) -> None:
# super().__init__()
# def title(self):
# return "Inpaint Anything"
# def show(self, is_img2img):
# return scripts.AlwaysVisible
# def ui(self, is_img2img):
# return ()
def on_ui_tabs():
global sam_dict
sampler_names = get_sampler_names()
sam_model_ids = get_sam_model_ids()
sam_model_index = sam_model_ids.index("sam_vit_l_0b3195.pth") if "sam_vit_l_0b3195.pth" in sam_model_ids else 1
model_ids = get_model_ids()
cleaner_model_ids = get_cleaner_model_ids()
padding_mode_names = get_padding_mode_names()
sam_dict["cnet"] = find_controlnet()
cn_enabled = False
if sam_dict["cnet"] is not None:
cn_module_ids = [cn for cn in sam_dict["cnet"].get_modules() if "inpaint" in cn]
cn_module_index = cn_module_ids.index("inpaint_only") if "inpaint_only" in cn_module_ids else 0
cn_model_ids = [cn for cn in sam_dict["cnet"].get_models() if "inpaint" in cn]
cn_modes = [mode.value for mode in sam_dict["cnet"].ControlMode]
if len(cn_module_ids) > 0 and len(cn_model_ids) > 0:
cn_enabled = True
if samplers_for_img2img is not None and len(samplers_for_img2img) > 0:
cn_sampler_ids = [sampler.name for sampler in samplers_for_img2img]
else:
cn_sampler_ids = ["DDIM"]
cn_sampler_index = cn_sampler_ids.index("DDIM") if "DDIM" in cn_sampler_ids else -1
cn_ref_only = False
if cn_enabled and sam_dict["cnet"].get_max_models_num() > 1:
cn_ref_module_ids = [cn for cn in sam_dict["cnet"].get_modules() if "reference" in cn]
if len(cn_ref_module_ids) > 0:
cn_ref_only = True
webui_inpaint_enabled = False
list_ckpt = shared.list_checkpoint_tiles()
webui_model_ids = [ckpt for ckpt in list_ckpt if "inpaint" in ckpt.lower()]
if len(webui_model_ids) > 0:
webui_inpaint_enabled = True
if samplers_for_img2img is not None and len(samplers_for_img2img) > 0:
webui_sampler_ids = [sampler.name for sampler in samplers_for_img2img]
else:
webui_sampler_ids = ["DDIM"]
webui_sampler_index = webui_sampler_ids.index("DDIM") if "DDIM" in webui_sampler_ids else -1
with gr.Blocks(analytics_enabled=False) as inpaint_anything_interface:
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
sam_model_id = gr.Dropdown(label="Segment Anything Model ID", elem_id="sam_model_id", choices=sam_model_ids,
value=sam_model_ids[sam_model_index], show_label=True)
with gr.Column():
with gr.Row():
load_model_btn = gr.Button("Download model", elem_id="load_model_btn")
with gr.Row():
status_text = gr.Textbox(label="", max_lines=1, show_label=False, interactive=False)
with gr.Row():
input_image = gr.Image(label="Input image", elem_id="input_image", source="upload", type="numpy", interactive=True)
with gr.Row():
with gr.Accordion("Padding options", elem_id="padding_options", open=False):
with gr.Row():
with gr.Column():
pad_scale_width = gr.Slider(label="Scale Width", elem_id="pad_scale_width", minimum=1.0, maximum=1.5, value=1.0, step=0.01)
with gr.Column():
pad_lr_barance = gr.Slider(label="Left/Right Balance", elem_id="pad_lr_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
with gr.Row():
with gr.Column():
pad_scale_height = gr.Slider(label="Scale Height", elem_id="pad_scale_height", minimum=1.0, maximum=1.5, value=1.0, step=0.01)
with gr.Column():
pad_tb_barance = gr.Slider(label="Top/Bottom Balance", elem_id="pad_tb_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
with gr.Row():
with gr.Column():
padding_mode = gr.Dropdown(label="Padding Mode", elem_id="padding_mode", choices=padding_mode_names, value="edge")
with gr.Column():
padding_btn = gr.Button("Run Padding", elem_id="padding_btn")
with gr.Row():
sam_btn = gr.Button("Run Segment Anything", elem_id="sam_btn")
with gr.Tab("Inpainting", elem_id="inpainting_tab"):
prompt = gr.Textbox(label="Inpainting Prompt", elem_id="sd_prompt")
n_prompt = gr.Textbox(label="Negative Prompt", elem_id="sd_n_prompt")
with gr.Accordion("Advanced options", elem_id="inp_advanced_options", open=False):
with gr.Row():
with gr.Column():
sampler_name = gr.Dropdown(label="Sampler", elem_id="sampler_name", choices=sampler_names,
value=sampler_names[0], show_label=True)
with gr.Column():
ddim_steps = gr.Slider(label="Sampling Steps", elem_id="ddim_steps", minimum=1, maximum=100, value=20, step=1)
cfg_scale = gr.Slider(label="Guidance Scale", elem_id="cfg_scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
seed = gr.Slider(
label="Seed",
elem_id="sd_seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1,
)
with gr.Row():
with gr.Column():
model_id = gr.Dropdown(label="Inpainting Model ID", elem_id="model_id", choices=model_ids, value=model_ids[0], show_label=True)
with gr.Column():
with gr.Row():
inpaint_btn = gr.Button("Run Inpainting", elem_id="inpaint_btn")
with gr.Row():
composite_chk = gr.Checkbox(label="Mask area Only", elem_id="composite_chk", value=True, show_label=True, interactive=True)
save_mask_chk = gr.Checkbox(label="Save mask", elem_id="save_mask_chk", show_label=True, interactive=True)
with gr.Row():
out_image = gr.Image(label="Inpainted image", elem_id="out_image", type="pil", interactive=False).style(height=480)
with gr.Tab("Cleaner", elem_id="cleaner_tab"):
with gr.Row():
with gr.Column():
cleaner_model_id = gr.Dropdown(label="Cleaner Model ID", elem_id="cleaner_model_id", choices=cleaner_model_ids, value=cleaner_model_ids[0], show_label=True)
with gr.Column():
with gr.Row():
cleaner_btn = gr.Button("Run Cleaner", elem_id="cleaner_btn")
with gr.Row():
cleaner_save_mask_chk = gr.Checkbox(label="Save mask", elem_id="cleaner_save_mask_chk", show_label=True, interactive=True)
with gr.Row():
cleaner_out_image = gr.Image(label="Cleaned image", elem_id="cleaner_out_image", type="pil", interactive=False).style(height=480)
if webui_inpaint_enabled:
with gr.Tab("Inpainting webui", elem_id="webui_inpainting_tab"):
webui_prompt = gr.Textbox(label="Inpainting Prompt", elem_id="webui_sd_prompt")
webui_n_prompt = gr.Textbox(label="Negative Prompt", elem_id="webui_sd_n_prompt")
with gr.Accordion("Advanced options", elem_id="webui_advanced_options", open=False):
webui_fill_mode = gr.Radio(label="Masked content", choices=["fill", "original", "latent noise", "latent nothing"], value="original", type="index", elem_id="webui_fill_mode")
with gr.Row():
with gr.Column():
webui_sampler_id = gr.Dropdown(label="Sampling method", elem_id="webui_sampler_id", choices=webui_sampler_ids, value=webui_sampler_ids[webui_sampler_index], show_label=True)
with gr.Column():
webui_ddim_steps = gr.Slider(label="Sampling steps", elem_id="webui_ddim_steps", minimum=1, maximum=150, value=30, step=1)
webui_cfg_scale = gr.Slider(label="Guidance scale", elem_id="webui_cfg_scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
webui_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.6, elem_id="webui_strength")
webui_seed = gr.Slider(
label="Seed",
elem_id="webui_sd_seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1,
)
with gr.Row():
with gr.Column():
webui_model_id = gr.Dropdown(label="Inpainting Model ID", elem_id="webui_model_id", choices=webui_model_ids, value=webui_model_ids[0], show_label=True)
with gr.Column():
with gr.Row():
webui_inpaint_btn = gr.Button("Run Inpainting", elem_id="webui_inpaint_btn")
with gr.Row():
webui_save_mask_chk = gr.Checkbox(label="Save mask", elem_id="webui_save_mask_chk", show_label=True, interactive=True)
with gr.Row():
webui_out_image = gr.Image(label="Inpainted image", elem_id="webui_out_image", type="pil", interactive=False).style(height=480)
with gr.Tab("ControlNet Inpaint", elem_id="cn_inpaint_tab"):
if cn_enabled:
cn_prompt = gr.Textbox(label="Inpainting Prompt", elem_id="cn_sd_prompt")
cn_n_prompt = gr.Textbox(label="Negative Prompt", elem_id="cn_sd_n_prompt")
with gr.Accordion("Advanced options", elem_id="cn_advanced_options", open=False):
with gr.Row():
with gr.Column():
cn_sampler_id = gr.Dropdown(label="Sampling method", elem_id="cn_sampler_id", choices=cn_sampler_ids, value=cn_sampler_ids[cn_sampler_index], show_label=True)
with gr.Column():
cn_ddim_steps = gr.Slider(label="Sampling steps", elem_id="cn_ddim_steps", minimum=1, maximum=150, value=30, step=1)
cn_cfg_scale = gr.Slider(label="Guidance scale", elem_id="cn_cfg_scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
cn_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.6, elem_id="cn_strength")
cn_seed = gr.Slider(
label="Seed",
elem_id="cn_sd_seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1,
)
with gr.Accordion("ControlNet options", elem_id="cn_cn_options", open=False):
with gr.Row():
with gr.Column():
cn_low_vram_chk = gr.Checkbox(label="Low VRAM", elem_id="cn_low_vram_chk", show_label=True, interactive=True)
cn_weight = gr.Slider(label="Control Weight", elem_id="cn_weight", minimum=0.0, maximum=2.0, value=1.0, step=0.05)
with gr.Column():
cn_mode = gr.Dropdown(label="Control Mode", elem_id="cn_mode", choices=cn_modes, value=cn_modes[-1], show_label=True)
if cn_ref_only:
with gr.Row():
gr.Markdown("Reference-Only Control (enabled only when a reference image below is present)")
with gr.Row():
with gr.Column():
cn_ref_image = gr.Image(label="Reference Image", elem_id="cn_ref_image", source="upload", type="numpy", interactive=True)
with gr.Column():
cn_ref_module_id = gr.Dropdown(label="Reference Type", elem_id="cn_ref_module_id", choices=cn_ref_module_ids, value=cn_ref_module_ids[-1], show_label=True)
cn_ref_weight = gr.Slider(label="Reference Control Weight", elem_id="cn_ref_weight", minimum=0.0, maximum=2.0, value=1.0, step=0.05)
cn_ref_mode = gr.Dropdown(label="Reference Control Mode", elem_id="cn_ref_mode", choices=cn_modes, value=cn_modes[0], show_label=True)
else:
with gr.Row():
gr.Markdown("The Multi ControlNet setting is currently set to 1.<br>" + \
"If you wish to use the Reference-Only Control, please adjust the Multi ControlNet setting to 2 or more and restart the Web UI.")
with gr.Row():
with gr.Column():
cn_module_id = gr.Dropdown(label="ControlNet Preprocessor", elem_id="cn_module_id", choices=cn_module_ids, value=cn_module_ids[cn_module_index], show_label=True)
cn_model_id = gr.Dropdown(label="ControlNet Model ID", elem_id="cn_model_id", choices=cn_model_ids, value=cn_model_ids[0], show_label=True)
with gr.Column():
with gr.Row():
cn_inpaint_btn = gr.Button("Run ControlNet Inpaint", elem_id="cn_inpaint_btn")
with gr.Row():
cn_save_mask_chk = gr.Checkbox(label="Save mask", elem_id="cn_save_mask_chk", show_label=True, interactive=True)
with gr.Row():
cn_out_image = gr.Image(label="Inpainted image", elem_id="cn_out_image", type="pil", interactive=False).style(height=480)
else:
if sam_dict["cnet"] is None:
gr.Markdown("ControlNet extension is not available.<br>" + \
"Requires the [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension.")
elif len(cn_module_ids) > 0:
cn_models_directory = os.path.join("extensions", "sd-webui-controlnet", "models")
gr.Markdown("ControlNet inpaint model is not available.<br>" + \
f"Requires the [ControlNet-v1-1](https://huggingface.co/lllyasviel/ControlNet-v1-1) inpaint model in the {cn_models_directory} directory.")
else:
gr.Markdown("ControlNet inpaint preprocessor is not available.<br>" + \
"The local version of [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) extension may be old.")
with gr.Tab("Mask only", elem_id="mask_only_tab"):
with gr.Row():
with gr.Column():
get_alpha_image_btn = gr.Button("Get mask as alpha of image", elem_id="get_alpha_image_btn")
with gr.Column():
get_mask_btn = gr.Button("Get mask", elem_id="get_mask_btn")
with gr.Row():
with gr.Column():
alpha_out_image = gr.Image(label="Alpha channel image", elem_id="alpha_out_image", type="pil", interactive=False)
with gr.Column():
mask_out_image = gr.Image(label="Mask image", elem_id="mask_out_image", type="numpy", interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("")
with gr.Column():
mask_send_to_inpaint_btn = gr.Button("Send to img2img inpaint", elem_id="mask_send_to_inpaint_btn")
with gr.Column():
sam_image = gr.Image(label="Segment Anything image", elem_id="sam_image", type="numpy", tool="sketch", brush_radius=8,
interactive=True).style(height=480)
with gr.Row():
with gr.Column():
select_btn = gr.Button("Create mask", elem_id="select_btn")
with gr.Column():
invert_chk = gr.Checkbox(label="Invert mask", elem_id="invert_chk", show_label=True, interactive=True)
sel_mask = gr.Image(label="Selected mask image", elem_id="sel_mask", type="numpy", tool="sketch", brush_radius=12,
interactive=True).style(height=480)
with gr.Row():
with gr.Column():
expand_mask_btn = gr.Button("Expand mask region", elem_id="expand_mask_btn")
with gr.Column():
apply_mask_btn = gr.Button("Trim mask by sketch", elem_id="apply_mask_btn")
load_model_btn.click(download_model, inputs=[sam_model_id], outputs=[status_text])
input_image.upload(input_image_upload, inputs=[input_image], outputs=None)
padding_btn.click(run_padding, inputs=[input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode], outputs=[input_image, status_text])
sam_btn.click(run_sam, inputs=[input_image, sam_model_id, sam_image], outputs=[sam_image, status_text]).then(
fn=sleep_clear_cache_and_reload_model, inputs=None, outputs=None, _js="inpaintAnything_clearSamMask")
select_btn.click(select_mask, inputs=[input_image, sam_image, invert_chk, sel_mask], outputs=[sel_mask]).then(
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
expand_mask_btn.click(expand_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then(
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
apply_mask_btn.click(apply_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then(
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
inpaint_btn.click(
run_inpaint,
inputs=[input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, model_id, save_mask_chk, composite_chk, sampler_name],
outputs=[out_image]).then(
fn=sleep_clear_cache_and_reload_model, inputs=None, outputs=None)
cleaner_btn.click(
run_cleaner,
inputs=[input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk],
outputs=[cleaner_out_image]).then(
fn=sleep_clear_cache_and_reload_model, inputs=None, outputs=None)
get_alpha_image_btn.click(
run_get_alpha_image,
inputs=[input_image, sel_mask],
outputs=[alpha_out_image])
get_mask_btn.click(
run_get_mask,
inputs=[sel_mask],
outputs=[mask_out_image])
mask_send_to_inpaint_btn.click(
fn=None,
_js="inpaintAnything_sendToInpaint",
inputs=None,
outputs=None)
if cn_enabled and not cn_ref_only:
cn_inpaint_btn.click(
run_cn_inpaint,
inputs=[input_image, sel_mask,
cn_prompt, cn_n_prompt, cn_sampler_id, cn_ddim_steps, cn_cfg_scale, cn_strength, cn_seed, cn_module_id, cn_model_id, cn_save_mask_chk,
cn_low_vram_chk, cn_weight, cn_mode],
outputs=[cn_out_image]).then(
fn=sleep_clear_cache_and_reload_model, inputs=None, outputs=None)
elif cn_enabled and cn_ref_only:
cn_inpaint_btn.click(
run_cn_inpaint,
inputs=[input_image, sel_mask,
cn_prompt, cn_n_prompt, cn_sampler_id, cn_ddim_steps, cn_cfg_scale, cn_strength, cn_seed, cn_module_id, cn_model_id, cn_save_mask_chk,
cn_low_vram_chk, cn_weight, cn_mode, cn_ref_module_id, cn_ref_image, cn_ref_weight, cn_ref_mode],
outputs=[cn_out_image]).then(
fn=sleep_clear_cache_and_reload_model, inputs=None, outputs=None)
if webui_inpaint_enabled:
webui_inpaint_btn.click(
run_webui_inpaint,
inputs=[input_image, sel_mask,
webui_prompt, webui_n_prompt, webui_sampler_id, webui_ddim_steps, webui_cfg_scale, webui_strength, webui_seed, webui_model_id, webui_save_mask_chk,
webui_fill_mode],
outputs=[webui_out_image]).then(
fn=sleep_clear_cache_and_reload_model, inputs=None, outputs=None)
return [(inpaint_anything_interface, "Inpaint Anything", "inpaint_anything")]
def on_ui_settings():
section = ("inpaint_anything", "Inpaint Anything")
shared.opts.add_option("inpaint_anything_save_folder", shared.OptionInfo(
"inpaint-anything", "Folder name where output images will be saved", gr.Radio, {"choices": ["inpaint-anything", "img2img-images"]}, section=section))
shared.opts.add_option("inpaint_anything_offline_inpainting", shared.OptionInfo(
False, "Enable offline network Inpainting", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("inpaint_anything_padding_fill", shared.OptionInfo(
127, "Fill value used when Padding is set to constant", gr.Slider, {"minimum":0, "maximum":255, "step":1}, section=section))
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)