fix: resize issue with t2i-adapter/color alignment

pull/541/head
Mikubill 2023-03-08 03:52:21 +00:00
parent fed7a52909
commit 5dfdc772d4
4 changed files with 32 additions and 3 deletions

View File

@ -1,6 +1,20 @@
import cv2
def cv2_resize_shortest_edge(image, size):
h, w = image.shape[:2]
if h < w:
new_h = size
new_w = int(round(w / h * size))
else:
new_w = size
new_h = int(round(h / w * size))
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
return resized_image
def apply_color(img, res=512):
input_img_color = cv2.resize(img, (res//64, res//64), interpolation=cv2.INTER_CUBIC)
input_img_color = cv2.resize(input_img_color, (res, res), interpolation=cv2.INTER_NEAREST)
img = cv2_resize_shortest_edge(img, res)
h, w = img.shape[:2]
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
return input_img_color

View File

@ -90,12 +90,13 @@ class PlugableAdapter(nn.Module):
self.control = None
self.hint_cond = None
def forward(self, hint=None, *args, **kwargs):
def forward(self, x=None, hint=None, *args, **kwargs):
if self.control is not None:
return deepcopy(self.control)
self.hint_cond = hint
hint_in = hint
if hasattr(self.control_model, 'conv_in') and self.control_model.conv_in.in_channels == 64:
hint_in = hint_in[0].unsqueeze(0).unsqueeze(0)
else:

View File

@ -382,6 +382,13 @@ class Script(scripts.Script):
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
gr.update(visible=True)
]
elif module == "color":
return [
gr.update(label="Annotator Resolution", value=64, minimum=64, maximum=2048, step=8, interactive=True),
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
gr.update(visible=True)
]
elif module == "none":
return [
gr.update(label="Normal Resolution", value=64, minimum=64, maximum=2048, interactive=False),

View File

@ -106,6 +106,12 @@ class UnetHook(nn.Module):
if is_adapter:
return torch.cat([cond + x, uncond], dim=0)
# resize to sample resolution
base_h, base_w = base.shape[-2:]
xh, xw = x.shape[-2:]
if base_h != xh or base_w != xw:
x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
return base + x
def forward(self, x, timesteps=None, context=None, **kwargs):
@ -187,6 +193,7 @@ class UnetHook(nn.Module):
target[idx] += item
control = total_control
# print(torch.mean(torch.stack([control])))
assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
hs = []
with th.no_grad():