implemented finetune from pretrained checkpoint
parent
5ea001b1f7
commit
cbe5fa0643
|
|
@ -104,19 +104,44 @@ def collate_fn(data):
|
||||||
|
|
||||||
class IPAdapter(torch.nn.Module):
|
class IPAdapter(torch.nn.Module):
|
||||||
"""IP-Adapter"""
|
"""IP-Adapter"""
|
||||||
def __init__(self, unet, image_proj_model, adapter_modules):
|
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.image_proj_model = image_proj_model
|
self.image_proj_model = image_proj_model
|
||||||
self.adapter_modules = adapter_modules
|
self.adapter_modules = adapter_modules
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.load_from_checkpoint(ckpt_path)
|
||||||
|
|
||||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||||
ip_tokens = self.image_proj_model(image_embeds)
|
ip_tokens = self.image_proj_model(image_embeds)
|
||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual
|
||||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
def load_from_checkpoint(self, ckpt_path: str):
|
||||||
|
# Calculate original checksums
|
||||||
|
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||||
|
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||||
|
|
||||||
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Load state dict for image_proj_model and adapter_modules
|
||||||
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
|
||||||
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||||
|
|
||||||
|
# Calculate new checksums
|
||||||
|
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||||
|
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||||
|
|
||||||
|
# Verify if the weights have changed
|
||||||
|
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||||
|
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||||
|
|
||||||
|
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
|
@ -127,6 +152,12 @@ def parse_args():
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_ip_adapter_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_json_file",
|
"--data_json_file",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -289,7 +320,7 @@ def main():
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||||
|
|
||||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)
|
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||||
|
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
|
|
|
||||||
|
|
@ -104,19 +104,54 @@ def collate_fn(data):
|
||||||
|
|
||||||
class IPAdapter(torch.nn.Module):
|
class IPAdapter(torch.nn.Module):
|
||||||
"""IP-Adapter"""
|
"""IP-Adapter"""
|
||||||
def __init__(self, unet, image_proj_model, adapter_modules):
|
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.image_proj_model = image_proj_model
|
self.image_proj_model = image_proj_model
|
||||||
self.adapter_modules = adapter_modules
|
self.adapter_modules = adapter_modules
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.load_from_checkpoint(ckpt_path)
|
||||||
|
|
||||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||||
ip_tokens = self.image_proj_model(image_embeds)
|
ip_tokens = self.image_proj_model(image_embeds)
|
||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual
|
||||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
def load_from_checkpoint(self, ckpt_path: str):
|
||||||
|
# Calculate original checksums
|
||||||
|
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||||
|
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||||
|
|
||||||
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
|
||||||
|
strict_load_image_proj_model = True
|
||||||
|
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
|
||||||
|
# Check if the shapes are mismatched
|
||||||
|
if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
|
||||||
|
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
|
||||||
|
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
|
||||||
|
del state_dict["image_proj"]["latents"]
|
||||||
|
strict_load_image_proj_model = False
|
||||||
|
|
||||||
|
# Load state dict for image_proj_model and adapter_modules
|
||||||
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
|
||||||
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||||
|
|
||||||
|
# Calculate new checksums
|
||||||
|
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||||
|
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||||
|
|
||||||
|
# Verify if the weights have changed
|
||||||
|
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||||
|
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||||
|
|
||||||
|
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
|
@ -127,6 +162,18 @@ def parse_args():
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_ip_adapter_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_tokens",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of tokens to query from the CLIP image encoding.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_json_file",
|
"--data_json_file",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -258,13 +305,12 @@ def main():
|
||||||
image_encoder.requires_grad_(False)
|
image_encoder.requires_grad_(False)
|
||||||
|
|
||||||
#ip-adapter-plus
|
#ip-adapter-plus
|
||||||
num_tokens = 16
|
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler(
|
||||||
dim=unet.config.cross_attention_dim,
|
dim=unet.config.cross_attention_dim,
|
||||||
depth=4,
|
depth=4,
|
||||||
dim_head=64,
|
dim_head=64,
|
||||||
heads=12,
|
heads=12,
|
||||||
num_queries=num_tokens,
|
num_queries=args.num_tokens,
|
||||||
embedding_dim=image_encoder.config.hidden_size,
|
embedding_dim=image_encoder.config.hidden_size,
|
||||||
output_dim=unet.config.cross_attention_dim,
|
output_dim=unet.config.cross_attention_dim,
|
||||||
ff_mult=4
|
ff_mult=4
|
||||||
|
|
@ -290,12 +336,12 @@ def main():
|
||||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||||
}
|
}
|
||||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens)
|
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)
|
||||||
attn_procs[name].load_state_dict(weights)
|
attn_procs[name].load_state_dict(weights)
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||||
|
|
||||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)
|
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||||
|
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
|
|
|
||||||
|
|
@ -150,19 +150,43 @@ def collate_fn(data):
|
||||||
|
|
||||||
class IPAdapter(torch.nn.Module):
|
class IPAdapter(torch.nn.Module):
|
||||||
"""IP-Adapter"""
|
"""IP-Adapter"""
|
||||||
def __init__(self, unet, image_proj_model, adapter_modules):
|
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.image_proj_model = image_proj_model
|
self.image_proj_model = image_proj_model
|
||||||
self.adapter_modules = adapter_modules
|
self.adapter_modules = adapter_modules
|
||||||
|
|
||||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
|
if ckpt_path is not None:
|
||||||
|
self.load_from_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||||
ip_tokens = self.image_proj_model(image_embeds)
|
ip_tokens = self.image_proj_model(image_embeds)
|
||||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
def load_from_checkpoint(self, ckpt_path: str):
|
||||||
|
# Calculate original checksums
|
||||||
|
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||||
|
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||||
|
|
||||||
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Load state dict for image_proj_model and adapter_modules
|
||||||
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
|
||||||
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||||
|
|
||||||
|
# Calculate new checksums
|
||||||
|
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||||
|
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||||
|
|
||||||
|
# Verify if the weights have changed
|
||||||
|
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||||
|
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||||
|
|
||||||
|
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
|
@ -173,6 +197,12 @@ def parse_args():
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_ip_adapter_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_json_file",
|
"--data_json_file",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -340,7 +370,7 @@ def main():
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||||
|
|
||||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)
|
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||||
|
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue