mirror of https://github.com/vladmandic/automatic
RUF013 updates and typing update
parent
f0bb0a921a
commit
3f830589d1
|
|
@ -467,7 +467,7 @@ def report_model_stats(module_name, module):
|
||||||
log.error(f'Module stats: name={module_name} {e}')
|
log.error(f'Module stats: name={module_name} {e}')
|
||||||
|
|
||||||
|
|
||||||
def apply_balanced_offload(sd_model=None, exclude:list[str]=None, force:bool=False, silent:bool=False):
|
def apply_balanced_offload(sd_model=None, exclude: list[str] | None = None, force: bool = False, silent: bool = False):
|
||||||
global offload_hook_instance # pylint: disable=global-statement
|
global offload_hook_instance # pylint: disable=global-statement
|
||||||
if shared.opts.diffusers_offload_mode != "balanced":
|
if shared.opts.diffusers_offload_mode != "balanced":
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def deregister_aux(name: str) -> None:
|
||||||
debug_move(f'Offload: type=aux op=deregister name={name}')
|
debug_move(f'Offload: type=aux op=deregister name={name}')
|
||||||
|
|
||||||
|
|
||||||
def evict_aux(exclude: str = None, reason: str = 'evict') -> None:
|
def evict_aux(exclude: str | None = None, reason: str = 'evict') -> None:
|
||||||
for name, entry in aux_models.items():
|
for name, entry in aux_models.items():
|
||||||
if name == exclude:
|
if name == exclude:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from modules.logger import log
|
||||||
|
|
||||||
|
|
||||||
def get_t5_prompt_embeds(
|
def get_t5_prompt_embeds(
|
||||||
prompt: str | list[str] = None,
|
prompt: str | list[str] | None = None,
|
||||||
num_images_per_prompt: int = 1, # pylint: disable=unused-argument
|
num_images_per_prompt: int = 1, # pylint: disable=unused-argument
|
||||||
max_sequence_length: int = 512, # pylint: disable=unused-argument
|
max_sequence_length: int = 512, # pylint: disable=unused-argument
|
||||||
device: torch.device | None = None,
|
device: torch.device | None = None,
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,11 @@ def load_unet_sdxl_nunchaku(repo_id):
|
||||||
return unet
|
return unet
|
||||||
|
|
||||||
|
|
||||||
def load_unet(model, repo_id:str=None):
|
def load_unet(model, repo_id: str | None = None):
|
||||||
global loaded_unet # pylint: disable=global-statement
|
global loaded_unet # pylint: disable=global-statement
|
||||||
|
|
||||||
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and (('stable-diffusion-xl-base' in repo_id) or ('sdxl-turbo' in repo_id)):
|
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and repo_id is not None and (("stable-diffusion-xl-base" in repo_id) or ("sdxl-turbo" in repo_id)):
|
||||||
if model_quant.check_nunchaku('Model'):
|
if model_quant.check_nunchaku("Model"):
|
||||||
unet = load_unet_sdxl_nunchaku(repo_id)
|
unet = load_unet_sdxl_nunchaku(repo_id)
|
||||||
if unet is not None:
|
if unet is not None:
|
||||||
model.unet = unet
|
model.unet = unet
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue