mirror of https://github.com/vladmandic/automatic
parent
28e3ae0480
commit
a315a004e9
|
|
@ -186,6 +186,8 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
|
|||
self.svd_down.to(devices.device) if self.svd_down is not None else None,
|
||||
skip_quantized_matmul=self.sdnq_dequantizer.use_quantized_matmul
|
||||
)
|
||||
else:
|
||||
weights_dtype = devices.dtype
|
||||
|
||||
new_weight = dequant_weight.to(devices.device, dtype=torch.float32) + lora_weights.to(devices.device, dtype=torch.float32)
|
||||
self.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def save_sdnq_model(model: ModelMixin, sdnq_config: SDNQConfig, model_path: str,
|
|||
|
||||
def load_sdnq_model(model_cls: ModelMixin, model_path: str, file_name: str = "diffusion_pytorch_model.safetensors", use_quantized_matmul: bool = False) -> ModelMixin:
|
||||
with torch.device("meta"):
|
||||
with open(os.path.join(model_path, "quantization_config.json"), "r") as f:
|
||||
with open(os.path.join(model_path, "quantization_config.json"), "r", encoding="utf-8") as f:
|
||||
quantization_config = json.load(f)
|
||||
quantization_config.pop("is_integer", None)
|
||||
quantization_config.pop("quant_method", None)
|
||||
|
|
|
|||
|
|
@ -252,13 +252,12 @@ def apply_sdnq_to_module(model, weights_dtype="int8", torch_dtype=None, group_si
|
|||
or any("*" in param and re.match(param.replace(".*", "\\.*").replace("*", ".*"), param_name) for param in modules_to_not_convert)
|
||||
):
|
||||
continue
|
||||
else:
|
||||
layer_class_name = module.__class__.__name__
|
||||
if layer_class_name in allowed_types:
|
||||
if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv:
|
||||
continue
|
||||
else:
|
||||
layer_class_name = module.__class__.__name__
|
||||
if layer_class_name in allowed_types:
|
||||
if (layer_class_name in conv_types or layer_class_name in conv_transpose_types) and not quant_conv:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
if len(modules_dtype_dict.keys()) > 0:
|
||||
for key, value in modules_dtype_dict.items():
|
||||
if param_name in value:
|
||||
|
|
|
|||
|
|
@ -497,7 +497,7 @@ def teacache_forward(
|
|||
|
||||
|
||||
class FluxPipelineWithSigLIP(FluxPipeline):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@ from .modeling_vit import create_siglip_vit
|
|||
|
||||
|
||||
def create_anyres_preprocess(
|
||||
short_size=384,
|
||||
long_size=1152,
|
||||
patch_size=16,
|
||||
random_ratio=None,
|
||||
min_short_size=128,
|
||||
max_aspect_ratio=3.,
|
||||
short_size=384,
|
||||
long_size=1152,
|
||||
patch_size=16,
|
||||
random_ratio=None,
|
||||
min_short_size=128,
|
||||
max_aspect_ratio=3.,
|
||||
filtering=True
|
||||
):
|
||||
|
||||
|
|
@ -35,21 +35,21 @@ def create_anyres_preprocess(
|
|||
sqrt_ratio = torch.exp(0.5 * torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
|
||||
target_width = int(round(target_width * sqrt_ratio))
|
||||
target_height = int(round(target_height / sqrt_ratio))
|
||||
|
||||
|
||||
ss = min(target_width, target_height)
|
||||
if ss < short_size:
|
||||
target_width = target_width * (short_size / ss)
|
||||
target_height = target_height * (short_size / ss)
|
||||
|
||||
|
||||
ls = max(target_width, target_height)
|
||||
if ls > long_size:
|
||||
target_width = target_width * (long_size / ls)
|
||||
target_height = target_height * (long_size / ls)
|
||||
|
||||
|
||||
target_width = int(round(target_width / patch_size)) * patch_size
|
||||
target_height = int(round(target_height / patch_size)) * patch_size
|
||||
pil_image = pil_image.resize((target_width, target_height), resample=Image.BICUBIC)
|
||||
|
||||
|
||||
to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
|
|
@ -75,7 +75,7 @@ class IBQ(nn.Module):
|
|||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
if self.l2_norm:
|
||||
self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
|
||||
|
||||
|
||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False, **kwargs):
|
||||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
||||
|
|
@ -96,7 +96,7 @@ class IBQ(nn.Module):
|
|||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||
torch.sum(embedding**2, dim=1) - 2 * \
|
||||
torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
|
||||
|
||||
|
||||
if self.training:
|
||||
logits = -d / self.quantization_temp
|
||||
soft_one_hot = F.softmax(logits, dim=1)
|
||||
|
|
@ -114,13 +114,13 @@ class IBQ(nn.Module):
|
|||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = embedding[min_encoding_indices].view(z.shape)
|
||||
commit_loss = None
|
||||
|
||||
|
||||
if self.training and self.skip_quantization_prob > 0.0:
|
||||
z_q = torch.where(
|
||||
torch.rand_like(z_q[:, 0:1, 0:1, 0:1]).expand_as(z_q) <= self.skip_quantization_prob,
|
||||
z, z_q,
|
||||
)
|
||||
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ class ResidualBlock(nn.Module):
|
|||
self.activate = nn.GELU()
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding='same')
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.norm1(x)
|
||||
|
|
@ -164,10 +164,10 @@ class ResidualBlock(nn.Module):
|
|||
|
||||
class VQConvProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels=1536,
|
||||
codebook_size=16384,
|
||||
codebook_dim=2048,
|
||||
self,
|
||||
z_channels=1536,
|
||||
codebook_size=16384,
|
||||
codebook_dim=2048,
|
||||
conv_layers=2,
|
||||
with_norm=True,
|
||||
skip_quant_prob=0.1,
|
||||
|
|
@ -178,7 +178,7 @@ class VQConvProjector(nn.Module):
|
|||
self.post_quant_conv = nn.Conv2d(codebook_dim, z_channels, 1)
|
||||
block = ResidualBlock
|
||||
self.post_conv = nn.Sequential(*[block(z_channels) for _ in range(conv_layers)])
|
||||
|
||||
|
||||
def forward(self, x, h, w):
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
z = self.quant_conv(x)
|
||||
|
|
@ -187,37 +187,37 @@ class VQConvProjector(nn.Module):
|
|||
z = self.post_conv(z)
|
||||
z = rearrange(z, 'b c h w -> b (h w) c')
|
||||
return z, codebook_loss
|
||||
|
||||
|
||||
def encode(self, x, h, w):
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
z = self.quant_conv(x)
|
||||
(_, _, tokens), _ = self.quantize(z)
|
||||
return tokens
|
||||
|
||||
|
||||
def decode(self, tokens, bhwc):
|
||||
z_q = self.quantize.get_codebook_entry(tokens, bhwc)
|
||||
z = self.post_quant_conv(z_q)
|
||||
z = self.post_conv(z)
|
||||
z = self.post_conv(z)
|
||||
return z
|
||||
|
||||
|
||||
class SiglipTokenizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
siglip_name,
|
||||
siglip_path,
|
||||
projector_path,
|
||||
z_channels=1536,
|
||||
codebook_size=16384,
|
||||
codebook_dim=2048,
|
||||
self,
|
||||
siglip_name,
|
||||
siglip_path,
|
||||
projector_path,
|
||||
z_channels=1536,
|
||||
codebook_size=16384,
|
||||
codebook_dim=2048,
|
||||
with_norm=True
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = create_siglip_vit(model_name=siglip_name, path=siglip_path)
|
||||
self.vqproj = VQConvProjector(
|
||||
z_channels=z_channels,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
z_channels=z_channels,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
with_norm=with_norm
|
||||
)
|
||||
self.vqproj.load_state_dict(torch.load(projector_path, map_location='cpu'), strict=True)
|
||||
|
|
@ -226,6 +226,6 @@ class SiglipTokenizer(nn.Module):
|
|||
features, (h, w), _ = self.vit(x)
|
||||
tokens = self.vqproj.encode(features, h, w)
|
||||
return tokens
|
||||
|
||||
|
||||
def decode(self, tokens, bhwc):
|
||||
return self.vqproj.decode(tokens, bhwc)
|
||||
|
|
|
|||
|
|
@ -541,7 +541,7 @@ class VisionTransformer(nn.Module):
|
|||
for x in x_list:
|
||||
bs, _, h, w = x.shape
|
||||
|
||||
# fix patch size=14 in datasets
|
||||
# fix patch size=14 in datasets
|
||||
pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
|
||||
pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
|
@ -578,7 +578,7 @@ class VisionTransformer(nn.Module):
|
|||
bs, _, h, w = x.shape
|
||||
h = h // self.patch_embed.patch_size[0]
|
||||
w = w // self.patch_embed.patch_size[1]
|
||||
|
||||
|
||||
x = self.patch_embed(x)
|
||||
# x = self._pos_embed(x)
|
||||
x = x + self.rescale_positional_embedding(out_size=(h, w))
|
||||
|
|
@ -658,7 +658,7 @@ def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bic
|
|||
pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
|
||||
return model
|
||||
return model
|
||||
|
||||
|
||||
def create_siglip_vit(
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class XOmniDecoderLayer(Qwen2DecoderLayer):
|
|||
output_hidden_states, *others = super().forward(hidden_states, **kwargs)
|
||||
output_hidden_states = torch.cat([output_hidden_states, multimodal_mask], dim=-1)
|
||||
return output_hidden_states, *others
|
||||
|
||||
|
||||
# mm_hidden_states = torch.where(multimodal_mask.bool(), hidden_states, torch.zeros_like(hidden_states))
|
||||
output_hidden_states, *others = super().forward(hidden_states, **kwargs)
|
||||
output_hidden_states = torch.where(multimodal_mask.bool(), output_hidden_states, hidden_states)
|
||||
|
|
@ -48,7 +48,7 @@ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel):
|
|||
Qwen2PreTrainedModel.__init__(self, config)
|
||||
self.padding_idx = -1
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
|
||||
self.lm_embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.mm_embed_tokens = nn.Embedding(config.mm_vocab_size, config.hidden_size, self.padding_idx)
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class XOmniModel(Qwen2Model, Qwen2PreTrainedModel):
|
|||
class XOmniForCausalLM(Qwen2ForCausalLM):
|
||||
model_type = "x-omni"
|
||||
config_class = XOmniConfig
|
||||
|
||||
|
||||
_keys_to_ignore_on_load_missing = r'image_tokenizer\.*'
|
||||
|
||||
def __init__(self, config):
|
||||
|
|
@ -102,7 +102,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
self.model = XOmniModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.mm_head = nn.Linear(config.hidden_size, config.mm_vocab_size, bias=False)
|
||||
|
||||
|
||||
self.generation_mode = 'text'
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
|
@ -110,7 +110,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
|
||||
def init_vision(self, flux_pipe_path, **kwargs):
|
||||
self.som_token = self.config.mm_special_tokens[0]
|
||||
self.eom_token = self.config.mm_special_tokens[1]
|
||||
|
|
@ -125,7 +125,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
self.vision_dtype = dtype_map[self.vision_config.dtype]
|
||||
|
||||
self.image_transform = create_anyres_preprocess(**self.vision_config.transform)
|
||||
|
||||
|
||||
self.encoder_config.siglip_path = os.path.join(self.name_or_path, self.encoder_config.siglip_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.siglip_path)
|
||||
self.encoder_config.projector_path = os.path.join(self.name_or_path, self.encoder_config.projector_path) if os.path.isdir(self.name_or_path) else hf_hub_download(repo_id=self.name_or_path, filename=self.encoder_config.projector_path)
|
||||
|
||||
|
|
@ -133,15 +133,15 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
self.image_tokenizer.to(self.device, self.vision_dtype)
|
||||
|
||||
transformer = FluxTransformer2DModelWithSigLIP.from_pretrained(
|
||||
self.name_or_path,
|
||||
siglip_channels=self.encoder_config.z_channels,
|
||||
self.name_or_path,
|
||||
siglip_channels=self.encoder_config.z_channels,
|
||||
torch_dtype=self.vision_dtype,
|
||||
subfolder=self.decoder_config.model_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.decoder_pipe = FluxPipelineWithSigLIP.from_pretrained(
|
||||
flux_pipe_path,
|
||||
flux_pipe_path,
|
||||
transformer=transformer,
|
||||
torch_dtype=self.vision_dtype,
|
||||
)
|
||||
|
|
@ -161,7 +161,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
if len(images) > 0:
|
||||
doc += self.tokenize_image(images.pop(0))
|
||||
return tokenizer.encode(doc, **kwargs)
|
||||
|
||||
|
||||
def mmdecode(self, tokenizer, token_ids, force_text=None, **kwargs):
|
||||
force_text = force_text or []
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
|
|
@ -174,7 +174,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
assert len(token_ids) == 1
|
||||
token_ids = token_ids[0]
|
||||
assert isinstance(token_ids[0], int)
|
||||
|
||||
|
||||
doc = tokenizer.decode(token_ids, **kwargs)
|
||||
doc = doc.replace(tokenizer.pad_token, '')
|
||||
doc = doc.replace('<SEP>', '')
|
||||
|
|
@ -197,7 +197,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
image = self.detokenize_image(texts, images, token_ids, (H, W))
|
||||
images.append(image)
|
||||
return texts, images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def tokenize_image(self, image):
|
||||
assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.'
|
||||
|
|
@ -213,7 +213,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
token_str = ''.join(map(lambda x: '<MM-Token-{token_id}>'.format(token_id=x), tokens))
|
||||
image_str = f'{self.som_token}{H} {W}{self.img_token}{token_str}{self.eom_token}'
|
||||
return image_str
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detokenize_image(self, texts, images, token_ids, shape):
|
||||
assert hasattr(self, 'image_tokenizer'), 'Please call "init_vision" before that.'
|
||||
|
|
@ -228,7 +228,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
[texts[0]],
|
||||
negative_prompt=[''],
|
||||
height=H * upscale_factor, width=W * upscale_factor,
|
||||
num_inference_steps=self.decoder_config.num_inference_steps,
|
||||
num_inference_steps=self.decoder_config.num_inference_steps,
|
||||
guidance_scale=1.0,
|
||||
true_cfg_scale=self.decoder_config.cfg_scale,
|
||||
true_cfg_scale_2=self.decoder_config.cfg_scale_2,
|
||||
|
|
@ -236,7 +236,7 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
|
@ -275,14 +275,14 @@ class XOmniForCausalLM(Qwen2ForCausalLM):
|
|||
hidden_states = outputs[0]
|
||||
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
||||
logits = hidden_states.new_full(
|
||||
(*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size),
|
||||
(*hidden_states.shape[:-1], self.config.vocab_size + self.config.mm_vocab_size),
|
||||
torch.finfo(hidden_states.dtype).min
|
||||
)
|
||||
if self.generation_mode == 'text':
|
||||
logits[:, :, :self.config.vocab_size] = self.lm_head(hidden_states)
|
||||
else:
|
||||
logits[:, :, self.config.vocab_size:self.config.vocab_size + self.config.image_vocab_size] = self.mm_head(hidden_states)[:, :, :self.config.image_vocab_size]
|
||||
|
||||
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue