RUF013 updates

pull/4706/head
awsr 2026-03-24 05:48:19 -07:00
parent 3f830589d1
commit c4ebef29a9
No known key found for this signature in database
13 changed files with 73 additions and 73 deletions

View File

@ -600,7 +600,7 @@ def split_prompts(pipe, prompt, SD3 = False):
return prompt, prompt2, prompt3, prompt4
def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None, prompt_mean_norm=None, diffusers_zeros_prompt_pad=None, te_pooled_embeds=None):
def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int | None = None, prompt_mean_norm=None, diffusers_zeros_prompt_pad=None, te_pooled_embeds=None):
device = devices.device
if prompt is None:
prompt = ''

View File

@ -13,7 +13,7 @@ def map_keys(key: str, key_mapping: dict) -> str:
return new_key
def load_safetensors(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict:
def load_safetensors(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu") -> dict:
from safetensors.torch import safe_open
if state_dict is None:
state_dict = {}
@ -23,7 +23,7 @@ def load_safetensors(files: list[str], state_dict: dict = None, key_mapping: dic
state_dict[map_keys(key, key_mapping)] = f.get_tensor(key)
def load_threaded(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict:
def load_threaded(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu") -> dict:
future_items = {}
if state_dict is None:
state_dict = {}
@ -34,7 +34,7 @@ def load_threaded(files: list[str], state_dict: dict = None, key_mapping: dict =
future.result()
def load_streamer(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict:
def load_streamer(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu") -> dict:
# requires pip install runai_model_streamer
from runai_model_streamer import SafetensorsStreamer
if state_dict is None:
@ -45,7 +45,7 @@ def load_streamer(files: list[str], state_dict: dict = None, key_mapping: dict =
state_dict[map_keys(key, key_mapping)] = tensor.to(device)
def load_files(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu", method: str = None) -> dict:
def load_files(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu", method: str | None = None) -> dict:
# note: files is list-of-files within a module for chunked loading, not accross model
if isinstance(files, str):
files = [files]

View File

@ -20,11 +20,11 @@ def conv_fp16_matmul(
padding_mode: str, conv_type: int,
groups: int, stride: list[int],
padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)

View File

@ -19,11 +19,11 @@ def conv_fp8_matmul(
padding_mode: str, conv_type: int,
groups: int, stride: list[int],
padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)

View File

@ -20,11 +20,11 @@ def conv_fp8_matmul_tensorwise(
padding_mode: str, conv_type: int,
groups: int, stride: list[int],
padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)

View File

@ -20,11 +20,11 @@ def conv_int8_matmul(
padding_mode: str, conv_type: int,
groups: int, stride: list[int],
padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
return_dtype = input.dtype
input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation)

View File

@ -76,16 +76,16 @@ def quantized_conv_forward(self, input) -> torch.FloatTensor:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias)
def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: list[int] = None) -> torch.FloatTensor:
def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: list[int] = None) -> torch.FloatTensor:
def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: list[int] = None) -> torch.FloatTensor:
def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation)
return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)

View File

@ -14,11 +14,11 @@ def fp16_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
scale: torch.FloatTensor,
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
if quantized_weight_shape is not None:
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float16).t_()

View File

@ -19,11 +19,11 @@ def fp8_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
scale: torch.FloatTensor,
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
if quantized_weight_shape is not None:
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_()

View File

@ -22,11 +22,11 @@ def fp8_matmul_tensorwise(
input: torch.FloatTensor,
weight: torch.Tensor,
scale: torch.FloatTensor,
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
if quantized_weight_shape is not None:
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_()

View File

@ -22,11 +22,11 @@ def int8_matmul(
input: torch.FloatTensor,
weight: torch.Tensor,
scale: torch.FloatTensor,
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
quantized_weight_shape: torch.Size = None,
weights_dtype: str = None,
bias: torch.FloatTensor | None = None,
svd_up: torch.FloatTensor | None = None,
svd_down: torch.FloatTensor | None = None,
quantized_weight_shape: torch.Size | None = None,
weights_dtype: str | None = None,
) -> torch.FloatTensor:
if quantized_weight_shape is not None:
weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_()

View File

@ -25,7 +25,7 @@ def unset_config_on_save(quantization_config: SDNQConfig) -> SDNQConfig:
return quantization_config
def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "5GB", is_pipeline: bool = False, sdnq_config: SDNQConfig = None) -> None:
def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "5GB", is_pipeline: bool = False, sdnq_config: SDNQConfig | None = None) -> None:
if is_pipeline:
for module_name in get_module_names(model):
module = getattr(model, module_name, None)
@ -63,7 +63,7 @@ def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "5
model.config.quantization_config.to_json_file(quantization_config_path)
def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: str = None, dtype: torch.dtype = None, device: torch.device = "cpu", dequantize_fp32: bool = None, use_quantized_matmul: bool = None, model_config: dict = None, quantization_config: dict = None, load_method: str = "safetensors") -> ModelMixin:
def load_sdnq_model(model_path: str, model_cls: ModelMixin | None = None, file_name: str | None = None, dtype: torch.dtype | None = None, device: torch.device = "cpu", dequantize_fp32: bool | None = None, use_quantized_matmul: bool | None = None, model_config: dict | None = None, quantization_config: dict | None = None, load_method: str = "safetensors") -> ModelMixin:
from accelerate import init_empty_weights
with init_empty_weights():
@ -162,7 +162,7 @@ def post_process_model(model):
return model
def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp32: bool = None, use_quantized_matmul: bool = None):
def apply_sdnq_options_to_module(model, dtype: torch.dtype | None = None, dequantize_fp32: bool | None = None, use_quantized_matmul: bool | None = None):
has_children = list(model.children())
if not has_children:
if dtype is not None and getattr(model, "dtype", torch.float32) not in {torch.float32, torch.float64}:
@ -231,7 +231,7 @@ def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp
return model
def apply_sdnq_options_to_model(model, dtype: torch.dtype = None, dequantize_fp32: bool = None, use_quantized_matmul: bool = None):
def apply_sdnq_options_to_model(model, dtype: torch.dtype | None = None, dequantize_fp32: bool | None = None, use_quantized_matmul: bool | None = None):
if use_quantized_matmul and not check_torch_compile():
raise RuntimeError("SDNQ Quantized MatMul requires a working Triton install.")
model = apply_sdnq_options_to_module(model, dtype=dtype, dequantize_fp32=dequantize_fp32, use_quantized_matmul=use_quantized_matmul)

View File

@ -189,7 +189,7 @@ def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: dict[str, dict])
return quant_kwargs
def add_module_skip_keys(model, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None):
def add_module_skip_keys(model, modules_to_not_convert: list[str] | None = None, modules_dtype_dict: dict[str, list[str]] | None = None):
if modules_to_not_convert is None:
modules_to_not_convert = []
if modules_dtype_dict is None:
@ -547,7 +547,7 @@ def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None
@devices.inference_context()
def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=True, non_blocking=False, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None, modules_quant_config: dict[str, dict] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument
def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=True, non_blocking=False, modules_to_not_convert: list[str] | None = None, modules_dtype_dict: dict[str, list[str]] | None = None, modules_quant_config: dict[str, dict] | None = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument
has_children = list(model.children())
if not has_children:
return model, modules_to_not_convert, modules_dtype_dict
@ -628,8 +628,8 @@ def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=Non
def sdnq_post_load_quant(
model: torch.nn.Module,
weights_dtype: str = "int8",
quantized_matmul_dtype: str = None,
torch_dtype: torch.dtype = None,
quantized_matmul_dtype: str | None = None,
torch_dtype: torch.dtype | None = None,
group_size: int = 0,
svd_rank: int = 32,
svd_steps: int = 8,
@ -643,11 +643,11 @@ def sdnq_post_load_quant(
dequantize_fp32: bool = True,
non_blocking: bool = False,
add_skip_keys:bool = True,
quantization_device: torch.device = None,
return_device: torch.device = None,
modules_to_not_convert: list[str] = None,
modules_dtype_dict: dict[str, list[str]] = None,
modules_quant_config: dict[str, dict] = None,
quantization_device: torch.device | None = None,
return_device: torch.device | None = None,
modules_to_not_convert: list[str] | None = None,
modules_dtype_dict: dict[str, list[str]] | None = None,
modules_quant_config: dict[str, dict] | None = None,
):
if modules_to_not_convert is None:
modules_to_not_convert = []
@ -735,9 +735,9 @@ class SDNQQuantize:
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
model: torch.nn.Module = None,
full_layer_name: str = None,
missing_keys: list[str] = None, # pylint: disable=unused-argument
model: torch.nn.Module | None = None,
full_layer_name: str | None = None,
missing_keys: list[str] | None = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> dict[str, torch.FloatTensor]:
_module_name, value = tuple(input_dict.items())[0]
@ -898,11 +898,11 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
def adjust_target_dtype(self, target_dtype: torch.dtype) -> torch.dtype: # pylint: disable=unused-argument,arguments-renamed
return dtype_dict[self.quantization_config.weights_dtype]["target_dtype"]
def update_torch_dtype(self, torch_dtype: torch.dtype = None) -> torch.dtype:
def update_torch_dtype(self, torch_dtype: torch.dtype | None = None) -> torch.dtype:
self.torch_dtype = torch_dtype
return torch_dtype
def update_dtype(self, dtype: torch.dtype = None) -> torch.dtype:
def update_dtype(self, dtype: torch.dtype | None = None) -> torch.dtype:
"""
needed for transformers compatibilty, returns self.update_torch_dtype
"""
@ -912,7 +912,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
self,
model,
device_map, # pylint: disable=unused-argument
keep_in_fp32_modules: list[str] = None,
keep_in_fp32_modules: list[str] | None = None,
**kwargs, # pylint: disable=unused-argument
):
if self.pre_quantized:
@ -1055,7 +1055,7 @@ class SDNQConfig(QuantizationConfigMixin):
def __init__( # pylint: disable=super-init-not-called
self,
weights_dtype: str = "int8",
quantized_matmul_dtype: str = None,
quantized_matmul_dtype: str | None = None,
group_size: int = 0,
svd_rank: int = 32,
svd_steps: int = 8,
@ -1071,11 +1071,11 @@ class SDNQConfig(QuantizationConfigMixin):
dequantize_fp32: bool = True,
non_blocking: bool = False,
add_skip_keys: bool = True,
quantization_device: torch.device = None,
return_device: torch.device = None,
modules_to_not_convert: list[str] = None,
modules_dtype_dict: dict[str, list[str]] = None,
modules_quant_config: dict[str, dict] = None,
quantization_device: torch.device | None = None,
return_device: torch.device | None = None,
modules_to_not_convert: list[str] | None = None,
modules_dtype_dict: dict[str, list[str]] | None = None,
modules_quant_config: dict[str, dict] | None = None,
is_training: bool = False,
**kwargs, # pylint: disable=unused-argument
):