Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4252/head
Vladimir Mandic 2025-10-05 20:25:33 -04:00
parent 28e3ae0480
commit a315a004e9
7 changed files with 63 additions and 62 deletions

View File

@ -186,6 +186,8 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
self.svd_down.to(devices.device) if self.svd_down is not None else None,
skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul
)
else:
weights_dtype = devices.dtype
new_weight = dequant_weight.to(devices.device, dtype=torch.float32) + lora_weights.to(devices.device, dtype=torch.float32)
self.weight = torch.nn.Parameter(new_weight, requires_grad=False)

View File

@ -16,7 +16,7 @@ def save_sdnq_model(model: ModelMixin, sdnq_config: SDNQConfig, model_path: str,
def load_sdnq_model(model_cls: ModelMixin, model_path: str, file_name: str = "diffusion_pytorch_model.safetensors", use_quantized_matmul: bool = False) -> ModelMixin:
with torch.device("meta"):
with open(os.path.join(model_path, "quantization_config.json"), "r") as f:
with open(os.path.join(model_path, "quantization_config.json"), "r", encoding="utf-8") as f:
quantization_config = json.load(f)
quantization_config.pop("is_integer", None)
quantization_config.pop("quant_method", None)

View File

@ -252,13 +252,12 @@ def apply_sdnq_to_module(model, weights_dtype="int8", torch_dtype=None, group_si
or any("*" in param and re.match(param.replace(".*", "\\.*").replace("*", ".*"), param_name) for param in modules_to_not_convert)
):
continue
else:
layer_class_name = module.__class__.__name__
if layer_class_name in allowed_types:
if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv:
continue
else:
layer_class_name = module.__class__.__name__
if layer_class_name in allowed_types:
if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv:
continue
else:
continue
if len(modules_dtype_dict.keys()) > 0:
for key, value in modules_dtype_dict.items():
if param_name in value:

View File

@ -497,7 +497,7 @@ def teacache_forward(
class FluxPipelineWithSigLIP(FluxPipeline):
@torch.no_grad()
def __call__(
self,

View File

@ -13,12 +13,12 @@ from .modeling_vit import create_siglip_vit
def create_anyres_preprocess(
short_size=384,
long_size=1152,
patch_size=16,
random_ratio=None,
min_short_size=128,
max_aspect_ratio=3.,
short_size=384,
long_size=1152,
patch_size=16,
random_ratio=None,
min_short_size=128,
max_aspect_ratio=3.,
filtering=True
):
@ -35,21 +35,21 @@ def create_anyres_preprocess(
sqrt_ratio = torch.exp(0.5 * torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
target_width = int(round(target_width * sqrt_ratio))
target_height = int(round(target_height / sqrt_ratio))
ss = min(target_width, target_height)
if ss < short_size:
target_width = target_width * (short_size / ss)
target_height = target_height * (short_size / ss)
ls = max(target_width, target_height)
if ls > long_size:
target_width = target_width * (long_size / ls)
target_height = target_height * (long_size / ls)
target_width = int(round(target_width / patch_size)) * patch_size
target_height = int(round(target_height / patch_size)) * patch_size
pil_image = pil_image.resize((target_width, target_height), resample=Image.BICUBIC)
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
@ -75,7 +75,7 @@ class IBQ(nn.Module):
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm:
self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False, **kwargs):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits == False, "Only for interface compatible with Gumbel"
@ -96,7 +96,7 @@ class IBQ(nn.Module):
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(embedding**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
if self.training:
logits = -d / self.quantization_temp
soft_one_hot = F.softmax(logits, dim=1)
@ -114,13 +114,13 @@ class IBQ(nn.Module):
min_encoding_indices = torch.argmin(d, dim=1)
z_q = embedding[min_encoding_indices].view(z.shape)
commit_loss = None
if self.training and self.skip_quantization_prob > 0.0:
z_q = torch.where(
torch.rand_like(z_q[:, 0:1, 0:1, 0:1]).expand_as(z_q) <= self.skip_quantization_prob,
z, z_q,
)
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
@ -150,7 +150,7 @@ class ResidualBlock(nn.Module):
self.activate = nn.GELU()
self.conv2 = nn.Conv2d(channels, channels, 3, padding='same')
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
def forward(self, x):
res = x
x = self.norm1(x)
@ -164,10 +164,10 @@ class ResidualBlock(nn.Module):
class VQConvProjector(nn.Module):
def __init__(
self,
z_channels=1536,
codebook_size=16384,
codebook_dim=2048,
self,
z_channels=1536,
codebook_size=16384,
codebook_dim=2048,
conv_layers=2,
with_norm=True,
skip_quant_prob=0.1,
@ -178,7 +178,7 @@ class VQConvProjector(nn.Module):
self.post_quant_conv = nn.Conv2d(codebook_dim, z_channels, 1)
block = ResidualBlock
self.post_conv = nn.Sequential(*[block(z_channels) for _ in range(conv_layers)])
def forward(self, x, h, w):
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
z = self.quant_conv(x)
@ -187,37 +187,37 @@ class VQConvProjector(nn.Module):
z = self.post_conv(z)
z = rearrange(z, 'b c h w -> b (h w) c')
return z, codebook_loss
def encode(self, x, h, w):
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
z = self.quant_conv(x)
(_, _, tokens), _ = self.quantize(z)
return tokens
def decode(self, tokens, bhwc):
z_q = self.quantize.get_codebook_entry(tokens, bhwc)
z = self.post_quant_conv(z_q)
z = self.post_conv(z)
z = self.post_conv(z)
return z
class SiglipTokenizer(nn.Module):
def __init__(
self,
siglip_name,
siglip_path,
projector_path,
z_channels=1536,
codebook_size=16384,
codebook_dim=2048,
self,
siglip_name,
siglip_path,
projector_path,
z_channels=1536,
codebook_size=16384,
codebook_dim=2048,
with_norm=True
):
super().__init__()
self.vit = create_siglip_vit(model_name=siglip_name, path=siglip_path)
self.vqproj = VQConvProjector(
z_channels=z_channels,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
z_channels=z_channels,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
with_norm=with_norm
)
self.vqproj.load_state_dict(torch.load(projector_path, map_location='cpu'), strict=True)
@ -226,6 +226,6 @@ class SiglipTokenizer(nn.Module):
features, (h, w), _ = self.vit(x)
tokens = self.vqproj.encode(features, h, w)
return tokens
def decode(self, tokens, bhwc):
return self.vqproj.decode(tokens, bhwc)

View File

@ -541,7 +541,7 @@ class VisionTransformer(nn.Module):
for x in x_list:
bs, _, h, w = x.shape
# fix patch size=14 in datasets
# fix patch size=14 in datasets
pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
@ -578,7 +578,7 @@ class VisionTransformer(nn.Module):
bs, _, h, w = x.shape
h = h // self.patch_embed.patch_size[0]
w = w // self.patch_embed.patch_size[1]
x = self.patch_embed(x)
# x = self._pos_embed(x)
x = x + self.rescale_positional_embedding(out_size=(h, w))
@ -658,7 +658,7 @@ def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bic
pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
return model
return model
def create_siglip_vit(

View File

@ -32,7 +32,7 @@ class XOmniDecoderLayer(Qwen2DecoderLayer):
output_hidden_states, *others = super().forward(hidden_states, **kwargs)
output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1)
return output_hidden_states, *others
# mm_hidden_states = torch.where(multimodal_mask.bool(), hidden_states, torch.zeros_like(hidden_states))
output_hidden_states, *others = super().forward(hidden_states, **kwargs)
output_hidden_states = torch.where(multimodal_mask.bool(), output_hidden_states, hidden_states)
@ -48,7 +48,7 @@ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel):
Qwen2PreTrainedModel.__init__(self, config)
self.padding_idx = -1
self.vocab_size = config.vocab_size
self.lm_embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.mm_embed_tokens = nn.Embedding(config.mm_vocab_size, config.hidden_size, self.padding_idx)
@ -94,7 +94,7 @@ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel):
class XOmniForCausalLM(Qwen2ForCausalLM):
model_type = "x-omni"
config_class = XOmniConfig
_keys_to_ignore_on_load_missing = r'image_tokenizer\.*'
def __init__(self, config):
@ -102,7 +102,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
self.model = XOmniModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.mm_head = nn.Linear(config.hidden_size, config.mm_vocab_size, bias=False)
self.generation_mode = 'text'
# Initialize weights and apply final processing
self.post_init()
@ -110,7 +110,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
@property
def device(self):
return next(iter(self.parameters())).device
def init_vision(self, flux_pipe_path, **kwargs):
self.som_token = self.config.mm_special_tokens[0]
self.eom_token = self.config.mm_special_tokens[1]
@ -125,7 +125,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
self.vision_dtype = dtype_map[self.vision_config.dtype]
self.image_transform = create_anyres_preprocess(**self.vision_config.transform)
self.encoder_config.siglip_path = os.path.join(self.name_or_path, self.encoder_config.siglip_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.siglip_path)
self.encoder_config.projector_path = os.path.join(self.name_or_path, self.encoder_config.projector_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.projector_path)
@ -133,15 +133,15 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
self.image_tokenizer.to(self.device, self.vision_dtype)
transformer = FluxTransformer2DModelWithSigLIP.from_pretrained(
self.name_or_path,
siglip_channels=self.encoder_config.z_channels,
self.name_or_path,
siglip_channels=self.encoder_config.z_channels,
torch_dtype=self.vision_dtype,
subfolder=self.decoder_config.model_path,
**kwargs,
)
self.decoder_pipe = FluxPipelineWithSigLIP.from_pretrained(
flux_pipe_path,
flux_pipe_path,
transformer=transformer,
torch_dtype=self.vision_dtype,
)
@ -161,7 +161,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
if len(images) > 0:
doc += self.tokenize_image(images.pop(0))
return tokenizer.encode(doc, **kwargs)
def mmdecode(self, tokenizer, token_ids, force_text=None, **kwargs):
force_text = force_text or []
if isinstance(token_ids, torch.Tensor):
@ -174,7 +174,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
assert len(token_ids) == 1
token_ids = token_ids[0]
assert isinstance(token_ids[0], int)
doc = tokenizer.decode(token_ids, **kwargs)
doc = doc.replace(tokenizer.pad_token, '')
doc = doc.replace('<SEP>', '')
@ -197,7 +197,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
image = self.detokenize_image(texts, images, token_ids, (H, W))
images.append(image)
return texts, images
@torch.no_grad()
def tokenize_image(self, image):
assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.'
@ -213,7 +213,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
token_str = ''.join(map(lambda x: '<MM-Token-{token_id}>'.format(token_id=x), tokens))
image_str = f'{self.som_token}{H} {W}{self.img_token}{token_str}{self.eom_token}'
return image_str
@torch.no_grad()
def detokenize_image(self, texts, images, token_ids, shape):
assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.'
@ -228,7 +228,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
[texts[0]],
negative_prompt=[''],
height=H * upscale_factor, width=W * upscale_factor,
num_inference_steps=self.decoder_config.num_inference_steps,
num_inference_steps=self.decoder_config.num_inference_steps,
guidance_scale=1.0,
true_cfg_scale=self.decoder_config.cfg_scale,
true_cfg_scale_2=self.decoder_config.cfg_scale_2,
@ -236,7 +236,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
return image
def forward(
self,
input_ids: torch.LongTensor = None,
@ -275,14 +275,14 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
hidden_states = outputs[0]
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
logits = hidden_states.new_full(
(*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size),
(*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size),
torch.finfo(hidden_states.dtype).min
)
if self.generation_mode == 'text':
logits[:, :, :self.config.vocab_size] = self.lm_head(hidden_states)
else:
logits[:, :, self.config.vocab_size:self.config.vocab_size + self.config.image_vocab_size] = self.mm_head(hidden_states)[:, :, :self.config.image_vocab_size]
logits = logits.float()
loss = None