add scripts

pull/8/head
mattyamonaca 2023-04-05 09:16:05 +09:00
parent 2d93da9674
commit 2a75449d5f
5 changed files with 338 additions and 3 deletions

View File

@ -1,4 +1,2 @@
# PBRemTools
Precise background remover
test
Precise background remover tools.

63
scripts/convertor.py Normal file
View File

@ -0,0 +1,63 @@
import numpy as np
import pandas as pd
from PIL import Image
def rgb2df(img):
"""
Convert an RGB image to a DataFrame.
Args:
img (np.ndarray): RGB image.
Returns:
df (pd.DataFrame): DataFrame containing the image data.
"""
h, w, _ = img.shape
x_l, y_l = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
df = pd.DataFrame({
"x_l": x_l.ravel(),
"y_l": y_l.ravel(),
"r": r.ravel(),
"g": g.ravel(),
"b": b.ravel(),
})
return df
def df2rgba(img_df):
"""
Convert a DataFrame to an RGB image.
Args:
img_df (pd.DataFrame): DataFrame containing image data.
Returns:
img (np.ndarray): RGB image.
"""
r_img = img_df.pivot_table(index="x_l", columns="y_l",values= "r").reset_index(drop=True).values
g_img = img_df.pivot_table(index="x_l", columns="y_l",values= "g").reset_index(drop=True).values
b_img = img_df.pivot_table(index="x_l", columns="y_l",values= "b").reset_index(drop=True).values
a_img = img_df.pivot_table(index="x_l", columns="y_l",values= "a").reset_index(drop=True).values
df_img = np.stack([r_img, g_img, b_img, a_img], 2).astype(np.uint8)
return df_img
def pil2cv(image):
new_image = np.array(image, dtype=np.uint8)
if new_image.ndim == 2:
pass
elif new_image.shape[2] == 3:
new_image = new_image[:, :, ::-1]
elif new_image.shape[2] == 4:
new_image = new_image[:, :, [2, 1, 0, 3]]
return new_image
def cv2pil(image):
new_image = image.copy()
if new_image.ndim == 2:
pass
elif new_image.shape[2] == 3:
new_image = new_image[:, :, ::-1]
elif new_image.shape[2] == 4:
new_image = new_image[:, :, [2, 1, 0, 3]]
new_image = Image.fromarray(new_image)
return new_image

67
scripts/launch.py Normal file
View File

@ -0,0 +1,67 @@
import gradio as gr
import sys
import cv2
from td_abg import get_foreground
from convertor import pil2cv
class webui:
def __init__(self):
self.demo = gr.Blocks()
def processing(self, input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
image = pil2cv(input_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
return image, mask
def launch(self, share):
with self.demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
with gr.Accordion("tile division ABG", open=True):
with gr.Box():
td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
alpha = gr.Slider(1, 255, value=100, step=1, label="alpha threshold", show_label=True)
th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
with gr.Accordion("cascadePSP", open=True):
with gr.Box():
cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
fast = gr.Checkbox(label="fast", show_label=True)
psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)
submit = gr.Button(value="Submit")
with gr.Row():
with gr.Column():
with gr.Tab("output"):
output_img = gr.Image()
with gr.Tab("mask"):
output_mask = gr.Image()
submit.click(
self.processing,
inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L],
outputs=[output_img, output_mask]
)
self.demo.queue()
self.demo.launch(share=share)
if __name__ == "__main__":
ui = webui()
if len(sys.argv) > 1:
if sys.argv[1] == "share":
ui.launch(share=True)
else:
ui.launch(share=False)
else:
ui.launch(share=False)

85
scripts/main.py Normal file
View File

@ -0,0 +1,85 @@
import os
import io
import json
import numpy as np
import cv2
import gradio as gr
import modules.scripts as scripts
from modules import script_callbacks
from td_abg import get_foreground
from convertor import pil2cv
"""
body_estimation = None
presets_file = os.path.join(scripts.basedir(), "presets.json")
presets = {}
try:
with open(presets_file) as file:
presets = json.load(file)
except FileNotFoundError:
pass
"""
def processing(self, input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
image = pil2cv(input_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
return image, mask
class Script(scripts.Script):
def __init__(self) -> None:
super().__init__()
def title(self):
return "PBRemTools"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
return ()
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as PBRemTools:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
with gr.Accordion("tile division ABG", open=True):
with gr.Box():
td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
alpha = gr.Slider(1, 255, value=100, step=1, label="alpha threshold", show_label=True)
th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
with gr.Accordion("cascadePSP", open=True):
with gr.Box():
cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
fast = gr.Checkbox(label="fast", show_label=True)
psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)
submit = gr.Button(value="Submit")
with gr.Row():
with gr.Column():
with gr.Tab("output"):
output_img = gr.Image()
with gr.Tab("mask"):
output_mask = gr.Image()
#dummy_component = gr.Label(visible=False)
#preset = gr.Text(visible=False)
submit.click(
processing,
inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L],
outputs=[output_img, output_mask]
)
return [(PBRemTools, "PBRemTools", "pbremtools")]
script_callbacks.on_ui_tabs(on_ui_tabs)

122
scripts/td_abg.py Normal file
View File

@ -0,0 +1,122 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from convertor import rgb2df, df2rgba, cv2pil
import gradio as gr
import huggingface_hub
import onnxruntime as rt
import copy
from PIL import Image
import segmentation_refinement as refine
# Declare Execution Providers
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
# Download and host the model
model_path = huggingface_hub.hf_hub_download(
"skytnt/anime-seg", "isnetis.onnx")
rmbg_model = rt.InferenceSession(model_path, providers=providers)
def get_mask(img, s=1024):
img = (img / 255).astype(np.float32)
dim = img.shape[2]
if dim == 4:
img = img[..., :3]
dim = 3
h, w = h0, w0 = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, dim], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw //
2:pw // 2 + w] = cv2.resize(img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
mask = rmbg_model.run(None, {'img': img_input})[0][0]
mask = np.transpose(mask, (1, 2, 0))
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
return mask
def assign_tile(row, tile_width, tile_height):
tile_x = row['x_l'] // tile_width
tile_y = row['y_l'] // tile_height
return f"tile_{tile_y}_{tile_x}"
def rmbg_fn(img):
mask = get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
mask = mask.repeat(3, axis=2)
return mask, img
def refinement(img, mask, fast, psp_L):
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
refiner = refine.Refiner(device='cuda:0') # device can also be 'cpu'
# Fast - Global step only.
# Smaller L -> Less memory usage; faster in fast mode.
mask = refiner.refine(img, mask, fast=fast, L=psp_L)
return mask
def get_foreground(img, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
if td_abg_enabled == True:
mask = get_mask(img)
mask = (mask * 255).astype(np.uint8)
mask = mask.repeat(3, axis=2)
if cascadePSP_enabled == True:
mask = refinement(img, mask, fast, psp_L)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
df = rgb2df(img)
image_width = img.shape[1]
image_height = img.shape[0]
num_horizontal_splits = h_split
num_vertical_splits = v_split
tile_width = image_width // num_horizontal_splits
tile_height = image_height // num_vertical_splits
df['tile'] = df.apply(assign_tile, args=(tile_width, tile_height), axis=1)
cls = MiniBatchKMeans(n_clusters=n_cluster, batch_size=100)
cls.fit(df[["r","g","b"]])
df["label"] = cls.labels_
mask_df = rgb2df(mask)
mask_df['bg_label'] = (mask_df['r'] > alpha) & (mask_df['g'] > alpha) & (mask_df['b'] > alpha)
img_df = df.copy()
img_df["bg_label"] = mask_df["bg_label"]
img_df["label"] = img_df["label"].astype(str) + "-" + img_df["tile"]
bg_rate = img_df.groupby("label").sum()["bg_label"]/img_df.groupby("label").count()["bg_label"]
img_df['bg_cls'] = (img_df['label'].isin(bg_rate[bg_rate > th_rate].index)).astype(int)
img_df.loc[img_df['bg_cls'] == 0, ['a']] = 0
img_df.loc[img_df['bg_cls'] != 0, ['a']] = 255
img = df2rgba(img_df)
if cascadePSP_enabled == True and td_abg_enabled == False:
mask = get_mask(img)
mask = (mask * 255).astype(np.uint8)
refiner = refine.Refiner(device='cuda:0')
mask = refiner.refine(img, mask, fast=fast, L=psp_L)
img = np.dstack((img, mask))
if cascadePSP_enabled == False and td_abg_enabled == False:
mask, img = rmbg_fn(img)
return mask, img