Signed-off-by: vladmandic <mandic00@live.com>
pull/4690/head
vladmandic 2026-03-13 14:44:58 +01:00
parent 35803746df
commit e0faa149dd
11 changed files with 69 additions and 73 deletions

View File

@ -206,7 +206,6 @@ def post_unload_checkpoint():
from modules import sd_models from modules import sd_models
sd_models.unload_model_weights(op='model') sd_models.unload_model_weights(op='model')
sd_models.unload_model_weights(op='refiner') sd_models.unload_model_weights(op='refiner')
sd_models.unload_auxiliary_models()
return {} return {}
def post_reload_checkpoint(force:bool=False): def post_reload_checkpoint(force:bool=False):

View File

@ -284,7 +284,7 @@ def process_init(p: StableDiffusionProcessing):
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)] p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
def _p_or_opt(p, key): def get_opt(p, key):
val = getattr(p, key, None) val = getattr(p, key, None)
return val if val is not None else getattr(shared.opts, key) return val if val is not None else getattr(shared.opts, key)
@ -313,9 +313,9 @@ def process_samples(p: StableDiffusionProcessing, samples):
if p.detailer_enabled: if p.detailer_enabled:
p.ops.append('detailer') p.ops.append('detailer')
if not p.do_not_save_samples and _p_or_opt(p, 'save_images_before_detailer'): if not p.do_not_save_samples and get_opt(p, 'save_images_before_detailer'):
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i) info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
images.save_image(Image.fromarray(sample), path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=_p_or_opt(p, 'samples_format'), info=info, p=p, suffix="-before-detailer") images.save_image(Image.fromarray(sample), path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=get_opt(p, 'samples_format'), info=info, p=p, suffix="-before-detailer")
sample = detailer.detail(sample, p) sample = detailer.detail(sample, p)
if isinstance(sample, list): if isinstance(sample, list):
if len(sample) > 0: if len(sample) > 0:
@ -329,10 +329,10 @@ def process_samples(p: StableDiffusionProcessing, samples):
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
p.ops.append('color') p.ops.append('color')
if not p.do_not_save_samples and _p_or_opt(p, 'save_images_before_color_correction'): if not p.do_not_save_samples and get_opt(p, 'save_images_before_color_correction'):
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i) info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=_p_or_opt(p, 'samples_format'), info=info, p=p, suffix="-before-color-correct") images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=get_opt(p, 'samples_format'), info=info, p=p, suffix="-before-color-correct")
method = p.color_correction_method if p.color_correction_method is not None else getattr(shared.opts, 'color_correction_method', 'histogram') method = p.color_correction_method if p.color_correction_method is not None else getattr(shared.opts, 'color_correction_method', 'histogram')
image = apply_color_correction(p.color_corrections[i], image, method=method) image = apply_color_correction(p.color_corrections[i], image, method=method)
@ -373,10 +373,10 @@ def process_samples(p: StableDiffusionProcessing, samples):
if _overlay: if _overlay:
image = apply_overlay(image, p.paste_to, i, p.overlay_images) image = apply_overlay(image, p.paste_to, i, p.overlay_images)
_save_mask = _p_or_opt(p, 'save_mask') _save_mask = get_opt(p, 'save_mask')
_save_mask_composite = _p_or_opt(p, 'save_mask_composite') _save_mask_composite = get_opt(p, 'save_mask_composite')
_return_mask = _p_or_opt(p, 'return_mask') _return_mask = get_opt(p, 'return_mask')
_return_mask_composite = _p_or_opt(p, 'return_mask_composite') _return_mask_composite = get_opt(p, 'return_mask_composite')
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([_save_mask, _save_mask_composite, _return_mask, _return_mask_composite]): if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([_save_mask, _save_mask_composite, _return_mask, _return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB') image_mask = p.mask_for_overlay.convert('RGB')
image1 = image.convert('RGBA').convert('RGBa') image1 = image.convert('RGBA').convert('RGBa')
@ -384,7 +384,7 @@ def process_samples(p: StableDiffusionProcessing, samples):
mask = images.resize_image(3, p.mask_for_overlay, image.width, image.height).convert('L') mask = images.resize_image(3, p.mask_for_overlay, image.width, image.height).convert('L')
image_mask_composite = Image.composite(image1, image2, mask).convert('RGBA') image_mask_composite = Image.composite(image1, image2, mask).convert('RGBA')
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i) info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
_fmt = _p_or_opt(p, 'samples_format') _fmt = get_opt(p, 'samples_format')
if _save_mask: if _save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], _fmt, info=info, p=p, suffix="-mask") images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], _fmt, info=info, p=p, suffix="-mask")
if _save_mask_composite: if _save_mask_composite:
@ -420,8 +420,8 @@ def process_samples(p: StableDiffusionProcessing, samples):
image = images.resize_image(p.resize_mode_after, image, p.width_after, p.height_after, p.resize_name_after, context=p.resize_context_after) image = images.resize_image(p.resize_mode_after, image, p.width_after, p.height_after, p.resize_name_after, context=p.resize_context_after)
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i) info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
if _p_or_opt(p, 'samples_save') and not p.do_not_save_samples and p.outpath_samples is not None: if get_opt(p, 'samples_save') and not p.do_not_save_samples and p.outpath_samples is not None:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], _p_or_opt(p, 'samples_format'), info=info, p=p) # main save image images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], get_opt(p, 'samples_format'), info=info, p=p) # main save image
image.info["parameters"] = info image.info["parameters"] = info
out_infotexts.append(info) out_infotexts.append(info)
@ -499,7 +499,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.state.interrupted: if shared.state.interrupted:
log.debug(f'Process: batch={n+1}/{p.n_iter} interrupted') log.debug(f'Process: batch={n+1}/{p.n_iter} interrupted')
_keep = _p_or_opt(p, 'keep_incomplete') _keep = get_opt(p, 'keep_incomplete')
p.do_not_save_samples = not _keep p.do_not_save_samples = not _keep
if shared.state.current_image is not None and isinstance(shared.state.current_image, Image.Image): if shared.state.current_image is not None and isinstance(shared.state.current_image, Image.Image):
samples = [shared.state.current_image] samples = [shared.state.current_image]
@ -544,8 +544,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.color_corrections = None p.color_corrections = None
index_of_first_image = 0 index_of_first_image = 0
_return_grid = _p_or_opt(p, 'return_grid') _return_grid = get_opt(p, 'return_grid')
_grid_save = _p_or_opt(p, 'grid_save') _grid_save = get_opt(p, 'grid_save')
if (_return_grid or _grid_save) and (not p.do_not_save_grid) and (len(output_images) > 1): if (_return_grid or _grid_save) and (not p.do_not_save_grid) and (len(output_images) > 1):
if images.check_grid_size(output_images): if images.check_grid_size(output_images):
r, c = images.get_grid_size(output_images, p.batch_size) r, c = images.get_grid_size(output_images, p.batch_size)
@ -557,7 +557,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
if _grid_save: if _grid_save:
images.save_image(grid, p.outpath_grids, "", p.all_seeds[0], p.all_prompts[0], _p_or_opt(p, 'grid_format'), info=grid_info, p=p, grid=True) # main save grid images.save_image(grid, p.outpath_grids, "", p.all_seeds[0], p.all_prompts[0], get_opt(p, 'grid_format'), info=grid_info, p=p, grid=True) # main save grid
results = get_processed( results = get_processed(
p, p,

View File

@ -75,7 +75,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = No
time.sleep(0.1) time.sleep(0.1)
if latents is None: if latents is None:
return kwargs return kwargs
elif (getattr(p, 'nan_skip', None) if (p is not None and getattr(p, 'nan_skip', None) is not None) else shared.opts.nan_skip): elif shared.opts.nan_skip:
assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...' assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...'
if p is None: if p is None:
return kwargs return kwargs

View File

@ -143,6 +143,7 @@ def process_base(p: processing.StableDiffusionProcessing):
update_sampler(p, shared.sd_model) update_sampler(p, shared.sd_model)
timer.process.record('prepare') timer.process.record('prepare')
process_pre(p) process_pre(p)
sched_eta = p.scheduler_eta if p.scheduler_eta is not None else shared.opts.scheduler_eta
desc = 'Base' desc = 'Base'
if 'detailer' in p.ops: if 'detailer' in p.ops:
desc = 'Detail' desc = 'Detail'
@ -154,7 +155,7 @@ def process_base(p: processing.StableDiffusionProcessing):
prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
num_inference_steps=calculate_base_steps(p, use_refiner_start=use_refiner_start, use_denoise_start=use_denoise_start), num_inference_steps=calculate_base_steps(p, use_refiner_start=use_refiner_start, use_denoise_start=use_denoise_start),
eta=shared.opts.scheduler_eta, eta=sched_eta,
guidance_scale=p.cfg_scale, guidance_scale=p.cfg_scale,
guidance_rescale=p.diffusers_guidance_rescale, guidance_rescale=p.diffusers_guidance_rescale,
true_cfg_scale=p.pag_scale, true_cfg_scale=p.pag_scale,
@ -163,12 +164,13 @@ def process_base(p: processing.StableDiffusionProcessing):
num_frames=getattr(p, 'frames', 1), num_frames=getattr(p, 'frames', 1),
output_type=output_type, output_type=output_type,
clip_skip=p.clip_skip, clip_skip=p.clip_skip,
prompt_attention=getattr(p, 'prompt_attention', None),
desc=desc, desc=desc,
) )
base_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None) base_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None)
shared.state.update(get_job_name(p, shared.sd_model), base_steps, 1) shared.state.update(get_job_name(p, shared.sd_model), base_steps, 1)
if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1: if sched_eta is not None and sched_eta > 0 and sched_eta < 1:
p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta p.extra_generation_params["Sampler Eta"] = sched_eta
output = None output = None
if debug: if debug:
modelstats.analyze() modelstats.analyze()
@ -304,6 +306,7 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
prompts = p.prompts prompts = p.prompts
reset_prompts = False reset_prompts = False
sched_eta = p.scheduler_eta if p.scheduler_eta is not None else shared.opts.scheduler_eta
if len(p.refiner_prompt) > 0: if len(p.refiner_prompt) > 0:
prompts = len(output.images)* [p.refiner_prompt] prompts = len(output.images)* [p.refiner_prompt]
prompts, p.network_data = extra_networks.parse_prompts(prompts) prompts, p.network_data = extra_networks.parse_prompts(prompts)
@ -319,13 +322,14 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
num_inference_steps=calculate_hires_steps(p), num_inference_steps=calculate_hires_steps(p),
eta=shared.opts.scheduler_eta, eta=sched_eta,
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale, guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
guidance_rescale=p.diffusers_guidance_rescale, guidance_rescale=p.diffusers_guidance_rescale,
output_type=output_type, output_type=output_type,
clip_skip=p.clip_skip, clip_skip=p.clip_skip,
image=output.images, image=output.images,
strength=strength, strength=strength,
prompt_attention=getattr(p, 'prompt_attention', None),
desc='Hires', desc='Hires',
) )
@ -397,15 +401,14 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
p.extra_generation_params['Noise level'] = noise_level p.extra_generation_params['Noise level'] = noise_level
refiner_output_type = 'np' refiner_output_type = 'np'
update_sampler(p, shared.sd_refiner, second_pass=True) update_sampler(p, shared.sd_refiner, second_pass=True)
shared.opts.prompt_attention = 'fixed' sched_eta = p.scheduler_eta if p.scheduler_eta is not None else shared.opts.scheduler_eta
refiner_args = set_pipeline_args( refiner_args = set_pipeline_args(
p=p, p=p,
model=shared.sd_refiner, model=shared.sd_refiner,
prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts[i], prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts[i],
negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts[i], negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts[i],
num_inference_steps=calculate_refiner_steps(p), num_inference_steps=calculate_refiner_steps(p),
eta=shared.opts.scheduler_eta, eta=sched_eta,
# strength=p.denoising_strength,
noise_level=noise_level, # StableDiffusionUpscalePipeline only noise_level=noise_level, # StableDiffusionUpscalePipeline only
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale, guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
guidance_rescale=p.diffusers_guidance_rescale, guidance_rescale=p.diffusers_guidance_rescale,

View File

@ -567,8 +567,7 @@ def save_intermediate(p, latents, suffix):
info=create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i) info=create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i)
decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, output_type='pil', vae_type=p.vae_type, width=p.width, height=p.height) decoded = processing_vae.vae_decode(latents=latents, model=shared.sd_model, output_type='pil', vae_type=p.vae_type, width=p.width, height=p.height)
for j in range(len(decoded)): for j in range(len(decoded)):
_fmt = p.samples_format if p.samples_format is not None else shared.opts.samples_format images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix=suffix)
images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=_fmt, info=info, p=p, suffix=suffix)
def update_sampler(p, sd_model, second_pass=False): def update_sampler(p, sd_model, second_pass=False):

View File

@ -161,15 +161,15 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
if 'color' in p.ops: if 'color' in p.ops:
args["Color correction"] = True args["Color correction"] = True
def _p_or_opt(key): def get_opt(key):
val = getattr(p, key, None) val = getattr(p, key, None)
if val is not None: if val is not None:
return val return val
return getattr(shared.opts, key, None) return getattr(shared.opts, key, None)
_token_method = _p_or_opt('token_merging_method') _token_method = get_opt('token_merging_method')
_tome = _p_or_opt('tome_ratio') _tome = get_opt('tome_ratio')
_todo = _p_or_opt('todo_ratio') _todo = get_opt('todo_ratio')
if _token_method == 'ToMe': # tome/todo if _token_method == 'ToMe': # tome/todo
args['ToMe'] = _tome if _tome != 0 else None args['ToMe'] = _tome if _tome != 0 else None
elif _token_method == 'ToDo': elif _token_method == 'ToDo':
@ -179,23 +179,23 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
# samplers # samplers
if getattr(p, 'sampler_name', None) is not None and p.sampler_name.lower() != 'default': if getattr(p, 'sampler_name', None) is not None and p.sampler_name.lower() != 'default':
_eta_delta = _p_or_opt('eta_noise_seed_delta') _eta_delta = get_opt('eta_noise_seed_delta')
args["Sampler eta delta"] = _eta_delta if _eta_delta != 0 and sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None args["Sampler eta delta"] = _eta_delta if _eta_delta != 0 and sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None
args["Sampler eta multiplier"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None args["Sampler eta multiplier"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None
args['Sampler timesteps'] = _p_or_opt('schedulers_timesteps') if _p_or_opt('schedulers_timesteps') != shared.opts.data_labels.get('schedulers_timesteps').default else None args['Sampler timesteps'] = get_opt('schedulers_timesteps') if get_opt('schedulers_timesteps') != shared.opts.data_labels.get('schedulers_timesteps').default else None
args['Sampler spacing'] = _p_or_opt('schedulers_timestep_spacing') if _p_or_opt('schedulers_timestep_spacing') != shared.opts.data_labels.get('schedulers_timestep_spacing').default else None args['Sampler spacing'] = get_opt('schedulers_timestep_spacing') if get_opt('schedulers_timestep_spacing') != shared.opts.data_labels.get('schedulers_timestep_spacing').default else None
args['Sampler sigma'] = _p_or_opt('schedulers_sigma') if _p_or_opt('schedulers_sigma') != shared.opts.data_labels.get('schedulers_sigma').default else None args['Sampler sigma'] = get_opt('schedulers_sigma') if get_opt('schedulers_sigma') != shared.opts.data_labels.get('schedulers_sigma').default else None
args['Sampler order'] = _p_or_opt('schedulers_solver_order') if _p_or_opt('schedulers_solver_order') != shared.opts.data_labels.get('schedulers_solver_order').default else None args['Sampler order'] = get_opt('schedulers_solver_order') if get_opt('schedulers_solver_order') != shared.opts.data_labels.get('schedulers_solver_order').default else None
args['Sampler type'] = _p_or_opt('schedulers_prediction_type') if _p_or_opt('schedulers_prediction_type') != shared.opts.data_labels.get('schedulers_prediction_type').default else None args['Sampler type'] = get_opt('schedulers_prediction_type') if get_opt('schedulers_prediction_type') != shared.opts.data_labels.get('schedulers_prediction_type').default else None
args['Sampler beta schedule'] = _p_or_opt('schedulers_beta_schedule') if _p_or_opt('schedulers_beta_schedule') != shared.opts.data_labels.get('schedulers_beta_schedule').default else None args['Sampler beta schedule'] = get_opt('schedulers_beta_schedule') if get_opt('schedulers_beta_schedule') != shared.opts.data_labels.get('schedulers_beta_schedule').default else None
args['Sampler low order'] = _p_or_opt('schedulers_use_loworder') if _p_or_opt('schedulers_use_loworder') != shared.opts.data_labels.get('schedulers_use_loworder').default else None args['Sampler low order'] = get_opt('schedulers_use_loworder') if get_opt('schedulers_use_loworder') != shared.opts.data_labels.get('schedulers_use_loworder').default else None
args['Sampler dynamic'] = _p_or_opt('schedulers_use_thresholding') if _p_or_opt('schedulers_use_thresholding') != shared.opts.data_labels.get('schedulers_use_thresholding').default else None args['Sampler dynamic'] = get_opt('schedulers_use_thresholding') if get_opt('schedulers_use_thresholding') != shared.opts.data_labels.get('schedulers_use_thresholding').default else None
args['Sampler rescale'] = _p_or_opt('schedulers_rescale_betas') if _p_or_opt('schedulers_rescale_betas') != shared.opts.data_labels.get('schedulers_rescale_betas').default else None args['Sampler rescale'] = get_opt('schedulers_rescale_betas') if get_opt('schedulers_rescale_betas') != shared.opts.data_labels.get('schedulers_rescale_betas').default else None
args['Sampler beta start'] = _p_or_opt('schedulers_beta_start') if _p_or_opt('schedulers_beta_start') != shared.opts.data_labels.get('schedulers_beta_start').default else None args['Sampler beta start'] = get_opt('schedulers_beta_start') if get_opt('schedulers_beta_start') != shared.opts.data_labels.get('schedulers_beta_start').default else None
args['Sampler beta end'] = _p_or_opt('schedulers_beta_end') if _p_or_opt('schedulers_beta_end') != shared.opts.data_labels.get('schedulers_beta_end').default else None args['Sampler beta end'] = get_opt('schedulers_beta_end') if get_opt('schedulers_beta_end') != shared.opts.data_labels.get('schedulers_beta_end').default else None
args['Sampler range'] = _p_or_opt('schedulers_timesteps_range') if _p_or_opt('schedulers_timesteps_range') != shared.opts.data_labels.get('schedulers_timesteps_range').default else None args['Sampler range'] = get_opt('schedulers_timesteps_range') if get_opt('schedulers_timesteps_range') != shared.opts.data_labels.get('schedulers_timesteps_range').default else None
args['Sampler shift'] = _p_or_opt('schedulers_shift') if _p_or_opt('schedulers_shift') != shared.opts.data_labels.get('schedulers_shift').default else None args['Sampler shift'] = get_opt('schedulers_shift') if get_opt('schedulers_shift') != shared.opts.data_labels.get('schedulers_shift').default else None
args['Sampler dynamic shift'] = _p_or_opt('schedulers_dynamic_shift') if _p_or_opt('schedulers_dynamic_shift') != shared.opts.data_labels.get('schedulers_dynamic_shift').default else None args['Sampler dynamic shift'] = get_opt('schedulers_dynamic_shift') if get_opt('schedulers_dynamic_shift') != shared.opts.data_labels.get('schedulers_dynamic_shift').default else None
# model specific # model specific
if shared.sd_model_type == 'h1': if shared.sd_model_type == 'h1':

View File

@ -114,7 +114,7 @@ class CheckpointInfo:
def setup_model(): def setup_model():
list_models() list_models()
sd_hijack_accelerate.hijack_hfhub() # sd_hijack_accelerate.hijack_hfhub()
# sd_hijack_accelerate.hijack_torch_conv() # sd_hijack_accelerate.hijack_torch_conv()

View File

@ -80,12 +80,6 @@ def restore_accelerate():
accelerate.utils.set_module_tensor_to_device = orig_set_module accelerate.utils.set_module_tensor_to_device = orig_set_module
def hijack_hfhub():
import contextlib
import huggingface_hub.file_download
huggingface_hub.file_download.FileLock = contextlib.nullcontext
def torch_conv_forward(self, input, weight, bias): # pylint: disable=redefined-builtin def torch_conv_forward(self, input, weight, bias): # pylint: disable=redefined-builtin
if self.padding_mode != 'zeros': if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups) # pylint: disable=protected-access return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups) # pylint: disable=protected-access

View File

@ -1,5 +1,6 @@
import math import math
import torch import torch
from modules.logger import log
from modules import shared, devices from modules import shared, devices
# based on <https://github.com/ljleb/sd-webui-freeu/blob/main/lib_free_u/unet.py> # based on <https://github.com/ljleb/sd-webui-freeu/blob/main/lib_free_u/unet.py>
@ -91,7 +92,7 @@ def get_fft_device():
torch_fft_device = devices.device torch_fft_device = devices.device
except Exception: except Exception:
torch_fft_device = devices.cpu torch_fft_device = devices.cpu
shared.log.warning(f'FreeU: device={devices.device} dtype={devices.dtype} does not support FFT') log.warning(f'FreeU: device={devices.device} dtype={devices.dtype} does not support FFT')
return torch_fft_device return torch_fft_device
@ -166,4 +167,4 @@ def apply_freeu(p):
p.sd_model.disable_freeu() p.sd_model.disable_freeu()
state_enabled = False state_enabled = False
if enabled and state_enabled: if enabled and state_enabled:
shared.log.info(f'Applying Free-U: b1={b1} b2={b2} s1={s1} s2={s2}') log.info(f'Applying Free-U: b1={b1} b2={b2} s1={s1} s2={s2}')

View File

@ -10,10 +10,10 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
from installer import log from modules.logger import log
def _p_or_opt(p, key): def get_opt(p, key):
val = getattr(p, key, None) val = getattr(p, key, None)
if val is not None: if val is not None:
return val return val
@ -184,10 +184,10 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128
def context_hypertile_vae(p): def context_hypertile_vae(p):
from modules import shared from modules import shared
if p.sd_model is None or not _p_or_opt(p, 'hypertile_vae_enabled'): if p.sd_model is None or not get_opt(p, 'hypertile_vae_enabled'):
return nullcontext() return nullcontext()
if shared.opts.cross_attention_optimization == 'Sub-quadratic': if shared.opts.cross_attention_optimization == 'Sub-quadratic':
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
return nullcontext() return nullcontext()
global max_h, max_w, error_reported # pylint: disable=global-statement global max_h, max_w, error_reported # pylint: disable=global-statement
error_reported = False error_reported = False
@ -204,21 +204,21 @@ def context_hypertile_vae(p):
if vae is None: if vae is None:
return nullcontext() return nullcontext()
else: else:
_vae_tile = _p_or_opt(p, 'hypertile_vae_tile') _vae_tile = get_opt(p, 'hypertile_vae_tile')
tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128)) tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile') _min_tile = get_opt(p, 'hypertile_unet_min_tile')
min_tile_size = _min_tile if _min_tile > 0 else 128 min_tile_size = _min_tile if _min_tile > 0 else 128
shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}') log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
p.extra_generation_params['Hypertile VAE'] = tile_size p.extra_generation_params['Hypertile VAE'] = tile_size
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_vae_swap_size')) return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=get_opt(p, 'hypertile_vae_swap_size'))
def context_hypertile_unet(p): def context_hypertile_unet(p):
from modules import shared from modules import shared
if p.sd_model is None or not _p_or_opt(p, 'hypertile_unet_enabled'): if p.sd_model is None or not get_opt(p, 'hypertile_unet_enabled'):
return nullcontext() return nullcontext()
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental: if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
return nullcontext() return nullcontext()
global max_h, max_w, error_reported # pylint: disable=global-statement global max_h, max_w, error_reported # pylint: disable=global-statement
error_reported = False error_reported = False
@ -232,25 +232,25 @@ def context_hypertile_unet(p):
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8') log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
return nullcontext() return nullcontext()
if unet is None: if unet is None:
# shared.log.warning('Hypertile UNet is enabled but no Unet model was found') # log.warning('Hypertile UNet is enabled but no Unet model was found')
return nullcontext() return nullcontext()
else: else:
_unet_tile = _p_or_opt(p, 'hypertile_unet_tile') _unet_tile = get_opt(p, 'hypertile_unet_tile')
tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128)) tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile') _min_tile = get_opt(p, 'hypertile_unet_min_tile')
min_tile_size = _min_tile if _min_tile > 0 else 128 min_tile_size = _min_tile if _min_tile > 0 else 128
shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}') log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
p.extra_generation_params['Hypertile UNet'] = tile_size p.extra_generation_params['Hypertile UNet'] = tile_size
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_unet_swap_size'), depth=_p_or_opt(p, 'hypertile_unet_depth')) return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=get_opt(p, 'hypertile_unet_swap_size'), depth=get_opt(p, 'hypertile_unet_depth'))
def hypertile_set(p, hr=False): def hypertile_set(p, hr=False):
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
if not _p_or_opt(p, 'hypertile_unet_enabled'): if not get_opt(p, 'hypertile_unet_enabled'):
return return
error_reported = False error_reported = False
set_resolution(p, hr=hr) set_resolution(p, hr=hr)
skip_hypertile = _p_or_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False) skip_hypertile = get_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False)
reset_needed = True reset_needed = True

2
wiki

@ -1 +1 @@
Subproject commit ca56aaecedf08089e94a0bd0c97bcfa9f79cd731 Subproject commit 33dbd026a2e2fb7311d545a3b2d2db0363bb887f