minor fixes:
1. image_proj strict=True loading in regular IPAdapter 2. Re-added unet_added_cond_kwargs deleted by mistake in SDXLpull/105/head
parent
cbe5fa0643
commit
383f9691ec
|
|
@ -128,7 +128,7 @@ class IPAdapter(torch.nn.Module):
|
|||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
# Load state dict for image_proj_model and adapter_modules
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||
|
||||
# Calculate new checksums
|
||||
|
|
|
|||
|
|
@ -159,11 +159,11 @@ class IPAdapter(torch.nn.Module):
|
|||
if ckpt_path is not None:
|
||||
self.load_from_checkpoint(ckpt_path)
|
||||
|
||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
|
||||
ip_tokens = self.image_proj_model(image_embeds)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
|
||||
return noise_pred
|
||||
|
||||
def load_from_checkpoint(self, ckpt_path: str):
|
||||
|
|
@ -174,7 +174,7 @@ class IPAdapter(torch.nn.Module):
|
|||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
# Load state dict for image_proj_model and adapter_modules
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||
|
||||
# Calculate new checksums
|
||||
|
|
|
|||
Loading…
Reference in New Issue