Fix typos

Better support IPAttnProcessor2_0
pull/141/head
InstantX 2024-02-02 01:05:48 +08:00 committed by GitHub
parent 98332df4c1
commit e36ca465ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -421,6 +421,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])