Fix issue with ernie image. (#13393)
parent
acd718598e
commit
402ff1cdb7
|
|
@ -279,7 +279,7 @@ class ErnieImageModel(nn.Module):
|
|||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||
del image_ids, text_ids
|
||||
|
||||
sample = self.time_proj(timesteps.to(dtype)).to(self.time_embedding.linear_1.weight.dtype)
|
||||
sample = self.time_proj(timesteps).to(dtype)
|
||||
c = self.time_embedding(sample)
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue