diff --git a/README.md b/README.md index 5666d07..0205943 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ Dragging a file on the Web UI will freeze the entire page. It is better to use t ### Install +Some users may need to install the cv2 library before installing it: `pip install opencv-python` + 1. Open "Extensions" tab. 2. Open "Install from URL" tab in the tab. 3. Enter URL of this repo to "URL for extension's git repository". @@ -32,3 +34,4 @@ Currently it supports both full models and trimmed models. Use `extract_controln | Source | Input | Output | |:-------------------------:|:-------------------------:|:-------------------------:| | | | | +| | | | diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py index 42d8dc6..c0b99c4 100644 --- a/annotator/hed/__init__.py +++ b/annotator/hed/__init__.py @@ -3,9 +3,11 @@ import cv2 import torch from einops import rearrange +import os +from modules import extensions class Network(torch.nn.Module): - def __init__(self): + def __init__(self, model_path): super().__init__() self.netVggOne = torch.nn.Sequential( @@ -64,7 +66,7 @@ class Network(torch.nn.Module): torch.nn.Sigmoid() ) - self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()}) + self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()}) # end def forward(self, tenInput): @@ -93,11 +95,19 @@ class Network(torch.nn.Module): # end # end - -netNetwork = Network().cuda().eval() - +netNetwork = None +remote_model_path = "https://huggingface.co/datasets/nyanko7/tmp-public/resolve/main/network-bsds500.pt" +modeldir = os.path.join(extensions.extensions_dir, "sd-webui-controlnet", "annotator") def apply_hed(input_image): + global netNetwork + if netNetwork is None: + modelpath = os.path.join(modeldir, "network-bsds500.pt") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=modeldir) + netNetwork = Network(modelpath).cuda().eval() + assert input_image.ndim == 3 input_image = input_image[:, :, ::-1].copy() with torch.no_grad(): diff --git a/samples/evt_gen.png b/samples/evt_gen.png new file mode 100644 index 0000000..5d0dbf1 Binary files /dev/null and b/samples/evt_gen.png differ diff --git a/samples/evt_hed.png b/samples/evt_hed.png new file mode 100644 index 0000000..fa7feb7 Binary files /dev/null and b/samples/evt_hed.png differ diff --git a/samples/evt_source.jpg b/samples/evt_source.jpg new file mode 100644 index 0000000..0a21210 Binary files /dev/null and b/samples/evt_source.jpg differ diff --git a/samples/mahiro_canny.png b/samples/mahiro_canny.png index cb09613..ecaeef3 100644 Binary files a/samples/mahiro_canny.png and b/samples/mahiro_canny.png differ diff --git a/samples/mahiro_gen.png b/samples/mahiro_gen.png index 9c2ea6f..35ecc1a 100644 Binary files a/samples/mahiro_gen.png and b/samples/mahiro_gen.png differ diff --git a/samples/mahiro_input.png b/samples/mahiro_input.png index bc97e9e..0ee95dd 100644 Binary files a/samples/mahiro_input.png and b/samples/mahiro_input.png differ diff --git a/scripts/controlnet.py b/scripts/controlnet.py index df8b10d..69569da 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -162,15 +162,14 @@ class Script(scripts.Script): selected = dd else: selected = "None" - print(cn_models) - update = gr.Dropdown.update( - value=selected, choices=list(cn_models.keys())) + + update = gr.Dropdown.update(value=selected, choices=list(cn_models.keys())) updates.append(update) return updates refresh_models = gr.Button(value='Refresh models') refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns) - ctrls += (refresh_models, ) + # ctrls += (refresh_models, ) def create_canvas(h, w): return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 @@ -182,7 +181,7 @@ class Script(scripts.Script): gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.') create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image]) - ctrls += (canvas_width, canvas_height, create_button, input_image, scribble_mode) + ctrls += (input_image, scribble_mode) return ctrls @@ -212,15 +211,14 @@ class Script(scripts.Script): self.latest_network.restore(unet) self.latest_network = None - enabled, module, model, weight, _ = args[:5] - _, _, _, image, scribble_mode = args[5:] + enabled, module, model, weight,image, scribble_mode = args if not enabled: restore_networks() return models_changed = self.latest_params[0] != module or self.latest_params[1] != model \ - or self.latest_model_hash != p.sd_model.sd_model_hash + or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None if models_changed: restore_networks() @@ -273,6 +271,8 @@ class Script(scripts.Script): self.set_infotext_fields(p, self.latest_params) def postprocess(self, p, processed, *args): + if self.latest_network is None: + return if hasattr(self, "control") and self.control is not None: processed.images.append(ToPILImage()((self.control).clip(0, 255)))