RUF013 updates

pull/4706/head
awsr 2026-03-24 04:34:38 -07:00
parent 92960de8d6
commit 641321d7d2
No known key found for this signature in database
2 changed files with 12 additions and 12 deletions

View File

@ -51,7 +51,7 @@ def dont_quant():
return False
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared, devices
if allow and (module == 'any' or module in shared.opts.bnb_quantization):
load_bnb()
@ -74,7 +74,7 @@ def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model',
return kwargs
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared
if allow and (shared.opts.torchao_quantization_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.torchao_quantization):
torchao = load_torchao()
@ -93,7 +93,7 @@ def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', m
return kwargs
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared
if allow and (module == 'any' or module in shared.opts.quanto_quantization):
load_quanto(silent=True)
@ -115,7 +115,7 @@ def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model
return kwargs
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None):
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared
if allow and (module == 'any' or module in shared.opts.trt_quantization):
load_trt()
@ -163,7 +163,7 @@ def get_sdnq_devices(mode="pre"):
return quantization_device, return_device
def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model', weights_dtype: str = None, quantized_matmul_dtype: str = None, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model', weights_dtype: str | None = None, quantized_matmul_dtype: str | None = None, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
from modules import shared
if allow and (shared.opts.sdnq_quantize_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.sdnq_quantize_weights):
from modules.sdnq import SDNQConfig
@ -276,7 +276,7 @@ def check_nunchaku(module: str = ''):
return False
def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
if kwargs is None:
kwargs = {}
if module == 'Model' and dont_quant():
@ -508,7 +508,7 @@ def apply_layerwise(sd_model, quiet:bool=False):
log.error(f'Quantization: type=layerwise {e}')
def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weights_dtype: str = None, quantized_matmul_dtype: str = None, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weights_dtype: str | None = None, quantized_matmul_dtype: str | None = None, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
from modules import devices, shared, timer
from modules.sdnq import sdnq_post_load_quant
@ -774,7 +774,7 @@ def torchao_quantization(sd_model):
return sd_model
def get_dit_args(load_config:dict=None, module:str=None, device_map:bool=False, allow_quant:bool=True, modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
def get_dit_args(load_config: dict | None = None, module: str | None = None, device_map: bool = False, allow_quant: bool = True, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
from modules import shared, devices
config = {} if load_config is None else load_config.copy()
if 'torch_dtype' not in config:

View File

@ -56,7 +56,7 @@ def hf_login(token=None):
return True
def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
def download_diffusers_model(hub_id: str, cache_dir: str | None = None, download_config: dict[str, str | bool] | None = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
if hub_id is None or len(hub_id) == 0:
return None
from diffusers import DiffusionPipeline
@ -219,7 +219,7 @@ def get_reference_opts(name: str, quiet=False):
return model_opts
def load_reference(name: str, variant: str = None, revision: str = None, mirror: str = None, custom_pipeline: str = None):
def load_reference(name: str, variant: str | None = None, revision: str | None = None, mirror: str | None = None, custom_pipeline: str | None = None):
if '+' in name:
name = name.split('+')[0]
found = [r for r in diffuser_repos if name == r['name'] or name == r['friendly'] or name == r['path']]
@ -337,7 +337,7 @@ def load_file_from_url(url: str, *, model_dir: str, progress: bool = True, file_
return None
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
def load_models(model_path: str, model_url: str | None = None, command_path: str | None = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@param download_name: Specify to download from model_url immediately.
@ -404,7 +404,7 @@ def cleanup_models():
move_files(src_path, dest_path)
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def move_files(src_path: str, dest_path: str, ext_filter: str | None = None):
try:
if not os.path.exists(dest_path):
os.makedirs(dest_path)