Fixes issue for tensors on same device when running inference in multithreaded environment

pull/285/head
amarsingh.thakur 2025-01-15 16:12:18 +05:30
parent 2145b67f96
commit 26103901c9
1 changed files with 16 additions and 0 deletions

View File

@ -427,6 +427,22 @@ class IPAttnProcessor2_0(torch.nn.Module):
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
mask = mask.to(ip_hidden_states.device)
if mask.shape[1] < ip_hidden_states.shape[1]:
# Pad mask if it's shorter
pad_size = ip_hidden_states.shape[1] - mask.shape[1]
mask = F.pad(mask, (0, 0, 0, pad_size), mode="constant", value=1.0)
else:
# Truncate mask if it's longer
mask = mask[:, :ip_hidden_states.shape[1]]
# Ensure mask has the same number of dimensions as ip_hidden_states
if mask.ndim < ip_hidden_states.ndim:
mask = mask.unsqueeze(-1)
elif mask.ndim > ip_hidden_states.ndim:
mask = mask.squeeze(-1)
ip_hidden_states = ip_hidden_states * mask
hidden_states = hidden_states + self.scale * ip_hidden_states