implemented finetune from pretrained checkpoint

pull/105/head
danbochman 2023-10-18 11:18:32 +02:00
parent 5ea001b1f7
commit cbe5fa0643
No known key found for this signature in database
GPG Key ID: B0DE112E399D1082
3 changed files with 121 additions and 14 deletions

View File

@ -104,19 +104,44 @@ def collate_fn(data):
class IPAdapter(torch.nn.Module): class IPAdapter(torch.nn.Module):
"""IP-Adapter""" """IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules): def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__() super().__init__()
self.unet = unet self.unet = unet
self.image_proj_model = image_proj_model self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds) ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual and compute loss # Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
@ -127,6 +152,12 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument( parser.add_argument(
"--data_json_file", "--data_json_file",
type=str, type=str,
@ -289,7 +320,7 @@ def main():
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":

View File

@ -104,19 +104,54 @@ def collate_fn(data):
class IPAdapter(torch.nn.Module): class IPAdapter(torch.nn.Module):
"""IP-Adapter""" """IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules): def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__() super().__init__()
self.unet = unet self.unet = unet
self.image_proj_model = image_proj_model self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds) ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual and compute loss # Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
strict_load_image_proj_model = True
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
# Check if the shapes are mismatched
if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
del state_dict["image_proj"]["latents"]
strict_load_image_proj_model = False
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
@ -127,6 +162,18 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--num_tokens",
type=int,
default=16,
help="Number of tokens to query from the CLIP image encoding.",
)
parser.add_argument( parser.add_argument(
"--data_json_file", "--data_json_file",
type=str, type=str,
@ -258,13 +305,12 @@ def main():
image_encoder.requires_grad_(False) image_encoder.requires_grad_(False)
#ip-adapter-plus #ip-adapter-plus
num_tokens = 16
image_proj_model = Resampler( image_proj_model = Resampler(
dim=unet.config.cross_attention_dim, dim=unet.config.cross_attention_dim,
depth=4, depth=4,
dim_head=64, dim_head=64,
heads=12, heads=12,
num_queries=num_tokens, num_queries=args.num_tokens,
embedding_dim=image_encoder.config.hidden_size, embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim, output_dim=unet.config.cross_attention_dim,
ff_mult=4 ff_mult=4
@ -290,12 +336,12 @@ def main():
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
} }
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)
attn_procs[name].load_state_dict(weights) attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":

View File

@ -150,19 +150,43 @@ def collate_fn(data):
class IPAdapter(torch.nn.Module): class IPAdapter(torch.nn.Module):
"""IP-Adapter""" """IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules): def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__() super().__init__()
self.unet = unet self.unet = unet
self.image_proj_model = image_proj_model self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules self.adapter_modules = adapter_modules
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds) ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual # Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
@ -173,6 +197,12 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument( parser.add_argument(
"--data_json_file", "--data_json_file",
type=str, type=str,
@ -340,7 +370,7 @@ def main():
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":