fix: resize issue with t2i-adapter/color alignment
parent
fed7a52909
commit
5dfdc772d4
|
|
@ -1,6 +1,20 @@
|
|||
import cv2
|
||||
|
||||
def cv2_resize_shortest_edge(image, size):
|
||||
h, w = image.shape[:2]
|
||||
if h < w:
|
||||
new_h = size
|
||||
new_w = int(round(w / h * size))
|
||||
else:
|
||||
new_w = size
|
||||
new_h = int(round(h / w * size))
|
||||
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
return resized_image
|
||||
|
||||
def apply_color(img, res=512):
|
||||
input_img_color = cv2.resize(img, (res//64, res//64), interpolation=cv2.INTER_CUBIC)
|
||||
input_img_color = cv2.resize(input_img_color, (res, res), interpolation=cv2.INTER_NEAREST)
|
||||
img = cv2_resize_shortest_edge(img, res)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
|
||||
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
return input_img_color
|
||||
|
|
@ -90,12 +90,13 @@ class PlugableAdapter(nn.Module):
|
|||
self.control = None
|
||||
self.hint_cond = None
|
||||
|
||||
def forward(self, hint=None, *args, **kwargs):
|
||||
def forward(self, x=None, hint=None, *args, **kwargs):
|
||||
if self.control is not None:
|
||||
return deepcopy(self.control)
|
||||
|
||||
self.hint_cond = hint
|
||||
hint_in = hint
|
||||
|
||||
if hasattr(self.control_model, 'conv_in') and self.control_model.conv_in.in_channels == 64:
|
||||
hint_in = hint_in[0].unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -382,6 +382,13 @@ class Script(scripts.Script):
|
|||
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
||||
gr.update(visible=True)
|
||||
]
|
||||
elif module == "color":
|
||||
return [
|
||||
gr.update(label="Annotator Resolution", value=64, minimum=64, maximum=2048, step=8, interactive=True),
|
||||
gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False),
|
||||
gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False),
|
||||
gr.update(visible=True)
|
||||
]
|
||||
elif module == "none":
|
||||
return [
|
||||
gr.update(label="Normal Resolution", value=64, minimum=64, maximum=2048, interactive=False),
|
||||
|
|
|
|||
|
|
@ -106,6 +106,12 @@ class UnetHook(nn.Module):
|
|||
if is_adapter:
|
||||
return torch.cat([cond + x, uncond], dim=0)
|
||||
|
||||
# resize to sample resolution
|
||||
base_h, base_w = base.shape[-2:]
|
||||
xh, xw = x.shape[-2:]
|
||||
if base_h != xh or base_w != xw:
|
||||
x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
|
||||
|
||||
return base + x
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, **kwargs):
|
||||
|
|
@ -187,6 +193,7 @@ class UnetHook(nn.Module):
|
|||
target[idx] += item
|
||||
|
||||
control = total_control
|
||||
# print(torch.mean(torch.stack([control])))
|
||||
assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
|
||||
hs = []
|
||||
with th.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in New Issue