Merge pull request #102 from wfjsw/transparent-background-api

use transparent-background api
pull/122/head
s9roll7 2023-09-19 16:30:02 +09:00 committed by GitHub
commit dbb877b20d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 34 deletions

View File

@ -6,6 +6,8 @@ import re
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from PIL import Image
from transparent_background import Remover
from tqdm.auto import tqdm
import torch
import numpy as np
@ -21,7 +23,7 @@ def resize_img(img, w, h):
def resize_all_img(path, frame_width, frame_height):
if not os.path.isdir(path):
return
pngs = glob.glob( os.path.join(path, "*.png") )
img = cv2.imread(pngs[0])
org_h,org_w = img.shape[0],img.shape[1]
@ -44,7 +46,7 @@ def resize_all_img(path, frame_width, frame_height):
def remove_pngs_in_dir(path):
if not os.path.isdir(path):
return
pngs = glob.glob( os.path.join(path, "*.png") )
for png in pngs:
os.remove(png)
@ -55,7 +57,7 @@ def create_and_mask(mask_dir1, mask_dir2, output_dir):
for mask1 in masks:
base_name = os.path.basename(mask1)
print("combine {0}".format(base_name))
mask2 = os.path.join(mask_dir2, base_name)
if not os.path.isfile(mask2):
print("{0} not found!!! -> skip".format(mask2))
@ -83,7 +85,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
imgs = glob.glob( os.path.join(input_dir, "*.png") )
texts = [x.strip() for x in clipseg_mask_prompt.split(',')]
exclude_texts = [x.strip() for x in clipseg_exclude_prompt.split(',')] if clipseg_exclude_prompt else None
if exclude_texts:
all_texts = texts + exclude_texts
else:
@ -99,7 +101,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
with torch.no_grad(), devices.autocast():
outputs = model(**inputs)
if len(all_texts) == 1:
preds = outputs.logits.unsqueeze(0)
else:
@ -124,7 +126,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
mask_img = mask_img*255
mask_img = mask_img.astype(np.uint8)
if mask_blur_size > 0:
mask_blur_size = mask_blur_size//2 * 2 + 1
mask_img = cv2.medianBlur(mask_img, mask_blur_size)
@ -140,38 +142,25 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
cv2.imwrite(save_path, mask_img)
print("{0} / {1}".format( img_count+1,len(imgs) ))
devices.torch_gc()
def create_mask_transparent_background(input_dir, output_dir, tb_use_fast_mode, tb_use_jit, st1_mask_threshold):
fast_str = " --fast" if tb_use_fast_mode else ""
jit_str = " --jit" if tb_use_jit else ""
venv = "venv"
if 'VIRTUAL_ENV' in os.environ:
venv = os.environ['VIRTUAL_ENV']
bin_path = os.path.join(venv, "Scripts")
bin_path = os.path.join(bin_path, "transparent-background")
from modules import devices
remover = Remover(fast=tb_use_fast_mode, jit=tb_use_jit, device=devices.get_optimal_device_name())
if os.path.isfile(bin_path) or os.path.isfile(bin_path + ".exe"):
subprocess.call(bin_path + " --source " + input_dir + " --dest " + output_dir + " --type map" + fast_str + jit_str, shell=True)
else:
subprocess.call("transparent-background --source " + input_dir + " --dest " + output_dir + " --type map" + fast_str + jit_str, shell=True)
original_imgs = glob.glob( os.path.join(input_dir, "*.png") )
mask_imgs = glob.glob( os.path.join(output_dir, "*.png") )
for m in mask_imgs:
img = cv2.imread(m)
img[img < int( 255 * st1_mask_threshold )] = 0
cv2.imwrite(m, img)
pbar_original_imgs = tqdm(original_imgs, bar_format='{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}')
for m in pbar_original_imgs:
base_name = os.path.basename(m)
pbar_original_imgs.set_description('{}'.format(base_name))
img = Image.open(m).convert('RGB')
out = remover.process(img, type='map')
out[out < int( 255 * st1_mask_threshold )] = 0
cv2.imwrite(os.path.join(output_dir, base_name), out)
p = re.compile(r'([0-9]+)_[a-z]*\.png')
for mask in mask_imgs:
base_name = os.path.basename(mask)
m = p.fullmatch(base_name)
if m:
os.rename(mask, os.path.join(output_dir, m.group(1) + ".png"))
def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_masking_method_index, st1_mask_threshold, tb_use_fast_mode, tb_use_jit, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, clipseg_mask_blur_size, clipseg_mask_blur_size2, is_invert_mask):
dbg.print("stage1")
@ -192,7 +181,7 @@ def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_mas
if frame_mask_path:
remove_pngs_in_dir(frame_mask_path)
if frame_mask_path:
os.makedirs(frame_mask_path, exist_ok=True)
@ -228,7 +217,7 @@ def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_mas
dbg.print("mask created")
dbg.print("")
dbg.print("completed.")
@ -246,7 +235,7 @@ def ebsynth_utility_stage1_invert(dbg, frame_mask_path, inv_mask_path):
os.makedirs(inv_mask_path, exist_ok=True)
mask_imgs = glob.glob( os.path.join(frame_mask_path, "*.png") )
for m in mask_imgs:
img = cv2.imread(m)
inv = cv2.bitwise_not(img)