ebsynth_utility/stage1.py

248 lines
8.6 KiB
Python

import os
import subprocess
import glob
import cv2
import re
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from PIL import Image
from transparent_background import Remover
from tqdm.auto import tqdm
import torch
import numpy as np
def resize_img(img, w, h):
if img.shape[0] + img.shape[1] < h + w:
interpolation = interpolation=cv2.INTER_CUBIC
else:
interpolation = interpolation=cv2.INTER_AREA
return cv2.resize(img, (w, h), interpolation=interpolation)
def resize_all_img(path, frame_width, frame_height):
if not os.path.isdir(path):
return
pngs = glob.glob( os.path.join(path, "*.png") )
img = cv2.imread(pngs[0])
org_h,org_w = img.shape[0],img.shape[1]
if frame_width == -1 and frame_height == -1:
return
elif frame_width == -1 and frame_height != -1:
frame_width = int(frame_height * org_w / org_h)
elif frame_width != -1 and frame_height == -1:
frame_height = int(frame_width * org_h / org_w)
else:
pass
print("({0},{1}) resize to ({2},{3})".format(org_w, org_h, frame_width, frame_height))
for png in pngs:
img = cv2.imread(png)
img = resize_img(img, frame_width, frame_height)
cv2.imwrite(png, img)
def remove_pngs_in_dir(path):
if not os.path.isdir(path):
return
pngs = glob.glob( os.path.join(path, "*.png") )
for png in pngs:
os.remove(png)
def create_and_mask(mask_dir1, mask_dir2, output_dir):
masks = glob.glob( os.path.join(mask_dir1, "*.png") )
for mask1 in masks:
base_name = os.path.basename(mask1)
print("combine {0}".format(base_name))
mask2 = os.path.join(mask_dir2, base_name)
if not os.path.isfile(mask2):
print("{0} not found!!! -> skip".format(mask2))
continue
img_1 = cv2.imread(mask1)
img_2 = cv2.imread(mask2)
img_1 = np.minimum(img_1,img_2)
out_path = os.path.join(output_dir, base_name)
cv2.imwrite(out_path, img_1)
def create_mask_clipseg(input_dir, output_dir, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, mask_blur_size, mask_blur_size2):
from modules import devices
devices.torch_gc()
device = devices.get_optimal_device_name()
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
model.to(device)
imgs = glob.glob( os.path.join(input_dir, "*.png") )
texts = [x.strip() for x in clipseg_mask_prompt.split(',')]
exclude_texts = [x.strip() for x in clipseg_exclude_prompt.split(',')] if clipseg_exclude_prompt else None
if exclude_texts:
all_texts = texts + exclude_texts
else:
all_texts = texts
for img_count,img in enumerate(imgs):
image = Image.open(img)
base_name = os.path.basename(img)
inputs = processor(text=all_texts, images=[image] * len(all_texts), padding="max_length", return_tensors="pt")
inputs = inputs.to(device)
with torch.no_grad(), devices.autocast():
outputs = model(**inputs)
if len(all_texts) == 1:
preds = outputs.logits.unsqueeze(0)
else:
preds = outputs.logits
mask_img = None
for i in range(len(all_texts)):
x = torch.sigmoid(preds[i])
x = x.to('cpu').detach().numpy()
# x[x < clipseg_mask_threshold] = 0
x = x > clipseg_mask_threshold
if i < len(texts):
if mask_img is None:
mask_img = x
else:
mask_img = np.maximum(mask_img,x)
else:
mask_img[x > 0] = 0
mask_img = mask_img*255
mask_img = mask_img.astype(np.uint8)
if mask_blur_size > 0:
mask_blur_size = mask_blur_size//2 * 2 + 1
mask_img = cv2.medianBlur(mask_img, mask_blur_size)
if mask_blur_size2 > 0:
mask_blur_size2 = mask_blur_size2//2 * 2 + 1
mask_img = cv2.GaussianBlur(mask_img, (mask_blur_size2, mask_blur_size2), 0)
mask_img = resize_img(mask_img, image.width, image.height)
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
save_path = os.path.join(output_dir, base_name)
cv2.imwrite(save_path, mask_img)
print("{0} / {1}".format( img_count+1,len(imgs) ))
devices.torch_gc()
def create_mask_transparent_background(input_dir, output_dir, tb_use_fast_mode, tb_use_jit, st1_mask_threshold):
from modules import devices
remover = Remover(fast=tb_use_fast_mode, jit=tb_use_jit, device=devices.get_optimal_device_name())
original_imgs = glob.glob( os.path.join(input_dir, "*.png") )
pbar_original_imgs = tqdm(original_imgs, bar_format='{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}')
for m in pbar_original_imgs:
base_name = os.path.basename(m)
pbar_original_imgs.set_description('{}'.format(base_name))
img = Image.open(m).convert('RGB')
out = remover.process(img, type='map')
out[out < int( 255 * st1_mask_threshold )] = 0
cv2.imwrite(os.path.join(output_dir, base_name), out)
def ebsynth_utility_stage1(dbg, project_args, frame_width, frame_height, st1_masking_method_index, st1_mask_threshold, tb_use_fast_mode, tb_use_jit, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, clipseg_mask_blur_size, clipseg_mask_blur_size2, is_invert_mask):
dbg.print("stage1")
dbg.print("")
if st1_masking_method_index == 1 and (not clipseg_mask_prompt):
dbg.print("Error: clipseg_mask_prompt is Empty")
return
project_dir, original_movie_path, frame_path, frame_mask_path, _, _, _ = project_args
if is_invert_mask:
if os.path.isdir( frame_path ) and os.path.isdir( frame_mask_path ):
dbg.print("Skip as it appears that the frame and normal masks have already been generated.")
return
# remove_pngs_in_dir(frame_path)
if frame_mask_path:
remove_pngs_in_dir(frame_mask_path)
if frame_mask_path:
os.makedirs(frame_mask_path, exist_ok=True)
if os.path.isdir( frame_path ):
dbg.print("Skip frame extraction")
else:
os.makedirs(frame_path, exist_ok=True)
png_path = os.path.join(frame_path , "%05d.png")
# ffmpeg.exe -ss 00:00:00 -y -i %1 -qscale 0 -f image2 -c:v png "%05d.png"
subprocess.call("ffmpeg -ss 00:00:00 -y -i " + original_movie_path + " -qscale 0 -f image2 -c:v png " + png_path, shell=True)
dbg.print("frame extracted")
frame_width = max(frame_width,-1)
frame_height = max(frame_height,-1)
if frame_width != -1 or frame_height != -1:
resize_all_img(frame_path, frame_width, frame_height)
if frame_mask_path:
if st1_masking_method_index == 0:
create_mask_transparent_background(frame_path, frame_mask_path, tb_use_fast_mode, tb_use_jit, st1_mask_threshold)
elif st1_masking_method_index == 1:
create_mask_clipseg(frame_path, frame_mask_path, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, clipseg_mask_blur_size, clipseg_mask_blur_size2)
elif st1_masking_method_index == 2:
tb_tmp_path = os.path.join(project_dir , "tb_mask_tmp")
if not os.path.isdir( tb_tmp_path ):
os.makedirs(tb_tmp_path, exist_ok=True)
create_mask_transparent_background(frame_path, tb_tmp_path, tb_use_fast_mode, tb_use_jit, st1_mask_threshold)
create_mask_clipseg(frame_path, frame_mask_path, clipseg_mask_prompt, clipseg_exclude_prompt, clipseg_mask_threshold, clipseg_mask_blur_size, clipseg_mask_blur_size2)
create_and_mask(tb_tmp_path,frame_mask_path,frame_mask_path)
dbg.print("mask created")
dbg.print("")
dbg.print("completed.")
def ebsynth_utility_stage1_invert(dbg, frame_mask_path, inv_mask_path):
dbg.print("stage 1 create_invert_mask")
dbg.print("")
if not os.path.isdir( frame_mask_path ):
dbg.print( frame_mask_path + " not found")
dbg.print("Normal masks must be generated previously.")
dbg.print("Do stage 1 with [Ebsynth Utility] Tab -> [configuration] -> [etc]-> [Mask Mode] = Normal setting first")
return
os.makedirs(inv_mask_path, exist_ok=True)
mask_imgs = glob.glob( os.path.join(frame_mask_path, "*.png") )
for m in mask_imgs:
img = cv2.imread(m)
inv = cv2.bitwise_not(img)
base_name = os.path.basename(m)
cv2.imwrite(os.path.join(inv_mask_path,base_name), inv)
dbg.print("")
dbg.print("completed.")