68 lines
2.9 KiB
Python
68 lines
2.9 KiB
Python
import gradio as gr
|
|
import sys
|
|
import cv2
|
|
|
|
from scripts.td_abg import get_foreground
|
|
from scripts.convertor import pil2cv
|
|
|
|
|
|
|
|
|
|
class webui:
|
|
def __init__(self):
|
|
self.demo = gr.Blocks()
|
|
|
|
def processing(self, input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
|
|
image = pil2cv(input_image)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
|
|
return image, mask
|
|
|
|
def launch(self, share):
|
|
with self.demo:
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_image = gr.Image(type="pil")
|
|
with gr.Accordion("tile division ABG", open=True):
|
|
with gr.Box():
|
|
td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
|
|
h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
|
|
v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
|
|
|
|
n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
|
|
alpha = gr.Slider(1, 255, value=100, step=1, label="alpha threshold", show_label=True)
|
|
th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
|
|
|
|
with gr.Accordion("cascadePSP", open=True):
|
|
with gr.Box():
|
|
cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
|
|
fast = gr.Checkbox(label="fast", show_label=True)
|
|
psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)
|
|
|
|
submit = gr.Button(value="Submit")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Tab("output"):
|
|
output_img = gr.Image()
|
|
with gr.Tab("mask"):
|
|
output_mask = gr.Image()
|
|
submit.click(
|
|
self.processing,
|
|
inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L],
|
|
outputs=[output_img, output_mask]
|
|
)
|
|
|
|
self.demo.queue()
|
|
self.demo.launch(share=share)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ui = webui()
|
|
if len(sys.argv) > 1:
|
|
if sys.argv[1] == "share":
|
|
ui.launch(share=True)
|
|
else:
|
|
ui.launch(share=False)
|
|
else:
|
|
ui.launch(share=False)
|