minor fixes:

1. image_proj strict=True loading in regular IPAdapter
2. Re-added unet_added_cond_kwargs deleted by mistake in SDXL
pull/105/head
danbochman 2023-10-18 11:41:09 +02:00
parent cbe5fa0643
commit 383f9691ec
No known key found for this signature in database
GPG Key ID: B0DE112E399D1082
2 changed files with 4 additions and 4 deletions

View File

@ -128,7 +128,7 @@ class IPAdapter(torch.nn.Module):
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums

View File

@ -159,11 +159,11 @@ class IPAdapter(torch.nn.Module):
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
@ -174,7 +174,7 @@ class IPAdapter(torch.nn.Module):
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums