diff --git a/annotator/color/__init__.py b/annotator/color/__init__.py index f9453ad..65799a2 100644 --- a/annotator/color/__init__.py +++ b/annotator/color/__init__.py @@ -1,6 +1,20 @@ import cv2 +def cv2_resize_shortest_edge(image, size): + h, w = image.shape[:2] + if h < w: + new_h = size + new_w = int(round(w / h * size)) + else: + new_w = size + new_h = int(round(h / w * size)) + resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) + return resized_image + def apply_color(img, res=512): - input_img_color = cv2.resize(img, (res//64, res//64), interpolation=cv2.INTER_CUBIC) - input_img_color = cv2.resize(input_img_color, (res, res), interpolation=cv2.INTER_NEAREST) + img = cv2_resize_shortest_edge(img, res) + h, w = img.shape[:2] + + input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC) + input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST) return input_img_color \ No newline at end of file diff --git a/scripts/adapter.py b/scripts/adapter.py index c7809e8..d39e8c9 100644 --- a/scripts/adapter.py +++ b/scripts/adapter.py @@ -90,12 +90,13 @@ class PlugableAdapter(nn.Module): self.control = None self.hint_cond = None - def forward(self, hint=None, *args, **kwargs): + def forward(self, x=None, hint=None, *args, **kwargs): if self.control is not None: return deepcopy(self.control) self.hint_cond = hint hint_in = hint + if hasattr(self.control_model, 'conv_in') and self.control_model.conv_in.in_channels == 64: hint_in = hint_in[0].unsqueeze(0).unsqueeze(0) else: diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 6eb54a7..0a0fffe 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -382,6 +382,13 @@ class Script(scripts.Script): gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), gr.update(visible=True) ] + elif module == "color": + return [ + gr.update(label="Annotator Resolution", value=64, minimum=64, maximum=2048, step=8, interactive=True), + gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=True) + ] elif module == "none": return [ gr.update(label="Normal Resolution", value=64, minimum=64, maximum=2048, interactive=False), diff --git a/scripts/hook.py b/scripts/hook.py index 725adb0..770af47 100644 --- a/scripts/hook.py +++ b/scripts/hook.py @@ -106,6 +106,12 @@ class UnetHook(nn.Module): if is_adapter: return torch.cat([cond + x, uncond], dim=0) + # resize to sample resolution + base_h, base_w = base.shape[-2:] + xh, xw = x.shape[-2:] + if base_h != xh or base_w != xw: + x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") + return base + x def forward(self, x, timesteps=None, context=None, **kwargs): @@ -187,6 +193,7 @@ class UnetHook(nn.Module): target[idx] += item control = total_control + # print(torch.mean(torch.stack([control]))) assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}") hs = [] with th.no_grad():