diff --git a/stage1.py b/stage1.py index 90acf44..26da3c8 100644 --- a/stage1.py +++ b/stage1.py @@ -6,6 +6,8 @@ import re from transformers import AutoProcessor, CLIPSegForImageSegmentation from PIL import Image +from transparent_background import Remover +from tqdm.auto import tqdm import torch import numpy as np @@ -21,7 +23,7 @@ def resize_img(img, w, h): def resize_all_img(path, frame_width, frame_height): if not os.path.isdir(path): return - + pngs = glob.glob( os.path.join(path, "*.png") ) img = cv2.imread(pngs[0]) org_h,org_w = img.shape[0],img.shape[1] @@ -44,7 +46,7 @@ def resize_all_img(path, frame_width, frame_height): def remove_pngs_in_dir(path): if not os.path.isdir(path): return - + pngs = glob.glob( os.path.join(path, "*.png") ) for png in pngs: os.remove(png) @@ -55,7 +57,7 @@ def create_and_mask(mask_dir1, mask_dir2, output_dir): for mask1 in masks: base_name = os.path.basename(mask1) print("combine {0}".format(base_name)) - + mask2 = os.path.join(mask_dir2, base_name) if not os.path.isfile(mask2): print("{0} not found!!! -> skip".format(mask2)) @@ -83,7 +85,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl imgs = glob.glob( os.path.join(input_dir, "*.png") ) texts = [x.strip() for x in clipseg_mask_prompt.split(',')] exclude_texts = [x.strip() for x in clipseg_exclude_prompt.split(',')] if clipseg_exclude_prompt else None - + if exclude_texts: all_texts = texts + exclude_texts else: @@ -99,7 +101,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl with torch.no_grad(), devices.autocast(): outputs = model(**inputs) - + if len(all_texts) == 1: preds = outputs.logits.unsqueeze(0) else: @@ -124,7 +126,7 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl mask_img = mask_img*255 mask_img = mask_img.astype(np.uint8) - + if mask_blur_size > 0: mask_blur_size = mask_blur_size//2 * 2 + 1 mask_img = cv2.medianBlur(mask_img, mask_blur_size) @@ -140,38 +142,25 @@ def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_excl cv2.imwrite(save_path, mask_img) print("{0} / {1}".format( img_count+1,len(imgs) )) - + devices.torch_gc() def create_mask_transparent_background(input_dir, output_dir, tb_use_fast_mode, tb_use_jit, st1_mask_threshold): - fast_str = " --fast" if tb_use_fast_mode else "" - jit_str = " --jit" if tb_use_jit else "" - venv = "venv" - if 'VIRTUAL_ENV' in os.environ: - venv = os.environ['VIRTUAL_ENV'] - bin_path = os.path.join(venv, "Scripts") - bin_path = os.path.join(bin_path, "transparent-background") + from modules import devices + remover = Remover(fast=tb_use_fast_mode, jit=tb_use_jit, device=devices.get_optimal_device_name()) - if os.path.isfile(bin_path) or os.path.isfile(bin_path + ".exe"): - subprocess.call(bin_path + " --source " + input_dir + " --dest " + output_dir + " --type map" + fast_str + jit_str, shell=True) - else: - subprocess.call("transparent-background --source " + input_dir + " --dest " + output_dir + " --type map" + fast_str + jit_str, shell=True) + original_imgs = glob.glob( os.path.join(input_dir, "*.png") ) - mask_imgs = glob.glob( os.path.join(output_dir, "*.png") ) - - for m in mask_imgs: - img = cv2.imread(m) - img[img < int( 255 * st1_mask_threshold )] = 0 - cv2.imwrite(m, img) + pbar_original_imgs = tqdm(original_imgs, bar_format='{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}') + for m in pbar_original_imgs: + base_name = os.path.basename(m) + pbar_original_imgs.set_description('{}'.format(base_name)) + img = Image.open(m).convert('RGB') + out = remover.process(img, type='map') + out[out < int( 255 * st1_mask_threshold )] = 0 + cv2.imwrite(os.path.join(output_dir, base_name), out) - p = re.compile(r'([0-9]+)_[a-z]*\.png') - - for mask in mask_imgs: - base_name = os.path.basename(mask) - m = p.fullmatch(base_name) - if m: - os.rename(mask, os.path.join(output_dir, m.group(1) + ".png")) def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_masking_method_index, st1_mask_threshold, tb_use_fast_mode, tb_use_jit, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, clipseg_mask_blur_size, clipseg_mask_blur_size2, is_invert_mask): dbg.print("stage1") @@ -192,7 +181,7 @@ def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_mas if frame_mask_path: remove_pngs_in_dir(frame_mask_path) - + if frame_mask_path: os.makedirs(frame_mask_path, exist_ok=True) @@ -228,7 +217,7 @@ def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_mas dbg.print("mask created") - + dbg.print("") dbg.print("completed.") @@ -246,7 +235,7 @@ def ebsynth_utility_stage1_invert(dbg, frame_mask_path, inv_mask_path): os.makedirs(inv_mask_path, exist_ok=True) mask_imgs = glob.glob( os.path.join(frame_mask_path, "*.png") ) - + for m in mask_imgs: img = cv2.imread(m) inv = cv2.bitwise_not(img)