diff --git a/modules/lora/lora_apply.py b/modules/lora/lora_apply.py index d076efb97..51be0a74f 100644 --- a/modules/lora/lora_apply.py +++ b/modules/lora/lora_apply.py @@ -186,6 +186,8 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G self.svd_down.to(devices.device) if self.svd_down is not None else None, skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul ) + else: + weights_dtype = devices.dtype new_weight = dequant_weight.to(devices.device, dtype=torch.float32) + lora_weights.to(devices.device, dtype=torch.float32) self.weight = torch.nn.Parameter(new_weight, requires_grad=False) diff --git a/modules/sdnq/loader.py b/modules/sdnq/loader.py index aabd2496c..a7a7f3ae6 100644 --- a/modules/sdnq/loader.py +++ b/modules/sdnq/loader.py @@ -16,7 +16,7 @@ def save_sdnq_model(model: ModelMixin, sdnq_config: SDNQConfig, model_path: str, def load_sdnq_model(model_cls: ModelMixin, model_path: str, file_name: str = "diffusion_pytorch_model.safetensors", use_quantized_matmul: bool = False) -> ModelMixin: with torch.device("meta"): - with open(os.path.join(model_path, "quantization_config.json"), "r") as f: + with open(os.path.join(model_path, "quantization_config.json"), "r", encoding="utf-8") as f: quantization_config = json.load(f) quantization_config.pop("is_integer", None) quantization_config.pop("quant_method", None) diff --git a/modules/sdnq/quantizer.py b/modules/sdnq/quantizer.py index a92cb455a..d4769821e 100644 --- a/modules/sdnq/quantizer.py +++ b/modules/sdnq/quantizer.py @@ -252,13 +252,12 @@ def apply_sdnq_to_module(model, weights_dtype="int8", torch_dtype=None, group_si or any("*" in param and re.match(param.replace(".*", "\\.*").replace("*", ".*"), param_name) for param in modules_to_not_convert) ): continue - else: - layer_class_name = module.__class__.__name__ - if layer_class_name in allowed_types: - if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv: - continue - else: + layer_class_name = module.__class__.__name__ + if layer_class_name in allowed_types: + if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv: continue + else: + continue if len(modules_dtype_dict.keys()) > 0: for key, value in modules_dtype_dict.items(): if param_name in value: diff --git a/pipelines/xomni/modeling_siglip_flux.py b/pipelines/xomni/modeling_siglip_flux.py index 8db60374a..18b7b3463 100644 --- a/pipelines/xomni/modeling_siglip_flux.py +++ b/pipelines/xomni/modeling_siglip_flux.py @@ -497,7 +497,7 @@ def teacache_forward( class FluxPipelineWithSigLIP(FluxPipeline): - + @torch.no_grad() def __call__( self, diff --git a/pipelines/xomni/modeling_siglip_tokenizer.py b/pipelines/xomni/modeling_siglip_tokenizer.py index 26e67571d..5e99d4ff8 100644 --- a/pipelines/xomni/modeling_siglip_tokenizer.py +++ b/pipelines/xomni/modeling_siglip_tokenizer.py @@ -13,12 +13,12 @@ from .modeling_vit import create_siglip_vit def create_anyres_preprocess( - short_size=384, - long_size=1152, - patch_size=16, - random_ratio=None, - min_short_size=128, - max_aspect_ratio=3., + short_size=384, + long_size=1152, + patch_size=16, + random_ratio=None, + min_short_size=128, + max_aspect_ratio=3., filtering=True ): @@ -35,21 +35,21 @@ def create_anyres_preprocess( sqrt_ratio = torch.exp(0.5 * torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() target_width = int(round(target_width * sqrt_ratio)) target_height = int(round(target_height / sqrt_ratio)) - + ss = min(target_width, target_height) if ss < short_size: target_width = target_width * (short_size / ss) target_height = target_height * (short_size / ss) - + ls = max(target_width, target_height) if ls > long_size: target_width = target_width * (long_size / ls) target_height = target_height * (long_size / ls) - + target_width = int(round(target_width / patch_size)) * patch_size target_height = int(round(target_height / patch_size)) * patch_size pil_image = pil_image.resize((target_width, target_height), resample=Image.BICUBIC) - + to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), @@ -75,7 +75,7 @@ class IBQ(nn.Module): self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) if self.l2_norm: self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1) - + def forward(self, z, temp=None, rescale_logits=False, return_logits=False, **kwargs): assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" assert rescale_logits == False, "Only for interface compatible with Gumbel" @@ -96,7 +96,7 @@ class IBQ(nn.Module): d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ torch.sum(embedding**2, dim=1) - 2 * \ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding)) - + if self.training: logits = -d / self.quantization_temp soft_one_hot = F.softmax(logits, dim=1) @@ -114,13 +114,13 @@ class IBQ(nn.Module): min_encoding_indices = torch.argmin(d, dim=1) z_q = embedding[min_encoding_indices].view(z.shape) commit_loss = None - + if self.training and self.skip_quantization_prob > 0.0: z_q = torch.where( torch.rand_like(z_q[:, 0:1, 0:1, 0:1]).expand_as(z_q) <= self.skip_quantization_prob, z, z_q, ) - + # reshape back to match original input shape z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() @@ -150,7 +150,7 @@ class ResidualBlock(nn.Module): self.activate = nn.GELU() self.conv2 = nn.Conv2d(channels, channels, 3, padding='same') self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels) - + def forward(self, x): res = x x = self.norm1(x) @@ -164,10 +164,10 @@ class ResidualBlock(nn.Module): class VQConvProjector(nn.Module): def __init__( - self, - z_channels=1536, - codebook_size=16384, - codebook_dim=2048, + self, + z_channels=1536, + codebook_size=16384, + codebook_dim=2048, conv_layers=2, with_norm=True, skip_quant_prob=0.1, @@ -178,7 +178,7 @@ class VQConvProjector(nn.Module): self.post_quant_conv = nn.Conv2d(codebook_dim, z_channels, 1) block = ResidualBlock self.post_conv = nn.Sequential(*[block(z_channels) for _ in range(conv_layers)]) - + def forward(self, x, h, w): x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) z = self.quant_conv(x) @@ -187,37 +187,37 @@ class VQConvProjector(nn.Module): z = self.post_conv(z) z = rearrange(z, 'b c h w -> b (h w) c') return z, codebook_loss - + def encode(self, x, h, w): x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) z = self.quant_conv(x) (_, _, tokens), _ = self.quantize(z) return tokens - + def decode(self, tokens, bhwc): z_q = self.quantize.get_codebook_entry(tokens, bhwc) z = self.post_quant_conv(z_q) - z = self.post_conv(z) + z = self.post_conv(z) return z class SiglipTokenizer(nn.Module): def __init__( - self, - siglip_name, - siglip_path, - projector_path, - z_channels=1536, - codebook_size=16384, - codebook_dim=2048, + self, + siglip_name, + siglip_path, + projector_path, + z_channels=1536, + codebook_size=16384, + codebook_dim=2048, with_norm=True ): super().__init__() self.vit = create_siglip_vit(model_name=siglip_name, path=siglip_path) self.vqproj = VQConvProjector( - z_channels=z_channels, - codebook_size=codebook_size, - codebook_dim=codebook_dim, + z_channels=z_channels, + codebook_size=codebook_size, + codebook_dim=codebook_dim, with_norm=with_norm ) self.vqproj.load_state_dict(torch.load(projector_path, map_location='cpu'), strict=True) @@ -226,6 +226,6 @@ class SiglipTokenizer(nn.Module): features, (h, w), _ = self.vit(x) tokens = self.vqproj.encode(features, h, w) return tokens - + def decode(self, tokens, bhwc): return self.vqproj.decode(tokens, bhwc) diff --git a/pipelines/xomni/modeling_vit.py b/pipelines/xomni/modeling_vit.py index 722851a41..150571c1d 100644 --- a/pipelines/xomni/modeling_vit.py +++ b/pipelines/xomni/modeling_vit.py @@ -541,7 +541,7 @@ class VisionTransformer(nn.Module): for x in x_list: bs, _, h, w = x.shape - # fix patch size=14 in datasets + # fix patch size=14 in datasets pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0] pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1] x = F.pad(x, (0, pad_w, 0, pad_h)) @@ -578,7 +578,7 @@ class VisionTransformer(nn.Module): bs, _, h, w = x.shape h = h // self.patch_embed.patch_size[0] w = w // self.patch_embed.patch_size[1] - + x = self.patch_embed(x) # x = self._pos_embed(x) x = x + self.rescale_positional_embedding(out_size=(h, w)) @@ -658,7 +658,7 @@ def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bic pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True) - return model + return model def create_siglip_vit( diff --git a/pipelines/xomni/modeling_xomni.py b/pipelines/xomni/modeling_xomni.py index 2e05dec0e..7e46f781c 100644 --- a/pipelines/xomni/modeling_xomni.py +++ b/pipelines/xomni/modeling_xomni.py @@ -32,7 +32,7 @@ class XOmniDecoderLayer(Qwen2DecoderLayer): output_hidden_states, *others = super().forward(hidden_states, **kwargs) output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1) return output_hidden_states, *others - + # mm_hidden_states = torch.where(multimodal_mask.bool(), hidden_states, torch.zeros_like(hidden_states)) output_hidden_states, *others = super().forward(hidden_states, **kwargs) output_hidden_states = torch.where(multimodal_mask.bool(), output_hidden_states, hidden_states) @@ -48,7 +48,7 @@ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel): Qwen2PreTrainedModel.__init__(self, config) self.padding_idx = -1 self.vocab_size = config.vocab_size - + self.lm_embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.mm_embed_tokens = nn.Embedding(config.mm_vocab_size, config.hidden_size, self.padding_idx) @@ -94,7 +94,7 @@ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel): class XOmniForCausalLM(Qwen2ForCausalLM): model_type = "x-omni" config_class = XOmniConfig - + _keys_to_ignore_on_load_missing = r'image_tokenizer\.*' def __init__(self, config): @@ -102,7 +102,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): self.model = XOmniModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.mm_head = nn.Linear(config.hidden_size, config.mm_vocab_size, bias=False) - + self.generation_mode = 'text' # Initialize weights and apply final processing self.post_init() @@ -110,7 +110,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): @property def device(self): return next(iter(self.parameters())).device - + def init_vision(self, flux_pipe_path, **kwargs): self.som_token = self.config.mm_special_tokens[0] self.eom_token = self.config.mm_special_tokens[1] @@ -125,7 +125,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): self.vision_dtype = dtype_map[self.vision_config.dtype] self.image_transform = create_anyres_preprocess(**self.vision_config.transform) - + self.encoder_config.siglip_path = os.path.join(self.name_or_path, self.encoder_config.siglip_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.siglip_path) self.encoder_config.projector_path = os.path.join(self.name_or_path, self.encoder_config.projector_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.projector_path) @@ -133,15 +133,15 @@ class XOmniForCausalLM(Qwen2ForCausalLM): self.image_tokenizer.to(self.device, self.vision_dtype) transformer = FluxTransformer2DModelWithSigLIP.from_pretrained( - self.name_or_path, - siglip_channels=self.encoder_config.z_channels, + self.name_or_path, + siglip_channels=self.encoder_config.z_channels, torch_dtype=self.vision_dtype, subfolder=self.decoder_config.model_path, **kwargs, ) self.decoder_pipe = FluxPipelineWithSigLIP.from_pretrained( - flux_pipe_path, + flux_pipe_path, transformer=transformer, torch_dtype=self.vision_dtype, ) @@ -161,7 +161,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): if len(images) > 0: doc += self.tokenize_image(images.pop(0)) return tokenizer.encode(doc, **kwargs) - + def mmdecode(self, tokenizer, token_ids, force_text=None, **kwargs): force_text = force_text or [] if isinstance(token_ids, torch.Tensor): @@ -174,7 +174,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): assert len(token_ids) == 1 token_ids = token_ids[0] assert isinstance(token_ids[0], int) - + doc = tokenizer.decode(token_ids, **kwargs) doc = doc.replace(tokenizer.pad_token, '') doc = doc.replace('', '') @@ -197,7 +197,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): image = self.detokenize_image(texts, images, token_ids, (H, W)) images.append(image) return texts, images - + @torch.no_grad() def tokenize_image(self, image): assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.' @@ -213,7 +213,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): token_str = ''.join(map(lambda x: ''.format(token_id=x), tokens)) image_str = f'{self.som_token}{H} {W}{self.img_token}{token_str}{self.eom_token}' return image_str - + @torch.no_grad() def detokenize_image(self, texts, images, token_ids, shape): assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.' @@ -228,7 +228,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): [texts[0]], negative_prompt=[''], height=H * upscale_factor, width=W * upscale_factor, - num_inference_steps=self.decoder_config.num_inference_steps, + num_inference_steps=self.decoder_config.num_inference_steps, guidance_scale=1.0, true_cfg_scale=self.decoder_config.cfg_scale, true_cfg_scale_2=self.decoder_config.cfg_scale_2, @@ -236,7 +236,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM): return image - + def forward( self, input_ids: torch.LongTensor = None, @@ -275,14 +275,14 @@ class XOmniForCausalLM(Qwen2ForCausalLM): hidden_states = outputs[0] hidden_states = hidden_states[:, -num_logits_to_keep:, :] logits = hidden_states.new_full( - (*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size), + (*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size), torch.finfo(hidden_states.dtype).min ) if self.generation_mode == 'text': logits[:, :, :self.config.vocab_size] = self.lm_head(hidden_states) else: logits[:, :, self.config.vocab_size:self.config.vocab_size + self.config.image_vocab_size] = self.mm_head(hidden_states)[:, :, :self.config.image_vocab_size] - + logits = logits.float() loss = None