diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index 93bfcbdab..8c7dd13ba 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -600,7 +600,7 @@ def split_prompts(pipe, prompt, SD3 = False): return prompt, prompt2, prompt3, prompt4 -def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None, prompt_mean_norm=None, diffusers_zeros_prompt_pad=None, te_pooled_embeds=None): +def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int | None = None, prompt_mean_norm=None, diffusers_zeros_prompt_pad=None, te_pooled_embeds=None): device = devices.device if prompt is None: prompt = '' diff --git a/modules/sdnq/file_loader.py b/modules/sdnq/file_loader.py index 8028627b8..4390745b7 100644 --- a/modules/sdnq/file_loader.py +++ b/modules/sdnq/file_loader.py @@ -13,7 +13,7 @@ def map_keys(key: str, key_mapping: dict) -> str: return new_key -def load_safetensors(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict: +def load_safetensors(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu") -> dict: from safetensors.torch import safe_open if state_dict is None: state_dict = {} @@ -23,7 +23,7 @@ def load_safetensors(files: list[str], state_dict: dict = None, key_mapping: dic state_dict[map_keys(key, key_mapping)] = f.get_tensor(key) -def load_threaded(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict: +def load_threaded(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu") -> dict: future_items = {} if state_dict is None: state_dict = {} @@ -34,7 +34,7 @@ def load_threaded(files: list[str], state_dict: dict = None, key_mapping: dict = future.result() -def load_streamer(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict: +def load_streamer(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu") -> dict: # requires pip install runai_model_streamer from runai_model_streamer import SafetensorsStreamer if state_dict is None: @@ -45,7 +45,7 @@ def load_streamer(files: list[str], state_dict: dict = None, key_mapping: dict = state_dict[map_keys(key, key_mapping)] = tensor.to(device) -def load_files(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu", method: str = None) -> dict: +def load_files(files: list[str], state_dict: dict | None = None, key_mapping: dict | None = None, device: torch.device = "cpu", method: str | None = None) -> dict: # note: files is list-of-files within a module for chunked loading, not accross model if isinstance(files, str): files = [files] diff --git a/modules/sdnq/layers/conv/conv_fp16.py b/modules/sdnq/layers/conv/conv_fp16.py index 31beb017e..4f6cdecef 100644 --- a/modules/sdnq/layers/conv/conv_fp16.py +++ b/modules/sdnq/layers/conv/conv_fp16.py @@ -20,11 +20,11 @@ def conv_fp16_matmul( padding_mode: str, conv_type: int, groups: int, stride: list[int], padding: list[int], dilation: list[int], - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: return_dtype = input.dtype input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation) diff --git a/modules/sdnq/layers/conv/conv_fp8.py b/modules/sdnq/layers/conv/conv_fp8.py index 3595a5366..2099dfd18 100644 --- a/modules/sdnq/layers/conv/conv_fp8.py +++ b/modules/sdnq/layers/conv/conv_fp8.py @@ -19,11 +19,11 @@ def conv_fp8_matmul( padding_mode: str, conv_type: int, groups: int, stride: list[int], padding: list[int], dilation: list[int], - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: return_dtype = input.dtype input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation) diff --git a/modules/sdnq/layers/conv/conv_fp8_tensorwise.py b/modules/sdnq/layers/conv/conv_fp8_tensorwise.py index 38258bff7..2ad268eb8 100644 --- a/modules/sdnq/layers/conv/conv_fp8_tensorwise.py +++ b/modules/sdnq/layers/conv/conv_fp8_tensorwise.py @@ -20,11 +20,11 @@ def conv_fp8_matmul_tensorwise( padding_mode: str, conv_type: int, groups: int, stride: list[int], padding: list[int], dilation: list[int], - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: return_dtype = input.dtype input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation) diff --git a/modules/sdnq/layers/conv/conv_int8.py b/modules/sdnq/layers/conv/conv_int8.py index df54830ef..b5b2dcebf 100644 --- a/modules/sdnq/layers/conv/conv_int8.py +++ b/modules/sdnq/layers/conv/conv_int8.py @@ -20,11 +20,11 @@ def conv_int8_matmul( padding_mode: str, conv_type: int, groups: int, stride: list[int], padding: list[int], dilation: list[int], - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: return_dtype = input.dtype input, mm_output_shape = process_conv_input(conv_type, input, reversed_padding_repeated_twice, padding_mode, result_shape, stride, padding, dilation) diff --git a/modules/sdnq/layers/conv/forward.py b/modules/sdnq/layers/conv/forward.py index 9f6336173..2504b9090 100644 --- a/modules/sdnq/layers/conv/forward.py +++ b/modules/sdnq/layers/conv/forward.py @@ -76,16 +76,16 @@ def quantized_conv_forward(self, input) -> torch.FloatTensor: return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias) -def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: list[int] = None) -> torch.FloatTensor: +def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor: output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation) return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) -def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: list[int] = None) -> torch.FloatTensor: +def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor: output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation) return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) -def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: list[int] = None) -> torch.FloatTensor: +def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor: output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation) return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) diff --git a/modules/sdnq/layers/linear/linear_fp16.py b/modules/sdnq/layers/linear/linear_fp16.py index 2999d09cb..705aaeb6f 100644 --- a/modules/sdnq/layers/linear/linear_fp16.py +++ b/modules/sdnq/layers/linear/linear_fp16.py @@ -14,11 +14,11 @@ def fp16_matmul( input: torch.FloatTensor, weight: torch.Tensor, scale: torch.FloatTensor, - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: if quantized_weight_shape is not None: weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float16).t_() diff --git a/modules/sdnq/layers/linear/linear_fp8.py b/modules/sdnq/layers/linear/linear_fp8.py index c0b005b75..132dcb647 100644 --- a/modules/sdnq/layers/linear/linear_fp8.py +++ b/modules/sdnq/layers/linear/linear_fp8.py @@ -19,11 +19,11 @@ def fp8_matmul( input: torch.FloatTensor, weight: torch.Tensor, scale: torch.FloatTensor, - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: if quantized_weight_shape is not None: weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_() diff --git a/modules/sdnq/layers/linear/linear_fp8_tensorwise.py b/modules/sdnq/layers/linear/linear_fp8_tensorwise.py index 9977fbe7c..235dde48d 100644 --- a/modules/sdnq/layers/linear/linear_fp8_tensorwise.py +++ b/modules/sdnq/layers/linear/linear_fp8_tensorwise.py @@ -22,11 +22,11 @@ def fp8_matmul_tensorwise( input: torch.FloatTensor, weight: torch.Tensor, scale: torch.FloatTensor, - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: if quantized_weight_shape is not None: weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_() diff --git a/modules/sdnq/layers/linear/linear_int8.py b/modules/sdnq/layers/linear/linear_int8.py index 21eed8e10..e34222c5f 100644 --- a/modules/sdnq/layers/linear/linear_int8.py +++ b/modules/sdnq/layers/linear/linear_int8.py @@ -22,11 +22,11 @@ def int8_matmul( input: torch.FloatTensor, weight: torch.Tensor, scale: torch.FloatTensor, - bias: torch.FloatTensor = None, - svd_up: torch.FloatTensor = None, - svd_down: torch.FloatTensor = None, - quantized_weight_shape: torch.Size = None, - weights_dtype: str = None, + bias: torch.FloatTensor | None = None, + svd_up: torch.FloatTensor | None = None, + svd_down: torch.FloatTensor | None = None, + quantized_weight_shape: torch.Size | None = None, + weights_dtype: str | None = None, ) -> torch.FloatTensor: if quantized_weight_shape is not None: weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_() diff --git a/modules/sdnq/loader.py b/modules/sdnq/loader.py index 0edda40c0..1b3c5ae3e 100644 --- a/modules/sdnq/loader.py +++ b/modules/sdnq/loader.py @@ -25,7 +25,7 @@ def unset_config_on_save(quantization_config: SDNQConfig) -> SDNQConfig: return quantization_config -def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "5GB", is_pipeline: bool = False, sdnq_config: SDNQConfig = None) -> None: +def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "5GB", is_pipeline: bool = False, sdnq_config: SDNQConfig | None = None) -> None: if is_pipeline: for module_name in get_module_names(model): module = getattr(model, module_name, None) @@ -63,7 +63,7 @@ def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "5 model.config.quantization_config.to_json_file(quantization_config_path) -def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: str = None, dtype: torch.dtype = None, device: torch.device = "cpu", dequantize_fp32: bool = None, use_quantized_matmul: bool = None, model_config: dict = None, quantization_config: dict = None, load_method: str = "safetensors") -> ModelMixin: +def load_sdnq_model(model_path: str, model_cls: ModelMixin | None = None, file_name: str | None = None, dtype: torch.dtype | None = None, device: torch.device = "cpu", dequantize_fp32: bool | None = None, use_quantized_matmul: bool | None = None, model_config: dict | None = None, quantization_config: dict | None = None, load_method: str = "safetensors") -> ModelMixin: from accelerate import init_empty_weights with init_empty_weights(): @@ -162,7 +162,7 @@ def post_process_model(model): return model -def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp32: bool = None, use_quantized_matmul: bool = None): +def apply_sdnq_options_to_module(model, dtype: torch.dtype | None = None, dequantize_fp32: bool | None = None, use_quantized_matmul: bool | None = None): has_children = list(model.children()) if not has_children: if dtype is not None and getattr(model, "dtype", torch.float32) not in {torch.float32, torch.float64}: @@ -231,7 +231,7 @@ def apply_sdnq_options_to_module(model, dtype: torch.dtype = None, dequantize_fp return model -def apply_sdnq_options_to_model(model, dtype: torch.dtype = None, dequantize_fp32: bool = None, use_quantized_matmul: bool = None): +def apply_sdnq_options_to_model(model, dtype: torch.dtype | None = None, dequantize_fp32: bool | None = None, use_quantized_matmul: bool | None = None): if use_quantized_matmul and not check_torch_compile(): raise RuntimeError("SDNQ Quantized MatMul requires a working Triton install.") model = apply_sdnq_options_to_module(model, dtype=dtype, dequantize_fp32=dequantize_fp32, use_quantized_matmul=use_quantized_matmul) diff --git a/modules/sdnq/quantizer.py b/modules/sdnq/quantizer.py index e4b462ad0..13c89d77a 100644 --- a/modules/sdnq/quantizer.py +++ b/modules/sdnq/quantizer.py @@ -189,7 +189,7 @@ def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: dict[str, dict]) return quant_kwargs -def add_module_skip_keys(model, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None): +def add_module_skip_keys(model, modules_to_not_convert: list[str] | None = None, modules_dtype_dict: dict[str, list[str]] | None = None): if modules_to_not_convert is None: modules_to_not_convert = [] if modules_dtype_dict is None: @@ -547,7 +547,7 @@ def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None @devices.inference_context() -def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=True, non_blocking=False, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None, modules_quant_config: dict[str, dict] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument +def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=True, non_blocking=False, modules_to_not_convert: list[str] | None = None, modules_dtype_dict: dict[str, list[str]] | None = None, modules_quant_config: dict[str, dict] | None = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument has_children = list(model.children()) if not has_children: return model, modules_to_not_convert, modules_dtype_dict @@ -628,8 +628,8 @@ def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=Non def sdnq_post_load_quant( model: torch.nn.Module, weights_dtype: str = "int8", - quantized_matmul_dtype: str = None, - torch_dtype: torch.dtype = None, + quantized_matmul_dtype: str | None = None, + torch_dtype: torch.dtype | None = None, group_size: int = 0, svd_rank: int = 32, svd_steps: int = 8, @@ -643,11 +643,11 @@ def sdnq_post_load_quant( dequantize_fp32: bool = True, non_blocking: bool = False, add_skip_keys:bool = True, - quantization_device: torch.device = None, - return_device: torch.device = None, - modules_to_not_convert: list[str] = None, - modules_dtype_dict: dict[str, list[str]] = None, - modules_quant_config: dict[str, dict] = None, + quantization_device: torch.device | None = None, + return_device: torch.device | None = None, + modules_to_not_convert: list[str] | None = None, + modules_dtype_dict: dict[str, list[str]] | None = None, + modules_quant_config: dict[str, dict] | None = None, ): if modules_to_not_convert is None: modules_to_not_convert = [] @@ -735,9 +735,9 @@ class SDNQQuantize: def convert( self, input_dict: dict[str, list[torch.Tensor]], - model: torch.nn.Module = None, - full_layer_name: str = None, - missing_keys: list[str] = None, # pylint: disable=unused-argument + model: torch.nn.Module | None = None, + full_layer_name: str | None = None, + missing_keys: list[str] | None = None, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ) -> dict[str, torch.FloatTensor]: _module_name, value = tuple(input_dict.items())[0] @@ -898,11 +898,11 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer): def adjust_target_dtype(self, target_dtype: torch.dtype) -> torch.dtype: # pylint: disable=unused-argument,arguments-renamed return dtype_dict[self.quantization_config.weights_dtype]["target_dtype"] - def update_torch_dtype(self, torch_dtype: torch.dtype = None) -> torch.dtype: + def update_torch_dtype(self, torch_dtype: torch.dtype | None = None) -> torch.dtype: self.torch_dtype = torch_dtype return torch_dtype - def update_dtype(self, dtype: torch.dtype = None) -> torch.dtype: + def update_dtype(self, dtype: torch.dtype | None = None) -> torch.dtype: """ needed for transformers compatibilty, returns self.update_torch_dtype """ @@ -912,7 +912,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer): self, model, device_map, # pylint: disable=unused-argument - keep_in_fp32_modules: list[str] = None, + keep_in_fp32_modules: list[str] | None = None, **kwargs, # pylint: disable=unused-argument ): if self.pre_quantized: @@ -1055,7 +1055,7 @@ class SDNQConfig(QuantizationConfigMixin): def __init__( # pylint: disable=super-init-not-called self, weights_dtype: str = "int8", - quantized_matmul_dtype: str = None, + quantized_matmul_dtype: str | None = None, group_size: int = 0, svd_rank: int = 32, svd_steps: int = 8, @@ -1071,11 +1071,11 @@ class SDNQConfig(QuantizationConfigMixin): dequantize_fp32: bool = True, non_blocking: bool = False, add_skip_keys: bool = True, - quantization_device: torch.device = None, - return_device: torch.device = None, - modules_to_not_convert: list[str] = None, - modules_dtype_dict: dict[str, list[str]] = None, - modules_quant_config: dict[str, dict] = None, + quantization_device: torch.device | None = None, + return_device: torch.device | None = None, + modules_to_not_convert: list[str] | None = None, + modules_dtype_dict: dict[str, list[str]] | None = None, + modules_quant_config: dict[str, dict] | None = None, is_training: bool = False, **kwargs, # pylint: disable=unused-argument ):