diff --git a/README.md b/README.md
index 5666d07..0205943 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,8 @@ Dragging a file on the Web UI will freeze the entire page. It is better to use t
### Install
+Some users may need to install the cv2 library before installing it: `pip install opencv-python`
+
1. Open "Extensions" tab.
2. Open "Install from URL" tab in the tab.
3. Enter URL of this repo to "URL for extension's git repository".
@@ -32,3 +34,4 @@ Currently it supports both full models and trimmed models. Use `extract_controln
| Source | Input | Output |
|:-------------------------:|:-------------------------:|:-------------------------:|
|
|
|
|
+|
|
|
|
diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py
index 42d8dc6..c0b99c4 100644
--- a/annotator/hed/__init__.py
+++ b/annotator/hed/__init__.py
@@ -3,9 +3,11 @@ import cv2
import torch
from einops import rearrange
+import os
+from modules import extensions
class Network(torch.nn.Module):
- def __init__(self):
+ def __init__(self, model_path):
super().__init__()
self.netVggOne = torch.nn.Sequential(
@@ -64,7 +66,7 @@ class Network(torch.nn.Module):
torch.nn.Sigmoid()
)
- self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()})
+ self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
# end
def forward(self, tenInput):
@@ -93,11 +95,19 @@ class Network(torch.nn.Module):
# end
# end
-
-netNetwork = Network().cuda().eval()
-
+netNetwork = None
+remote_model_path = "https://huggingface.co/datasets/nyanko7/tmp-public/resolve/main/network-bsds500.pt"
+modeldir = os.path.join(extensions.extensions_dir, "sd-webui-controlnet", "annotator")
def apply_hed(input_image):
+ global netNetwork
+ if netNetwork is None:
+ modelpath = os.path.join(modeldir, "network-bsds500.pt")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=modeldir)
+ netNetwork = Network(modelpath).cuda().eval()
+
assert input_image.ndim == 3
input_image = input_image[:, :, ::-1].copy()
with torch.no_grad():
diff --git a/samples/evt_gen.png b/samples/evt_gen.png
new file mode 100644
index 0000000..5d0dbf1
Binary files /dev/null and b/samples/evt_gen.png differ
diff --git a/samples/evt_hed.png b/samples/evt_hed.png
new file mode 100644
index 0000000..fa7feb7
Binary files /dev/null and b/samples/evt_hed.png differ
diff --git a/samples/evt_source.jpg b/samples/evt_source.jpg
new file mode 100644
index 0000000..0a21210
Binary files /dev/null and b/samples/evt_source.jpg differ
diff --git a/samples/mahiro_canny.png b/samples/mahiro_canny.png
index cb09613..ecaeef3 100644
Binary files a/samples/mahiro_canny.png and b/samples/mahiro_canny.png differ
diff --git a/samples/mahiro_gen.png b/samples/mahiro_gen.png
index 9c2ea6f..35ecc1a 100644
Binary files a/samples/mahiro_gen.png and b/samples/mahiro_gen.png differ
diff --git a/samples/mahiro_input.png b/samples/mahiro_input.png
index bc97e9e..0ee95dd 100644
Binary files a/samples/mahiro_input.png and b/samples/mahiro_input.png differ
diff --git a/scripts/controlnet.py b/scripts/controlnet.py
index df8b10d..69569da 100644
--- a/scripts/controlnet.py
+++ b/scripts/controlnet.py
@@ -162,15 +162,14 @@ class Script(scripts.Script):
selected = dd
else:
selected = "None"
- print(cn_models)
- update = gr.Dropdown.update(
- value=selected, choices=list(cn_models.keys()))
+
+ update = gr.Dropdown.update(value=selected, choices=list(cn_models.keys()))
updates.append(update)
return updates
refresh_models = gr.Button(value='Refresh models')
refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
- ctrls += (refresh_models, )
+ # ctrls += (refresh_models, )
def create_canvas(h, w):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
@@ -182,7 +181,7 @@ class Script(scripts.Script):
gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.')
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
- ctrls += (canvas_width, canvas_height, create_button, input_image, scribble_mode)
+ ctrls += (input_image, scribble_mode)
return ctrls
@@ -212,15 +211,14 @@ class Script(scripts.Script):
self.latest_network.restore(unet)
self.latest_network = None
- enabled, module, model, weight, _ = args[:5]
- _, _, _, image, scribble_mode = args[5:]
+ enabled, module, model, weight,image, scribble_mode = args
if not enabled:
restore_networks()
return
models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
- or self.latest_model_hash != p.sd_model.sd_model_hash
+ or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None
if models_changed:
restore_networks()
@@ -273,6 +271,8 @@ class Script(scripts.Script):
self.set_infotext_fields(p, self.latest_params)
def postprocess(self, p, processed, *args):
+ if self.latest_network is None:
+ return
if hasattr(self, "control") and self.control is not None:
processed.images.append(ToPILImage()((self.control).clip(0, 255)))