mirror of https://github.com/InstantID/InstantID
Fixes issue for tensors on same device when running inference in multithreaded environment
parent
2145b67f96
commit
26103901c9
|
|
@ -427,6 +427,22 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
||||
else:
|
||||
mask = torch.ones_like(ip_hidden_states)
|
||||
|
||||
mask = mask.to(ip_hidden_states.device)
|
||||
if mask.shape[1] < ip_hidden_states.shape[1]:
|
||||
# Pad mask if it's shorter
|
||||
pad_size = ip_hidden_states.shape[1] - mask.shape[1]
|
||||
mask = F.pad(mask, (0, 0, 0, pad_size), mode="constant", value=1.0)
|
||||
else:
|
||||
# Truncate mask if it's longer
|
||||
mask = mask[:, :ip_hidden_states.shape[1]]
|
||||
|
||||
# Ensure mask has the same number of dimensions as ip_hidden_states
|
||||
if mask.ndim < ip_hidden_states.ndim:
|
||||
mask = mask.unsqueeze(-1)
|
||||
elif mask.ndim > ip_hidden_states.ndim:
|
||||
mask = mask.squeeze(-1)
|
||||
|
||||
ip_hidden_states = ip_hidden_states * mask
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
|
|
|||
Loading…
Reference in New Issue