diff --git a/tutorial_train.py b/tutorial_train.py index 2c7273f..10a7ab4 100644 --- a/tutorial_train.py +++ b/tutorial_train.py @@ -104,18 +104,43 @@ def collate_fn(data): class IPAdapter(torch.nn.Module): """IP-Adapter""" - def __init__(self, unet, image_proj_model, adapter_modules): + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): ip_tokens = self.image_proj_model(image_embeds) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) - # Predict the noise residual and compute loss + # Predict the noise residual noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") + def parse_args(): @@ -127,6 +152,12 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) parser.add_argument( "--data_json_file", type=str, @@ -289,7 +320,7 @@ def main(): unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) - ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": diff --git a/tutorial_train_plus.py b/tutorial_train_plus.py index e83a48f..83c855b 100644 --- a/tutorial_train_plus.py +++ b/tutorial_train_plus.py @@ -104,18 +104,53 @@ def collate_fn(data): class IPAdapter(torch.nn.Module): """IP-Adapter""" - def __init__(self, unet, image_proj_model, adapter_modules): + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): ip_tokens = self.image_proj_model(image_embeds) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) - # Predict the noise residual and compute loss + # Predict the noise residual noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Check if 'latents' exists in both the saved state_dict and the current model's state_dict + strict_load_image_proj_model = True + if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict(): + # Check if the shapes are mismatched + if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape: + print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") + print("Removing 'latents' from checkpoint and loading the rest of the weights.") + del state_dict["image_proj"]["latents"] + strict_load_image_proj_model = False + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") + def parse_args(): @@ -127,6 +162,18 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) + parser.add_argument( + "--num_tokens", + type=int, + default=16, + help="Number of tokens to query from the CLIP image encoding.", + ) parser.add_argument( "--data_json_file", type=str, @@ -258,13 +305,12 @@ def main(): image_encoder.requires_grad_(False) #ip-adapter-plus - num_tokens = 16 image_proj_model = Resampler( dim=unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, - num_queries=num_tokens, + num_queries=args.num_tokens, embedding_dim=image_encoder.config.hidden_size, output_dim=unet.config.cross_attention_dim, ff_mult=4 @@ -290,12 +336,12 @@ def main(): "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } - attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens) attn_procs[name].load_state_dict(weights) unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) - ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": diff --git a/tutorial_train_sdxl.py b/tutorial_train_sdxl.py index 432599d..2fb2e36 100644 --- a/tutorial_train_sdxl.py +++ b/tutorial_train_sdxl.py @@ -150,20 +150,44 @@ def collate_fn(data): class IPAdapter(torch.nn.Module): """IP-Adapter""" - def __init__(self, unet, image_proj_model, adapter_modules): + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules - def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): ip_tokens = self.image_proj_model(image_embeds) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) # Predict the noise residual - noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") - + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -173,6 +197,12 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) parser.add_argument( "--data_json_file", type=str, @@ -340,7 +370,7 @@ def main(): unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) - ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16":