305 lines
13 KiB
Python
305 lines
13 KiB
Python
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
from prettytable import PrettyTable
|
|
from .prompt import split_weighted_subprompts
|
|
from .load_images import load_img, prepare_mask, check_mask_for_errors
|
|
from .webui_sd_pipeline import get_webui_sd_pipeline
|
|
from .animation import sample_from_cv2, sample_to_cv2
|
|
|
|
#Webui
|
|
import cv2
|
|
from .animation import sample_from_cv2, sample_to_cv2
|
|
from modules import processing, sd_models, masking
|
|
from modules.shared import opts, sd_model
|
|
import modules.shared as shared
|
|
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
|
|
|
import math, json, itertools
|
|
import requests
|
|
|
|
def load_mask_latent(mask_input, shape):
|
|
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
|
|
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
|
|
|
|
if isinstance(mask_input, str): # mask input is probably a file name
|
|
if mask_input.startswith('http://') or mask_input.startswith('https://'):
|
|
mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')
|
|
else:
|
|
mask_image = Image.open(mask_input).convert('RGBA')
|
|
elif isinstance(mask_input, Image.Image):
|
|
mask_image = mask_input
|
|
else:
|
|
raise Exception("mask_input must be a PIL image or a file name")
|
|
|
|
mask_w_h = (shape[-1], shape[-2])
|
|
mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)
|
|
mask = mask.convert("L")
|
|
return mask
|
|
|
|
def isJson(myjson):
|
|
try:
|
|
json.loads(myjson)
|
|
except ValueError as e:
|
|
return False
|
|
return True
|
|
|
|
# Add pairwise implementation here not to upgrade
|
|
# the whole python to 3.10 just for one function
|
|
def pairwise_repl(iterable):
|
|
a, b = itertools.tee(iterable)
|
|
next(b, None)
|
|
return zip(a, b)
|
|
|
|
def generate(args, anim_args, loop_args, root, frame = 0, return_sample=False, sampler_name=None):
|
|
assert args.prompt is not None
|
|
|
|
# Setup the pipeline
|
|
p = get_webui_sd_pipeline(args, root, frame)
|
|
p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame)
|
|
|
|
if not args.use_init and args.strength > 0 and args.strength_0_no_init:
|
|
print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.")
|
|
print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n")
|
|
args.strength = 0
|
|
processed = None
|
|
mask_image = None
|
|
init_image = None
|
|
image_init0 = None
|
|
|
|
if loop_args.use_looper:
|
|
# TODO find out why we need to set this in the init tab
|
|
if args.strength == 0:
|
|
raise RuntimeError("Strength needs to be greater than 0 in Init tab and strength_0_no_init should *not* be checked")
|
|
if args.seed_behavior != "schedule":
|
|
raise RuntimeError("seed_behavior needs to be set to schedule in under 'Keyframes' tab --> 'Seed scheduling'")
|
|
if not isJson(loop_args.imagesToKeyframe):
|
|
raise RuntimeError("The images set for use with keyframe-guidance are not in a proper JSON format")
|
|
args.strength = loop_args.imageStrength
|
|
tweeningFrames = loop_args.tweeningFrameSchedule
|
|
blendFactor = .07
|
|
colorCorrectionFactor = loop_args.colorCorrectionFactor
|
|
jsonImages = json.loads(loop_args.imagesToKeyframe)
|
|
framesToImageSwapOn = list(map(int, list(jsonImages.keys())))
|
|
# find which image to show
|
|
frameToChoose = 0
|
|
for swappingFrame in framesToImageSwapOn[1:]:
|
|
frameToChoose += (frame >= int(swappingFrame))
|
|
|
|
#find which frame to do our swapping on for tweening
|
|
skipFrame = 25
|
|
for fs, fe in pairwise_repl(framesToImageSwapOn):
|
|
if fs <= frame <= fe:
|
|
skipFrame = fe - fs
|
|
|
|
if frame % skipFrame <= tweeningFrames: # number of tweening frames
|
|
blendFactor = loop_args.blendFactorMax - loop_args.blendFactorSlope*math.cos((frame % tweeningFrames) / (tweeningFrames / 2))
|
|
init_image2, _ = load_img(list(jsonImages.values())[frameToChoose],
|
|
shape=(args.W, args.H),
|
|
use_alpha_as_mask=args.use_alpha_as_mask)
|
|
image_init0 = list(jsonImages.values())[0]
|
|
|
|
else: # they passed in a single init image
|
|
image_init0 = args.init_image
|
|
|
|
|
|
available_samplers = {
|
|
'euler a':'Euler a',
|
|
'euler':'Euler',
|
|
'lms':'LMS',
|
|
'heun':'Heun',
|
|
'dpm2':'DPM2',
|
|
'dpm2 a':'DPM2 a',
|
|
'dpm++ 2s a':'DPM++ 2S a',
|
|
'dpm++ 2m':'DPM++ 2M',
|
|
'dpm++ sde':'DPM++ SDE',
|
|
'dpm fast':'DPM fast',
|
|
'dpm adaptive':'DPM adaptive',
|
|
'lms karras':'LMS Karras' ,
|
|
'dpm2 karras':'DPM2 Karras',
|
|
'dpm2 a karras':'DPM2 a Karras',
|
|
'dpm++ 2s a karras':'DPM++ 2S a Karras',
|
|
'dpm++ 2m karras':'DPM++ 2M Karras',
|
|
'dpm++ sde karras':'DPM++ SDE Karras'
|
|
}
|
|
if sampler_name is not None:
|
|
if sampler_name in available_samplers.keys():
|
|
args.sampler = available_samplers[sampler_name]
|
|
|
|
if args.checkpoint is not None:
|
|
info = sd_models.get_closet_checkpoint_match(args.checkpoint)
|
|
if info is None:
|
|
raise RuntimeError(f"Unknown checkpoint: {args.checkpoint}")
|
|
sd_models.reload_model_weights(info=info)
|
|
|
|
if args.init_sample is not None:
|
|
# TODO: cleanup init_sample remains later
|
|
img = args.init_sample
|
|
init_image = img
|
|
image_init0 = img
|
|
if loop_args.use_looper and isJson(loop_args.imagesToKeyframe):
|
|
init_image = Image.blend(init_image, init_image2, blendFactor)
|
|
correction_colors = Image.blend(init_image, init_image2, colorCorrectionFactor)
|
|
p.color_corrections = [processing.setup_color_correction(correction_colors)]
|
|
|
|
if anim_args.border == 'smart':
|
|
|
|
# Inpaint changed parts of the image
|
|
# that's, to say, zeros we got after the transformations
|
|
|
|
# Its important to note that the loop below is creating a mask for inpainting 0's
|
|
# This mask however can mask areas that were intended to be black
|
|
# Suggest a fix to send the inpainting mask as an argument,
|
|
# before the add_noise and contrast_adjust is applied
|
|
mask_image = init_image.convert('L')
|
|
for x in range(mask_image.width):
|
|
for y in range(mask_image.height):
|
|
if mask_image.getpixel((x,y)) < 4:
|
|
mask_image.putpixel((x,y), 255)
|
|
else:
|
|
mask_image.putpixel((x,y), 0)
|
|
|
|
# blend the two masks
|
|
if root.warp_mask is not None:
|
|
# TODO: I guess there is some built-in function for this
|
|
warp_mask_image = Image.fromarray(root.warp_mask).convert('L')
|
|
for x in range(mask_image.width):
|
|
for y in range(mask_image.height):
|
|
if mask_image.getpixel((x,y)) > 0 or warp_mask_image.getpixel((x,y)) > 0:
|
|
mask_image.putpixel((x,y), 255)
|
|
else:
|
|
mask_image.putpixel((x,y), 0)
|
|
root.warp_mask = None
|
|
|
|
mask = prepare_mask(mask_image,
|
|
(args.W, args.H),
|
|
args.mask_contrast_adjust,
|
|
args.mask_brightness_adjust)
|
|
|
|
# HACK: this is a hacky check to make the mask work with the new inpainting code
|
|
crop_region = masking.get_crop_region(np.array(mask_image), args.full_res_mask_padding)
|
|
crop_region = masking.expand_crop_region(crop_region, args.W, args.H, mask_image.width, mask_image.height)
|
|
x1, y1, x2, y2 = crop_region
|
|
|
|
too_small = (x2 - x1) < 1 or (y2 - y1) < 1
|
|
|
|
if not too_small:
|
|
p.do_not_save_samples=True,
|
|
p.inpainting_fill = args.smart_border_fill_mode
|
|
p.inpaint_full_res= args.full_res_mask
|
|
p.inpaint_full_res_padding = args.full_res_mask_padding
|
|
p.init_images = [init_image]
|
|
p.image_mask = mask_image
|
|
|
|
#color correction for zeroes inpainting
|
|
p.color_corrections = [processing.setup_color_correction(init_image)]
|
|
|
|
print("Smart mode: inpainting border")
|
|
|
|
processed = processing.process_images(p)
|
|
init_image = processed.images[0].convert('RGB')
|
|
|
|
p = get_webui_sd_pipeline(args, root, frame)
|
|
p.init_images = [init_image]
|
|
|
|
processed = None
|
|
else:
|
|
# fix tqdm total steps if we don't have to conduct a second pass
|
|
tqdm_instance = shared.total_tqdm
|
|
current_total = tqdm_instance.getTotal()
|
|
if current_total != -1:
|
|
tqdm_instance.updateTotal(current_total - int(math.ceil(args.steps * (1-args.strength))))
|
|
|
|
mask = None
|
|
mask_image = None
|
|
|
|
# this is the first pass
|
|
elif loop_args.use_looper or (args.use_init and ((args.init_image != None and args.init_image != ''))):
|
|
init_image, mask_image = load_img(image_init0, # initial init image
|
|
shape=(args.W, args.H),
|
|
use_alpha_as_mask=args.use_alpha_as_mask)
|
|
|
|
else:
|
|
if anim_args.animation_mode != 'Interpolation':
|
|
print(f"Not using an init image (doing pure txt2img)")
|
|
p_txt = StableDiffusionProcessingTxt2Img(
|
|
sd_model=sd_model,
|
|
outpath_samples=p.outpath_samples,
|
|
outpath_grids=p.outpath_samples,
|
|
prompt=p.prompt,
|
|
styles=p.styles,
|
|
negative_prompt=p.negative_prompt,
|
|
seed=p.seed,
|
|
subseed=p.subseed,
|
|
subseed_strength=p.subseed_strength,
|
|
seed_resize_from_h=p.seed_resize_from_h,
|
|
seed_resize_from_w=p.seed_resize_from_w,
|
|
sampler_name=p.sampler_name,
|
|
batch_size=p.batch_size,
|
|
n_iter=p.n_iter,
|
|
steps=p.steps,
|
|
cfg_scale=p.cfg_scale,
|
|
width=p.width,
|
|
height=p.height,
|
|
restore_faces=p.restore_faces,
|
|
tiling=p.tiling,
|
|
enable_hr=None,
|
|
denoising_strength=None,
|
|
)
|
|
# print dynamic table to cli
|
|
print_generate_table(args, anim_args, p_txt)
|
|
|
|
processed = processing.process_images(p_txt)
|
|
|
|
if processed is None:
|
|
# Mask functions
|
|
if args.use_mask:
|
|
mask = args.mask_image
|
|
#assign masking options to pipeline
|
|
if mask is not None:
|
|
p.inpainting_mask_invert = args.invert_mask
|
|
p.inpainting_fill = args.fill
|
|
p.inpaint_full_res= args.full_res_mask
|
|
p.inpaint_full_res_padding = args.full_res_mask_padding
|
|
else:
|
|
mask = None
|
|
|
|
assert not ( (mask is not None and args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True"
|
|
|
|
p.init_images = [init_image]
|
|
p.image_mask = mask
|
|
p.image_cfg_scale = args.pix2pix_img_cfg_scale
|
|
|
|
# print dynamic table to cli
|
|
print_generate_table(args, anim_args, p)
|
|
|
|
processed = processing.process_images(p)
|
|
|
|
if root.initial_info == None:
|
|
root.initial_seed = processed.seed
|
|
root.initial_info = processed.info
|
|
|
|
if root.first_frame == None:
|
|
root.first_frame = processed.images[0]
|
|
|
|
results = processed.images[0]
|
|
|
|
return results
|
|
|
|
def print_generate_table(args, anim_args, p):
|
|
x = PrettyTable(padding_width=0)
|
|
field_names = ["Steps", "CFG"]
|
|
if anim_args.animation_mode != 'Interpolation':
|
|
field_names.append("Denoise")
|
|
field_names += ["Subseed", "Subs. str"] * (args.seed_enable_extras)
|
|
field_names += ["Sampler"] * anim_args.enable_sampler_scheduling
|
|
field_names += ["Checkpoint"] * anim_args.enable_checkpoint_scheduling
|
|
x.field_names = field_names
|
|
row = [p.steps, p.cfg_scale]
|
|
if anim_args.animation_mode != 'Interpolation':
|
|
row.append(p.denoising_strength)
|
|
row += [p.subseed, p.subseed_strength] * (args.seed_enable_extras)
|
|
row += [p.sampler_name] * anim_args.enable_sampler_scheduling
|
|
row += [args.checkpoint] * anim_args.enable_checkpoint_scheduling
|
|
x.add_row(row)
|
|
print(x) |