mirror of https://github.com/vladmandic/automatic
parent
35803746df
commit
e0faa149dd
|
|
@ -206,7 +206,6 @@ def post_unload_checkpoint():
|
||||||
from modules import sd_models
|
from modules import sd_models
|
||||||
sd_models.unload_model_weights(op='model')
|
sd_models.unload_model_weights(op='model')
|
||||||
sd_models.unload_model_weights(op='refiner')
|
sd_models.unload_model_weights(op='refiner')
|
||||||
sd_models.unload_auxiliary_models()
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def post_reload_checkpoint(force:bool=False):
|
def post_reload_checkpoint(force:bool=False):
|
||||||
|
|
|
||||||
|
|
@ -284,7 +284,7 @@ def process_init(p: StableDiffusionProcessing):
|
||||||
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
|
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
|
||||||
|
|
||||||
|
|
||||||
def _p_or_opt(p, key):
|
def get_opt(p, key):
|
||||||
val = getattr(p, key, None)
|
val = getattr(p, key, None)
|
||||||
return val if val is not None else getattr(shared.opts, key)
|
return val if val is not None else getattr(shared.opts, key)
|
||||||
|
|
||||||
|
|
@ -313,9 +313,9 @@ def process_samples(p: StableDiffusionProcessing, samples):
|
||||||
|
|
||||||
if p.detailer_enabled:
|
if p.detailer_enabled:
|
||||||
p.ops.append('detailer')
|
p.ops.append('detailer')
|
||||||
if not p.do_not_save_samples and _p_or_opt(p, 'save_images_before_detailer'):
|
if not p.do_not_save_samples and get_opt(p, 'save_images_before_detailer'):
|
||||||
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
||||||
images.save_image(Image.fromarray(sample), path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=_p_or_opt(p, 'samples_format'), info=info, p=p, suffix="-before-detailer")
|
images.save_image(Image.fromarray(sample), path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=get_opt(p, 'samples_format'), info=info, p=p, suffix="-before-detailer")
|
||||||
sample = detailer.detail(sample, p)
|
sample = detailer.detail(sample, p)
|
||||||
if isinstance(sample, list):
|
if isinstance(sample, list):
|
||||||
if len(sample) > 0:
|
if len(sample) > 0:
|
||||||
|
|
@ -329,10 +329,10 @@ def process_samples(p: StableDiffusionProcessing, samples):
|
||||||
|
|
||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
p.ops.append('color')
|
p.ops.append('color')
|
||||||
if not p.do_not_save_samples and _p_or_opt(p, 'save_images_before_color_correction'):
|
if not p.do_not_save_samples and get_opt(p, 'save_images_before_color_correction'):
|
||||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||||
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
||||||
images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=_p_or_opt(p, 'samples_format'), info=info, p=p, suffix="-before-color-correct")
|
images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=get_opt(p, 'samples_format'), info=info, p=p, suffix="-before-color-correct")
|
||||||
method = p.color_correction_method if p.color_correction_method is not None else getattr(shared.opts, 'color_correction_method', 'histogram')
|
method = p.color_correction_method if p.color_correction_method is not None else getattr(shared.opts, 'color_correction_method', 'histogram')
|
||||||
image = apply_color_correction(p.color_corrections[i], image, method=method)
|
image = apply_color_correction(p.color_corrections[i], image, method=method)
|
||||||
|
|
||||||
|
|
@ -373,10 +373,10 @@ def process_samples(p: StableDiffusionProcessing, samples):
|
||||||
if _overlay:
|
if _overlay:
|
||||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||||
|
|
||||||
_save_mask = _p_or_opt(p, 'save_mask')
|
_save_mask = get_opt(p, 'save_mask')
|
||||||
_save_mask_composite = _p_or_opt(p, 'save_mask_composite')
|
_save_mask_composite = get_opt(p, 'save_mask_composite')
|
||||||
_return_mask = _p_or_opt(p, 'return_mask')
|
_return_mask = get_opt(p, 'return_mask')
|
||||||
_return_mask_composite = _p_or_opt(p, 'return_mask_composite')
|
_return_mask_composite = get_opt(p, 'return_mask_composite')
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([_save_mask, _save_mask_composite, _return_mask, _return_mask_composite]):
|
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([_save_mask, _save_mask_composite, _return_mask, _return_mask_composite]):
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask = p.mask_for_overlay.convert('RGB')
|
||||||
image1 = image.convert('RGBA').convert('RGBa')
|
image1 = image.convert('RGBA').convert('RGBa')
|
||||||
|
|
@ -384,7 +384,7 @@ def process_samples(p: StableDiffusionProcessing, samples):
|
||||||
mask = images.resize_image(3, p.mask_for_overlay, image.width, image.height).convert('L')
|
mask = images.resize_image(3, p.mask_for_overlay, image.width, image.height).convert('L')
|
||||||
image_mask_composite = Image.composite(image1, image2, mask).convert('RGBA')
|
image_mask_composite = Image.composite(image1, image2, mask).convert('RGBA')
|
||||||
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
||||||
_fmt = _p_or_opt(p, 'samples_format')
|
_fmt = get_opt(p, 'samples_format')
|
||||||
if _save_mask:
|
if _save_mask:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], _fmt, info=info, p=p, suffix="-mask")
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], _fmt, info=info, p=p, suffix="-mask")
|
||||||
if _save_mask_composite:
|
if _save_mask_composite:
|
||||||
|
|
@ -420,8 +420,8 @@ def process_samples(p: StableDiffusionProcessing, samples):
|
||||||
image = images.resize_image(p.resize_mode_after, image, p.width_after, p.height_after, p.resize_name_after, context=p.resize_context_after)
|
image = images.resize_image(p.resize_mode_after, image, p.width_after, p.height_after, p.resize_name_after, context=p.resize_context_after)
|
||||||
|
|
||||||
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
||||||
if _p_or_opt(p, 'samples_save') and not p.do_not_save_samples and p.outpath_samples is not None:
|
if get_opt(p, 'samples_save') and not p.do_not_save_samples and p.outpath_samples is not None:
|
||||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], _p_or_opt(p, 'samples_format'), info=info, p=p) # main save image
|
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], get_opt(p, 'samples_format'), info=info, p=p) # main save image
|
||||||
|
|
||||||
image.info["parameters"] = info
|
image.info["parameters"] = info
|
||||||
out_infotexts.append(info)
|
out_infotexts.append(info)
|
||||||
|
|
@ -499,7 +499,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
log.debug(f'Process: batch={n+1}/{p.n_iter} interrupted')
|
log.debug(f'Process: batch={n+1}/{p.n_iter} interrupted')
|
||||||
_keep = _p_or_opt(p, 'keep_incomplete')
|
_keep = get_opt(p, 'keep_incomplete')
|
||||||
p.do_not_save_samples = not _keep
|
p.do_not_save_samples = not _keep
|
||||||
if shared.state.current_image is not None and isinstance(shared.state.current_image, Image.Image):
|
if shared.state.current_image is not None and isinstance(shared.state.current_image, Image.Image):
|
||||||
samples = [shared.state.current_image]
|
samples = [shared.state.current_image]
|
||||||
|
|
@ -544,8 +544,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
index_of_first_image = 0
|
index_of_first_image = 0
|
||||||
_return_grid = _p_or_opt(p, 'return_grid')
|
_return_grid = get_opt(p, 'return_grid')
|
||||||
_grid_save = _p_or_opt(p, 'grid_save')
|
_grid_save = get_opt(p, 'grid_save')
|
||||||
if (_return_grid or _grid_save) and (not p.do_not_save_grid) and (len(output_images) > 1):
|
if (_return_grid or _grid_save) and (not p.do_not_save_grid) and (len(output_images) > 1):
|
||||||
if images.check_grid_size(output_images):
|
if images.check_grid_size(output_images):
|
||||||
r, c = images.get_grid_size(output_images, p.batch_size)
|
r, c = images.get_grid_size(output_images, p.batch_size)
|
||||||
|
|
@ -557,7 +557,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
if _grid_save:
|
if _grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "", p.all_seeds[0], p.all_prompts[0], _p_or_opt(p, 'grid_format'), info=grid_info, p=p, grid=True) # main save grid
|
images.save_image(grid, p.outpath_grids, "", p.all_seeds[0], p.all_prompts[0], get_opt(p, 'grid_format'), info=grid_info, p=p, grid=True) # main save grid
|
||||||
|
|
||||||
results = get_processed(
|
results = get_processed(
|
||||||
p,
|
p,
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = No
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
return kwargs
|
return kwargs
|
||||||
elif (getattr(p, 'nan_skip', None) if (p is not None and getattr(p, 'nan_skip', None) is not None) else shared.opts.nan_skip):
|
elif shared.opts.nan_skip:
|
||||||
assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...'
|
assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...'
|
||||||
if p is None:
|
if p is None:
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ def process_base(p: processing.StableDiffusionProcessing):
|
||||||
update_sampler(p, shared.sd_model)
|
update_sampler(p, shared.sd_model)
|
||||||
timer.process.record('prepare')
|
timer.process.record('prepare')
|
||||||
process_pre(p)
|
process_pre(p)
|
||||||
|
sched_eta = p.scheduler_eta if p.scheduler_eta is not None else shared.opts.scheduler_eta
|
||||||
desc = 'Base'
|
desc = 'Base'
|
||||||
if 'detailer' in p.ops:
|
if 'detailer' in p.ops:
|
||||||
desc = 'Detail'
|
desc = 'Detail'
|
||||||
|
|
@ -154,7 +155,7 @@ def process_base(p: processing.StableDiffusionProcessing):
|
||||||
prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
|
prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
|
||||||
negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
|
negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
|
||||||
num_inference_steps=calculate_base_steps(p, use_refiner_start=use_refiner_start, use_denoise_start=use_denoise_start),
|
num_inference_steps=calculate_base_steps(p, use_refiner_start=use_refiner_start, use_denoise_start=use_denoise_start),
|
||||||
eta=shared.opts.scheduler_eta,
|
eta=sched_eta,
|
||||||
guidance_scale=p.cfg_scale,
|
guidance_scale=p.cfg_scale,
|
||||||
guidance_rescale=p.diffusers_guidance_rescale,
|
guidance_rescale=p.diffusers_guidance_rescale,
|
||||||
true_cfg_scale=p.pag_scale,
|
true_cfg_scale=p.pag_scale,
|
||||||
|
|
@ -163,12 +164,13 @@ def process_base(p: processing.StableDiffusionProcessing):
|
||||||
num_frames=getattr(p, 'frames', 1),
|
num_frames=getattr(p, 'frames', 1),
|
||||||
output_type=output_type,
|
output_type=output_type,
|
||||||
clip_skip=p.clip_skip,
|
clip_skip=p.clip_skip,
|
||||||
|
prompt_attention=getattr(p, 'prompt_attention', None),
|
||||||
desc=desc,
|
desc=desc,
|
||||||
)
|
)
|
||||||
base_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None)
|
base_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None)
|
||||||
shared.state.update(get_job_name(p, shared.sd_model), base_steps, 1)
|
shared.state.update(get_job_name(p, shared.sd_model), base_steps, 1)
|
||||||
if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1:
|
if sched_eta is not None and sched_eta > 0 and sched_eta < 1:
|
||||||
p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta
|
p.extra_generation_params["Sampler Eta"] = sched_eta
|
||||||
output = None
|
output = None
|
||||||
if debug:
|
if debug:
|
||||||
modelstats.analyze()
|
modelstats.analyze()
|
||||||
|
|
@ -304,6 +306,7 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
|
||||||
|
|
||||||
prompts = p.prompts
|
prompts = p.prompts
|
||||||
reset_prompts = False
|
reset_prompts = False
|
||||||
|
sched_eta = p.scheduler_eta if p.scheduler_eta is not None else shared.opts.scheduler_eta
|
||||||
if len(p.refiner_prompt) > 0:
|
if len(p.refiner_prompt) > 0:
|
||||||
prompts = len(output.images)* [p.refiner_prompt]
|
prompts = len(output.images)* [p.refiner_prompt]
|
||||||
prompts, p.network_data = extra_networks.parse_prompts(prompts)
|
prompts, p.network_data = extra_networks.parse_prompts(prompts)
|
||||||
|
|
@ -319,13 +322,14 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
|
||||||
prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
|
prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
|
||||||
negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
|
negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
|
||||||
num_inference_steps=calculate_hires_steps(p),
|
num_inference_steps=calculate_hires_steps(p),
|
||||||
eta=shared.opts.scheduler_eta,
|
eta=sched_eta,
|
||||||
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
|
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
|
||||||
guidance_rescale=p.diffusers_guidance_rescale,
|
guidance_rescale=p.diffusers_guidance_rescale,
|
||||||
output_type=output_type,
|
output_type=output_type,
|
||||||
clip_skip=p.clip_skip,
|
clip_skip=p.clip_skip,
|
||||||
image=output.images,
|
image=output.images,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
prompt_attention=getattr(p, 'prompt_attention', None),
|
||||||
desc='Hires',
|
desc='Hires',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -397,15 +401,14 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
|
||||||
p.extra_generation_params['Noise level'] = noise_level
|
p.extra_generation_params['Noise level'] = noise_level
|
||||||
refiner_output_type = 'np'
|
refiner_output_type = 'np'
|
||||||
update_sampler(p, shared.sd_refiner, second_pass=True)
|
update_sampler(p, shared.sd_refiner, second_pass=True)
|
||||||
shared.opts.prompt_attention = 'fixed'
|
sched_eta = p.scheduler_eta if p.scheduler_eta is not None else shared.opts.scheduler_eta
|
||||||
refiner_args = set_pipeline_args(
|
refiner_args = set_pipeline_args(
|
||||||
p=p,
|
p=p,
|
||||||
model=shared.sd_refiner,
|
model=shared.sd_refiner,
|
||||||
prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts[i],
|
prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts[i],
|
||||||
negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts[i],
|
negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts[i],
|
||||||
num_inference_steps=calculate_refiner_steps(p),
|
num_inference_steps=calculate_refiner_steps(p),
|
||||||
eta=shared.opts.scheduler_eta,
|
eta=sched_eta,
|
||||||
# strength=p.denoising_strength,
|
|
||||||
noise_level=noise_level, # StableDiffusionUpscalePipeline only
|
noise_level=noise_level, # StableDiffusionUpscalePipeline only
|
||||||
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
|
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
|
||||||
guidance_rescale=p.diffusers_guidance_rescale,
|
guidance_rescale=p.diffusers_guidance_rescale,
|
||||||
|
|
|
||||||
|
|
@ -567,8 +567,7 @@ def save_intermediate(p, latents, suffix):
|
||||||
info=create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i)
|
info=create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i)
|
||||||
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, output_type='pil', vae_type=p.vae_type, width=p.width, height=p.height)
|
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, output_type='pil', vae_type=p.vae_type, width=p.width, height=p.height)
|
||||||
for j in range(len(decoded)):
|
for j in range(len(decoded)):
|
||||||
_fmt = p.samples_format if p.samples_format is not None else shared.opts.samples_format
|
images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix=suffix)
|
||||||
images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=_fmt, info=info, p=p, suffix=suffix)
|
|
||||||
|
|
||||||
|
|
||||||
def update_sampler(p, sd_model, second_pass=False):
|
def update_sampler(p, sd_model, second_pass=False):
|
||||||
|
|
|
||||||
|
|
@ -161,15 +161,15 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
|
||||||
if 'color' in p.ops:
|
if 'color' in p.ops:
|
||||||
args["Color correction"] = True
|
args["Color correction"] = True
|
||||||
|
|
||||||
def _p_or_opt(key):
|
def get_opt(key):
|
||||||
val = getattr(p, key, None)
|
val = getattr(p, key, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
return val
|
return val
|
||||||
return getattr(shared.opts, key, None)
|
return getattr(shared.opts, key, None)
|
||||||
|
|
||||||
_token_method = _p_or_opt('token_merging_method')
|
_token_method = get_opt('token_merging_method')
|
||||||
_tome = _p_or_opt('tome_ratio')
|
_tome = get_opt('tome_ratio')
|
||||||
_todo = _p_or_opt('todo_ratio')
|
_todo = get_opt('todo_ratio')
|
||||||
if _token_method == 'ToMe': # tome/todo
|
if _token_method == 'ToMe': # tome/todo
|
||||||
args['ToMe'] = _tome if _tome != 0 else None
|
args['ToMe'] = _tome if _tome != 0 else None
|
||||||
elif _token_method == 'ToDo':
|
elif _token_method == 'ToDo':
|
||||||
|
|
@ -179,23 +179,23 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
|
||||||
|
|
||||||
# samplers
|
# samplers
|
||||||
if getattr(p, 'sampler_name', None) is not None and p.sampler_name.lower() != 'default':
|
if getattr(p, 'sampler_name', None) is not None and p.sampler_name.lower() != 'default':
|
||||||
_eta_delta = _p_or_opt('eta_noise_seed_delta')
|
_eta_delta = get_opt('eta_noise_seed_delta')
|
||||||
args["Sampler eta delta"] = _eta_delta if _eta_delta != 0 and sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None
|
args["Sampler eta delta"] = _eta_delta if _eta_delta != 0 and sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None
|
||||||
args["Sampler eta multiplier"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None
|
args["Sampler eta multiplier"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None
|
||||||
args['Sampler timesteps'] = _p_or_opt('schedulers_timesteps') if _p_or_opt('schedulers_timesteps') != shared.opts.data_labels.get('schedulers_timesteps').default else None
|
args['Sampler timesteps'] = get_opt('schedulers_timesteps') if get_opt('schedulers_timesteps') != shared.opts.data_labels.get('schedulers_timesteps').default else None
|
||||||
args['Sampler spacing'] = _p_or_opt('schedulers_timestep_spacing') if _p_or_opt('schedulers_timestep_spacing') != shared.opts.data_labels.get('schedulers_timestep_spacing').default else None
|
args['Sampler spacing'] = get_opt('schedulers_timestep_spacing') if get_opt('schedulers_timestep_spacing') != shared.opts.data_labels.get('schedulers_timestep_spacing').default else None
|
||||||
args['Sampler sigma'] = _p_or_opt('schedulers_sigma') if _p_or_opt('schedulers_sigma') != shared.opts.data_labels.get('schedulers_sigma').default else None
|
args['Sampler sigma'] = get_opt('schedulers_sigma') if get_opt('schedulers_sigma') != shared.opts.data_labels.get('schedulers_sigma').default else None
|
||||||
args['Sampler order'] = _p_or_opt('schedulers_solver_order') if _p_or_opt('schedulers_solver_order') != shared.opts.data_labels.get('schedulers_solver_order').default else None
|
args['Sampler order'] = get_opt('schedulers_solver_order') if get_opt('schedulers_solver_order') != shared.opts.data_labels.get('schedulers_solver_order').default else None
|
||||||
args['Sampler type'] = _p_or_opt('schedulers_prediction_type') if _p_or_opt('schedulers_prediction_type') != shared.opts.data_labels.get('schedulers_prediction_type').default else None
|
args['Sampler type'] = get_opt('schedulers_prediction_type') if get_opt('schedulers_prediction_type') != shared.opts.data_labels.get('schedulers_prediction_type').default else None
|
||||||
args['Sampler beta schedule'] = _p_or_opt('schedulers_beta_schedule') if _p_or_opt('schedulers_beta_schedule') != shared.opts.data_labels.get('schedulers_beta_schedule').default else None
|
args['Sampler beta schedule'] = get_opt('schedulers_beta_schedule') if get_opt('schedulers_beta_schedule') != shared.opts.data_labels.get('schedulers_beta_schedule').default else None
|
||||||
args['Sampler low order'] = _p_or_opt('schedulers_use_loworder') if _p_or_opt('schedulers_use_loworder') != shared.opts.data_labels.get('schedulers_use_loworder').default else None
|
args['Sampler low order'] = get_opt('schedulers_use_loworder') if get_opt('schedulers_use_loworder') != shared.opts.data_labels.get('schedulers_use_loworder').default else None
|
||||||
args['Sampler dynamic'] = _p_or_opt('schedulers_use_thresholding') if _p_or_opt('schedulers_use_thresholding') != shared.opts.data_labels.get('schedulers_use_thresholding').default else None
|
args['Sampler dynamic'] = get_opt('schedulers_use_thresholding') if get_opt('schedulers_use_thresholding') != shared.opts.data_labels.get('schedulers_use_thresholding').default else None
|
||||||
args['Sampler rescale'] = _p_or_opt('schedulers_rescale_betas') if _p_or_opt('schedulers_rescale_betas') != shared.opts.data_labels.get('schedulers_rescale_betas').default else None
|
args['Sampler rescale'] = get_opt('schedulers_rescale_betas') if get_opt('schedulers_rescale_betas') != shared.opts.data_labels.get('schedulers_rescale_betas').default else None
|
||||||
args['Sampler beta start'] = _p_or_opt('schedulers_beta_start') if _p_or_opt('schedulers_beta_start') != shared.opts.data_labels.get('schedulers_beta_start').default else None
|
args['Sampler beta start'] = get_opt('schedulers_beta_start') if get_opt('schedulers_beta_start') != shared.opts.data_labels.get('schedulers_beta_start').default else None
|
||||||
args['Sampler beta end'] = _p_or_opt('schedulers_beta_end') if _p_or_opt('schedulers_beta_end') != shared.opts.data_labels.get('schedulers_beta_end').default else None
|
args['Sampler beta end'] = get_opt('schedulers_beta_end') if get_opt('schedulers_beta_end') != shared.opts.data_labels.get('schedulers_beta_end').default else None
|
||||||
args['Sampler range'] = _p_or_opt('schedulers_timesteps_range') if _p_or_opt('schedulers_timesteps_range') != shared.opts.data_labels.get('schedulers_timesteps_range').default else None
|
args['Sampler range'] = get_opt('schedulers_timesteps_range') if get_opt('schedulers_timesteps_range') != shared.opts.data_labels.get('schedulers_timesteps_range').default else None
|
||||||
args['Sampler shift'] = _p_or_opt('schedulers_shift') if _p_or_opt('schedulers_shift') != shared.opts.data_labels.get('schedulers_shift').default else None
|
args['Sampler shift'] = get_opt('schedulers_shift') if get_opt('schedulers_shift') != shared.opts.data_labels.get('schedulers_shift').default else None
|
||||||
args['Sampler dynamic shift'] = _p_or_opt('schedulers_dynamic_shift') if _p_or_opt('schedulers_dynamic_shift') != shared.opts.data_labels.get('schedulers_dynamic_shift').default else None
|
args['Sampler dynamic shift'] = get_opt('schedulers_dynamic_shift') if get_opt('schedulers_dynamic_shift') != shared.opts.data_labels.get('schedulers_dynamic_shift').default else None
|
||||||
|
|
||||||
# model specific
|
# model specific
|
||||||
if shared.sd_model_type == 'h1':
|
if shared.sd_model_type == 'h1':
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ class CheckpointInfo:
|
||||||
|
|
||||||
def setup_model():
|
def setup_model():
|
||||||
list_models()
|
list_models()
|
||||||
sd_hijack_accelerate.hijack_hfhub()
|
# sd_hijack_accelerate.hijack_hfhub()
|
||||||
# sd_hijack_accelerate.hijack_torch_conv()
|
# sd_hijack_accelerate.hijack_torch_conv()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,12 +80,6 @@ def restore_accelerate():
|
||||||
accelerate.utils.set_module_tensor_to_device = orig_set_module
|
accelerate.utils.set_module_tensor_to_device = orig_set_module
|
||||||
|
|
||||||
|
|
||||||
def hijack_hfhub():
|
|
||||||
import contextlib
|
|
||||||
import huggingface_hub.file_download
|
|
||||||
huggingface_hub.file_download.FileLock = contextlib.nullcontext
|
|
||||||
|
|
||||||
|
|
||||||
def torch_conv_forward(self, input, weight, bias): # pylint: disable=redefined-builtin
|
def torch_conv_forward(self, input, weight, bias): # pylint: disable=redefined-builtin
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != 'zeros':
|
||||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups) # pylint: disable=protected-access
|
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups) # pylint: disable=protected-access
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
from modules.logger import log
|
||||||
from modules import shared, devices
|
from modules import shared, devices
|
||||||
|
|
||||||
# based on <https://github.com/ljleb/sd-webui-freeu/blob/main/lib_free_u/unet.py>
|
# based on <https://github.com/ljleb/sd-webui-freeu/blob/main/lib_free_u/unet.py>
|
||||||
|
|
@ -91,7 +92,7 @@ def get_fft_device():
|
||||||
torch_fft_device = devices.device
|
torch_fft_device = devices.device
|
||||||
except Exception:
|
except Exception:
|
||||||
torch_fft_device = devices.cpu
|
torch_fft_device = devices.cpu
|
||||||
shared.log.warning(f'FreeU: device={devices.device} dtype={devices.dtype} does not support FFT')
|
log.warning(f'FreeU: device={devices.device} dtype={devices.dtype} does not support FFT')
|
||||||
return torch_fft_device
|
return torch_fft_device
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -166,4 +167,4 @@ def apply_freeu(p):
|
||||||
p.sd_model.disable_freeu()
|
p.sd_model.disable_freeu()
|
||||||
state_enabled = False
|
state_enabled = False
|
||||||
if enabled and state_enabled:
|
if enabled and state_enabled:
|
||||||
shared.log.info(f'Applying Free-U: b1={b1} b2={b2} s1={s1} s2={s2}')
|
log.info(f'Applying Free-U: b1={b1} b2={b2} s1={s1} s2={s2}')
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from installer import log
|
from modules.logger import log
|
||||||
|
|
||||||
|
|
||||||
def _p_or_opt(p, key):
|
def get_opt(p, key):
|
||||||
val = getattr(p, key, None)
|
val = getattr(p, key, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
return val
|
return val
|
||||||
|
|
@ -184,10 +184,10 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128
|
||||||
|
|
||||||
def context_hypertile_vae(p):
|
def context_hypertile_vae(p):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
if p.sd_model is None or not _p_or_opt(p, 'hypertile_vae_enabled'):
|
if p.sd_model is None or not get_opt(p, 'hypertile_vae_enabled'):
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
|
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
|
||||||
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
global max_h, max_w, error_reported # pylint: disable=global-statement
|
global max_h, max_w, error_reported # pylint: disable=global-statement
|
||||||
error_reported = False
|
error_reported = False
|
||||||
|
|
@ -204,21 +204,21 @@ def context_hypertile_vae(p):
|
||||||
if vae is None:
|
if vae is None:
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
else:
|
else:
|
||||||
_vae_tile = _p_or_opt(p, 'hypertile_vae_tile')
|
_vae_tile = get_opt(p, 'hypertile_vae_tile')
|
||||||
tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
||||||
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile')
|
_min_tile = get_opt(p, 'hypertile_unet_min_tile')
|
||||||
min_tile_size = _min_tile if _min_tile > 0 else 128
|
min_tile_size = _min_tile if _min_tile > 0 else 128
|
||||||
shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
|
log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
|
||||||
p.extra_generation_params['Hypertile VAE'] = tile_size
|
p.extra_generation_params['Hypertile VAE'] = tile_size
|
||||||
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_vae_swap_size'))
|
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=get_opt(p, 'hypertile_vae_swap_size'))
|
||||||
|
|
||||||
|
|
||||||
def context_hypertile_unet(p):
|
def context_hypertile_unet(p):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
if p.sd_model is None or not _p_or_opt(p, 'hypertile_unet_enabled'):
|
if p.sd_model is None or not get_opt(p, 'hypertile_unet_enabled'):
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
|
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
|
||||||
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
global max_h, max_w, error_reported # pylint: disable=global-statement
|
global max_h, max_w, error_reported # pylint: disable=global-statement
|
||||||
error_reported = False
|
error_reported = False
|
||||||
|
|
@ -232,25 +232,25 @@ def context_hypertile_unet(p):
|
||||||
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
|
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
if unet is None:
|
if unet is None:
|
||||||
# shared.log.warning('Hypertile UNet is enabled but no Unet model was found')
|
# log.warning('Hypertile UNet is enabled but no Unet model was found')
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
else:
|
else:
|
||||||
_unet_tile = _p_or_opt(p, 'hypertile_unet_tile')
|
_unet_tile = get_opt(p, 'hypertile_unet_tile')
|
||||||
tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
||||||
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile')
|
_min_tile = get_opt(p, 'hypertile_unet_min_tile')
|
||||||
min_tile_size = _min_tile if _min_tile > 0 else 128
|
min_tile_size = _min_tile if _min_tile > 0 else 128
|
||||||
shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
|
log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
|
||||||
p.extra_generation_params['Hypertile UNet'] = tile_size
|
p.extra_generation_params['Hypertile UNet'] = tile_size
|
||||||
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_unet_swap_size'), depth=_p_or_opt(p, 'hypertile_unet_depth'))
|
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=get_opt(p, 'hypertile_unet_swap_size'), depth=get_opt(p, 'hypertile_unet_depth'))
|
||||||
|
|
||||||
|
|
||||||
def hypertile_set(p, hr=False):
|
def hypertile_set(p, hr=False):
|
||||||
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
|
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
|
||||||
if not _p_or_opt(p, 'hypertile_unet_enabled'):
|
if not get_opt(p, 'hypertile_unet_enabled'):
|
||||||
return
|
return
|
||||||
error_reported = False
|
error_reported = False
|
||||||
set_resolution(p, hr=hr)
|
set_resolution(p, hr=hr)
|
||||||
skip_hypertile = _p_or_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False)
|
skip_hypertile = get_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False)
|
||||||
reset_needed = True
|
reset_needed = True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
||||||
Subproject commit ca56aaecedf08089e94a0bd0c97bcfa9f79cd731
|
Subproject commit 33dbd026a2e2fb7311d545a3b2d2db0363bb887f
|
||||||
Loading…
Reference in New Issue