Merge pull request #102 from wfjsw/transparent-background-api
use transparent-background apipull/122/head
commit
dbb877b20d
57
stage1.py
57
stage1.py
|
|
@ -6,6 +6,8 @@ import re
|
|||
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
from PIL import Image
|
||||
from transparent_background import Remover
|
||||
from tqdm.auto import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -21,7 +23,7 @@ def resize_img(img, w, h):
|
|||
def resize_all_img(path, frame_width, frame_height):
|
||||
if not os.path.isdir(path):
|
||||
return
|
||||
|
||||
|
||||
pngs = glob.glob( os.path.join(path, "*.png") )
|
||||
img = cv2.imread(pngs[0])
|
||||
org_h,org_w = img.shape[0],img.shape[1]
|
||||
|
|
@ -44,7 +46,7 @@ def resize_all_img(path, frame_width, frame_height):
|
|||
def remove_pngs_in_dir(path):
|
||||
if not os.path.isdir(path):
|
||||
return
|
||||
|
||||
|
||||
pngs = glob.glob( os.path.join(path, "*.png") )
|
||||
for png in pngs:
|
||||
os.remove(png)
|
||||
|
|
@ -55,7 +57,7 @@ def create_and_mask(mask_dir1, mask_dir2, output_dir):
|
|||
for mask1 in masks:
|
||||
base_name = os.path.basename(mask1)
|
||||
print("combine {0}".format(base_name))
|
||||
|
||||
|
||||
mask2 = os.path.join(mask_dir2, base_name)
|
||||
if not os.path.isfile(mask2):
|
||||
print("{0} not found!!! -> skip".format(mask2))
|
||||
|
|
@ -83,7 +85,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
|
|||
imgs = glob.glob( os.path.join(input_dir, "*.png") )
|
||||
texts = [x.strip() for x in clipseg_mask_prompt.split(',')]
|
||||
exclude_texts = [x.strip() for x in clipseg_exclude_prompt.split(',')] if clipseg_exclude_prompt else None
|
||||
|
||||
|
||||
if exclude_texts:
|
||||
all_texts = texts + exclude_texts
|
||||
else:
|
||||
|
|
@ -99,7 +101,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
|
|||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
outputs = model(**inputs)
|
||||
|
||||
|
||||
if len(all_texts) == 1:
|
||||
preds = outputs.logits.unsqueeze(0)
|
||||
else:
|
||||
|
|
@ -124,7 +126,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
|
|||
|
||||
mask_img = mask_img*255
|
||||
mask_img = mask_img.astype(np.uint8)
|
||||
|
||||
|
||||
if mask_blur_size > 0:
|
||||
mask_blur_size = mask_blur_size//2 * 2 + 1
|
||||
mask_img = cv2.medianBlur(mask_img, mask_blur_size)
|
||||
|
|
@ -140,38 +142,25 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl
|
|||
cv2.imwrite(save_path, mask_img)
|
||||
|
||||
print("{0} / {1}".format( img_count+1,len(imgs) ))
|
||||
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def create_mask_transparent_background(input_dir, output_dir, tb_use_fast_mode, tb_use_jit, st1_mask_threshold):
|
||||
fast_str = " --fast" if tb_use_fast_mode else ""
|
||||
jit_str = " --jit" if tb_use_jit else ""
|
||||
venv = "venv"
|
||||
if 'VIRTUAL_ENV' in os.environ:
|
||||
venv = os.environ['VIRTUAL_ENV']
|
||||
bin_path = os.path.join(venv, "Scripts")
|
||||
bin_path = os.path.join(bin_path, "transparent-background")
|
||||
from modules import devices
|
||||
remover = Remover(fast=tb_use_fast_mode, jit=tb_use_jit, device=devices.get_optimal_device_name())
|
||||
|
||||
if os.path.isfile(bin_path) or os.path.isfile(bin_path + ".exe"):
|
||||
subprocess.call(bin_path + " --source " + input_dir + " --dest " + output_dir + " --type map" + fast_str + jit_str, shell=True)
|
||||
else:
|
||||
subprocess.call("transparent-background --source " + input_dir + " --dest " + output_dir + " --type map" + fast_str + jit_str, shell=True)
|
||||
original_imgs = glob.glob( os.path.join(input_dir, "*.png") )
|
||||
|
||||
mask_imgs = glob.glob( os.path.join(output_dir, "*.png") )
|
||||
|
||||
for m in mask_imgs:
|
||||
img = cv2.imread(m)
|
||||
img[img < int( 255 * st1_mask_threshold )] = 0
|
||||
cv2.imwrite(m, img)
|
||||
pbar_original_imgs = tqdm(original_imgs, bar_format='{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}')
|
||||
for m in pbar_original_imgs:
|
||||
base_name = os.path.basename(m)
|
||||
pbar_original_imgs.set_description('{}'.format(base_name))
|
||||
img = Image.open(m).convert('RGB')
|
||||
out = remover.process(img, type='map')
|
||||
out[out < int( 255 * st1_mask_threshold )] = 0
|
||||
cv2.imwrite(os.path.join(output_dir, base_name), out)
|
||||
|
||||
p = re.compile(r'([0-9]+)_[a-z]*\.png')
|
||||
|
||||
for mask in mask_imgs:
|
||||
base_name = os.path.basename(mask)
|
||||
m = p.fullmatch(base_name)
|
||||
if m:
|
||||
os.rename(mask, os.path.join(output_dir, m.group(1) + ".png"))
|
||||
|
||||
def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_masking_method_index, st1_mask_threshold, tb_use_fast_mode, tb_use_jit, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, clipseg_mask_blur_size, clipseg_mask_blur_size2, is_invert_mask):
|
||||
dbg.print("stage1")
|
||||
|
|
@ -192,7 +181,7 @@ def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_mas
|
|||
|
||||
if frame_mask_path:
|
||||
remove_pngs_in_dir(frame_mask_path)
|
||||
|
||||
|
||||
if frame_mask_path:
|
||||
os.makedirs(frame_mask_path, exist_ok=True)
|
||||
|
||||
|
|
@ -228,7 +217,7 @@ def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_mas
|
|||
|
||||
|
||||
dbg.print("mask created")
|
||||
|
||||
|
||||
dbg.print("")
|
||||
dbg.print("completed.")
|
||||
|
||||
|
|
@ -246,7 +235,7 @@ def ebsynth_utility_stage1_invert(dbg, frame_mask_path, inv_mask_path):
|
|||
os.makedirs(inv_mask_path, exist_ok=True)
|
||||
|
||||
mask_imgs = glob.glob( os.path.join(frame_mask_path, "*.png") )
|
||||
|
||||
|
||||
for m in mask_imgs:
|
||||
img = cv2.imread(m)
|
||||
inv = cv2.bitwise_not(img)
|
||||
|
|
|
|||
Loading…
Reference in New Issue