mirror of https://github.com/vladmandic/automatic
116 lines
5.7 KiB
Python
116 lines
5.7 KiB
Python
import os
|
|
import gradio as gr
|
|
from PIL import Image
|
|
from modules import scripts_postprocessing
|
|
|
|
|
|
models = [
|
|
"none",
|
|
"ben2",
|
|
"silueta",
|
|
"u2net",
|
|
"u2net_human_seg",
|
|
"isnet-general-use",
|
|
"isnet-anime",
|
|
# "u2netp",
|
|
# "u2net_cloth_seg",
|
|
# "sam",
|
|
]
|
|
|
|
|
|
class ScriptPostprocessingRembg(scripts_postprocessing.ScriptPostprocessing):
|
|
name = "Remove background"
|
|
order = 20000
|
|
model = None
|
|
|
|
def ui(self):
|
|
def visible(model):
|
|
return gr.update(visible=model not in ["none"]), gr.update(visible=model in ["ben2"]), gr.update(visible=model not in ["none", "ben2"])
|
|
|
|
with gr.Accordion('Remove background', open = True, elem_id="postprocess_rembg_accordion"):
|
|
with gr.Row():
|
|
model = gr.Dropdown(label="Model", choices=models, value="none", elem_id="extras_rembg_model")
|
|
with gr.Group(visible=False) as group_base:
|
|
with gr.Row():
|
|
merge_alpha = gr.Checkbox(label="Merge alpha", value=False)
|
|
mask_only = gr.Checkbox(label="Mask only", value=False)
|
|
with gr.Group(visible=False) as group_ben2:
|
|
with gr.Row():
|
|
refine = gr.Checkbox(label="Refine foreground", value=False)
|
|
with gr.Group(visible=False) as group_rembg:
|
|
with gr.Row():
|
|
postprocess_mask = gr.Checkbox(label="Postprocess mask", value=False, elem_id="extras_rembg_process_mask")
|
|
alpha_matting = gr.Checkbox(label="Alpha matting", value=False, elem_id="extras_rembg_alpha")
|
|
with gr.Row(visible=True) as alpha_mask_row:
|
|
alpha_matting_erode_size = gr.Slider(label="Erode size", minimum=0, maximum=40, step=1, value=10, elem_id="extras_rembg_alpha_erode")
|
|
alpha_matting_foreground_threshold = gr.Slider(label="Foreground threshold", minimum=0, maximum=255, step=1, value=240, elem_id="extras_rembg_alpha_foreground")
|
|
alpha_matting_background_threshold = gr.Slider(label="Background threshold", minimum=0, maximum=255, step=1, value=10, elem_id="extras_rembg_alpha_background")
|
|
alpha_matting.change(fn=lambda x: gr.update(visible=x), inputs=[alpha_matting], outputs=[alpha_mask_row])
|
|
model.change(fn=lambda x: visible(x), inputs=[model], outputs=[group_base, group_ben2, group_rembg])
|
|
return {
|
|
"model": model,
|
|
"merge_alpha": merge_alpha,
|
|
"refine": refine,
|
|
"mask_only": mask_only,
|
|
"postprocess_mask": postprocess_mask,
|
|
"alpha_matting": alpha_matting,
|
|
"alpha_matting_foreground_threshold": alpha_matting_foreground_threshold,
|
|
"alpha_matting_background_threshold": alpha_matting_background_threshold,
|
|
"alpha_matting_erode_size": alpha_matting_erode_size,
|
|
}
|
|
|
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, model, merge_alpha, refine, mask_only, postprocess_mask, alpha_matting, alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size): # pylint: disable=arguments-differ
|
|
from modules.logger import log
|
|
if not model or model == "none":
|
|
return pp
|
|
if isinstance(pp, Image.Image):
|
|
image = pp
|
|
info = {}
|
|
else:
|
|
image = pp.image
|
|
info = pp.info
|
|
|
|
log.info(f'RemoveBackground: model={model} merge_alpha={merge_alpha} refine={refine} mask_only={mask_only} postprocess_mask={postprocess_mask} alpha_matting={alpha_matting} alpha_matting_foreground_threshold={alpha_matting_foreground_threshold} alpha_matting_background_threshold={alpha_matting_background_threshold} alpha_matting_erode_size={alpha_matting_erode_size}')
|
|
if model == 'ben2':
|
|
try:
|
|
from modules.rembg import ben2
|
|
image = ben2.remove(image, refine=refine)
|
|
except Exception as e:
|
|
log.error(f'RemoveBackground: model={model} {e}')
|
|
return pp
|
|
else:
|
|
try:
|
|
from installer import install
|
|
for pkg in ["dctorch==0.1.2", "pymatting", "pooch", "rembg"]:
|
|
install(pkg, no_deps=True, ignore=False)
|
|
import rembg
|
|
if "U2NET_HOME" not in os.environ:
|
|
from modules.paths import models_path
|
|
os.environ["U2NET_HOME"] = os.path.join(models_path, "Rembg")
|
|
image = rembg.remove(image,
|
|
post_process_mask=postprocess_mask,
|
|
alpha_matting=alpha_matting,
|
|
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
|
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
|
alpha_matting_erode_size=alpha_matting_erode_size,
|
|
session=rembg.new_session(model))
|
|
except Exception as e:
|
|
log.error(f'RemoveBackground: model={model} {e}')
|
|
return pp
|
|
|
|
if mask_only and pp.image.mode == "RGBA":
|
|
_r, _g, _b, alpha = pp.image.split()
|
|
image = alpha
|
|
if merge_alpha and pp.image.mode == "RGBA":
|
|
flattened = Image.new("RGBA", pp.image.size, "BLACK")
|
|
flattened.paste(pp.image, mask=pp.image)
|
|
flattened.convert("RGB")
|
|
image = flattened
|
|
|
|
if isinstance(pp, Image.Image):
|
|
pp = scripts_postprocessing.PostprocessedImage(image=image, info={**info, "Rembg": model})
|
|
else:
|
|
pp.image = image
|
|
pp.info['Rembg'] = model
|
|
return pp
|