mirror of https://github.com/vladmandic/automatic
RUF013 updates
parent
92960de8d6
commit
641321d7d2
|
|
@ -51,7 +51,7 @@ def dont_quant():
|
|||
return False
|
||||
|
||||
|
||||
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared, devices
|
||||
if allow and (module == 'any' or module in shared.opts.bnb_quantization):
|
||||
load_bnb()
|
||||
|
|
@ -74,7 +74,7 @@ def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model',
|
|||
return kwargs
|
||||
|
||||
|
||||
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared
|
||||
if allow and (shared.opts.torchao_quantization_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.torchao_quantization):
|
||||
torchao = load_torchao()
|
||||
|
|
@ -93,7 +93,7 @@ def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', m
|
|||
return kwargs
|
||||
|
||||
|
||||
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared
|
||||
if allow and (module == 'any' or module in shared.opts.quanto_quantization):
|
||||
load_quanto(silent=True)
|
||||
|
|
@ -115,7 +115,7 @@ def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model
|
|||
return kwargs
|
||||
|
||||
|
||||
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
|
||||
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared
|
||||
if allow and (module == 'any' or module in shared.opts.trt_quantization):
|
||||
load_trt()
|
||||
|
|
@ -163,7 +163,7 @@ def get_sdnq_devices(mode="pre"):
|
|||
return quantization_device, return_device
|
||||
|
||||
|
||||
def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model', weights_dtype: str = None, quantized_matmul_dtype: str = None, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
|
||||
def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model', weights_dtype: str | None = None, quantized_matmul_dtype: str | None = None, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
|
||||
from modules import shared
|
||||
if allow and (shared.opts.sdnq_quantize_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.sdnq_quantize_weights):
|
||||
from modules.sdnq import SDNQConfig
|
||||
|
|
@ -276,7 +276,7 @@ def check_nunchaku(module: str = ''):
|
|||
return False
|
||||
|
||||
|
||||
def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
|
||||
def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if module == 'Model' and dont_quant():
|
||||
|
|
@ -508,7 +508,7 @@ def apply_layerwise(sd_model, quiet:bool=False):
|
|||
log.error(f'Quantization: type=layerwise {e}')
|
||||
|
||||
|
||||
def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weights_dtype: str = None, quantized_matmul_dtype: str = None, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
|
||||
def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weights_dtype: str | None = None, quantized_matmul_dtype: str | None = None, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
|
||||
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
|
||||
from modules import devices, shared, timer
|
||||
from modules.sdnq import sdnq_post_load_quant
|
||||
|
|
@ -774,7 +774,7 @@ def torchao_quantization(sd_model):
|
|||
return sd_model
|
||||
|
||||
|
||||
def get_dit_args(load_config:dict=None, module:str=None, device_map:bool=False, allow_quant:bool=True, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
|
||||
def get_dit_args(load_config: dict | None = None, module: str | None = None, device_map: bool = False, allow_quant: bool = True, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
|
||||
from modules import shared, devices
|
||||
config = {} if load_config is None else load_config.copy()
|
||||
if 'torch_dtype' not in config:
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def hf_login(token=None):
|
|||
return True
|
||||
|
||||
|
||||
def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
|
||||
def download_diffusers_model(hub_id: str, cache_dir: str | None = None, download_config: dict[str, str | bool] | None = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
|
||||
if hub_id is None or len(hub_id) == 0:
|
||||
return None
|
||||
from diffusers import DiffusionPipeline
|
||||
|
|
@ -219,7 +219,7 @@ def get_reference_opts(name: str, quiet=False):
|
|||
return model_opts
|
||||
|
||||
|
||||
def load_reference(name: str, variant: str = None, revision: str = None, mirror: str = None, custom_pipeline: str = None):
|
||||
def load_reference(name: str, variant: str | None = None, revision: str | None = None, mirror: str | None = None, custom_pipeline: str | None = None):
|
||||
if '+' in name:
|
||||
name = name.split('+')[0]
|
||||
found = [r for r in diffuser_repos if name == r['name'] or name == r['friendly'] or name == r['path']]
|
||||
|
|
@ -337,7 +337,7 @@ def load_file_from_url(url: str, *, model_dir: str, progress: bool = True, file_
|
|||
return None
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||
def load_models(model_path: str, model_url: str | None = None, command_path: str | None = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
@param download_name: Specify to download from model_url immediately.
|
||||
|
|
@ -404,7 +404,7 @@ def cleanup_models():
|
|||
move_files(src_path, dest_path)
|
||||
|
||||
|
||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
def move_files(src_path: str, dest_path: str, ext_filter: str | None = None):
|
||||
try:
|
||||
if not os.path.exists(dest_path):
|
||||
os.makedirs(dest_path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue