automatic/scripts/postprocessing_rembg.py

116 lines
5.7 KiB
Python

import os
import gradio as gr
from PIL import Image
from modules import scripts_postprocessing
models = [
"none",
"ben2",
"silueta",
"u2net",
"u2net_human_seg",
"isnet-general-use",
"isnet-anime",
# "u2netp",
# "u2net_cloth_seg",
# "sam",
]
class ScriptPostprocessingRembg(scripts_postprocessing.ScriptPostprocessing):
name = "Remove background"
order = 20000
model = None
def ui(self):
def visible(model):
return gr.update(visible=model not in ["none"]), gr.update(visible=model in ["ben2"]), gr.update(visible=model not in ["none", "ben2"])
with gr.Accordion('Remove background', open = True, elem_id="postprocess_rembg_accordion"):
with gr.Row():
model = gr.Dropdown(label="Model", choices=models, value="none", elem_id="extras_rembg_model")
with gr.Group(visible=False) as group_base:
with gr.Row():
merge_alpha = gr.Checkbox(label="Merge alpha", value=False)
mask_only = gr.Checkbox(label="Mask only", value=False)
with gr.Group(visible=False) as group_ben2:
with gr.Row():
refine = gr.Checkbox(label="Refine foreground", value=False)
with gr.Group(visible=False) as group_rembg:
with gr.Row():
postprocess_mask = gr.Checkbox(label="Postprocess mask", value=False, elem_id="extras_rembg_process_mask")
alpha_matting = gr.Checkbox(label="Alpha matting", value=False, elem_id="extras_rembg_alpha")
with gr.Row(visible=True) as alpha_mask_row:
alpha_matting_erode_size = gr.Slider(label="Erode size", minimum=0, maximum=40, step=1, value=10, elem_id="extras_rembg_alpha_erode")
alpha_matting_foreground_threshold = gr.Slider(label="Foreground threshold", minimum=0, maximum=255, step=1, value=240, elem_id="extras_rembg_alpha_foreground")
alpha_matting_background_threshold = gr.Slider(label="Background threshold", minimum=0, maximum=255, step=1, value=10, elem_id="extras_rembg_alpha_background")
alpha_matting.change(fn=lambda x: gr.update(visible=x), inputs=[alpha_matting], outputs=[alpha_mask_row])
model.change(fn=lambda x: visible(x), inputs=[model], outputs=[group_base, group_ben2, group_rembg])
return {
"model": model,
"merge_alpha": merge_alpha,
"refine": refine,
"mask_only": mask_only,
"postprocess_mask": postprocess_mask,
"alpha_matting": alpha_matting,
"alpha_matting_foreground_threshold": alpha_matting_foreground_threshold,
"alpha_matting_background_threshold": alpha_matting_background_threshold,
"alpha_matting_erode_size": alpha_matting_erode_size,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, model, merge_alpha, refine, mask_only, postprocess_mask, alpha_matting, alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size): # pylint: disable=arguments-differ
from modules.logger import log
if not model or model == "none":
return pp
if isinstance(pp, Image.Image):
image = pp
info = {}
else:
image = pp.image
info = pp.info
log.info(f'RemoveBackground: model={model} merge_alpha={merge_alpha} refine={refine} mask_only={mask_only} postprocess_mask={postprocess_mask} alpha_matting={alpha_matting} alpha_matting_foreground_threshold={alpha_matting_foreground_threshold} alpha_matting_background_threshold={alpha_matting_background_threshold} alpha_matting_erode_size={alpha_matting_erode_size}')
if model == 'ben2':
try:
from modules.rembg import ben2
image = ben2.remove(image, refine=refine)
except Exception as e:
log.error(f'RemoveBackground: model={model} {e}')
return pp
else:
try:
from installer import install
for pkg in ["dctorch==0.1.2", "pymatting", "pooch", "rembg"]:
install(pkg, no_deps=True, ignore=False)
import rembg
if "U2NET_HOME" not in os.environ:
from modules.paths import models_path
os.environ["U2NET_HOME"] = os.path.join(models_path, "Rembg")
image = rembg.remove(image,
post_process_mask=postprocess_mask,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
session=rembg.new_session(model))
except Exception as e:
log.error(f'RemoveBackground: model={model} {e}')
return pp
if mask_only and pp.image.mode == "RGBA":
_r, _g, _b, alpha = pp.image.split()
image = alpha
if merge_alpha and pp.image.mode == "RGBA":
flattened = Image.new("RGBA", pp.image.size, "BLACK")
flattened.paste(pp.image, mask=pp.image)
flattened.convert("RGB")
image = flattened
if isinstance(pp, Image.Image):
pp = scripts_postprocessing.PostprocessedImage(image=image, info={**info, "Rembg": model})
else:
pp.image = image
pp.info['Rembg'] = model
return pp