feat: make hed mode works
|
|
@ -13,6 +13,8 @@ Dragging a file on the Web UI will freeze the entire page. It is better to use t
|
||||||
|
|
||||||
### Install
|
### Install
|
||||||
|
|
||||||
|
Some users may need to install the cv2 library before installing it: `pip install opencv-python`
|
||||||
|
|
||||||
1. Open "Extensions" tab.
|
1. Open "Extensions" tab.
|
||||||
2. Open "Install from URL" tab in the tab.
|
2. Open "Install from URL" tab in the tab.
|
||||||
3. Enter URL of this repo to "URL for extension's git repository".
|
3. Enter URL of this repo to "URL for extension's git repository".
|
||||||
|
|
@ -32,3 +34,4 @@ Currently it supports both full models and trimmed models. Use `extract_controln
|
||||||
| Source | Input | Output |
|
| Source | Input | Output |
|
||||||
|:-------------------------:|:-------------------------:|:-------------------------:|
|
|:-------------------------:|:-------------------------:|:-------------------------:|
|
||||||
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_gen.png?raw=true"> |
|
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_gen.png?raw=true"> |
|
||||||
|
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_source.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_hed.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_gen.png?raw=true"> |
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@ import cv2
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
import os
|
||||||
|
from modules import extensions
|
||||||
|
|
||||||
class Network(torch.nn.Module):
|
class Network(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, model_path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.netVggOne = torch.nn.Sequential(
|
self.netVggOne = torch.nn.Sequential(
|
||||||
|
|
@ -64,7 +66,7 @@ class Network(torch.nn.Module):
|
||||||
torch.nn.Sigmoid()
|
torch.nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()})
|
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
|
||||||
# end
|
# end
|
||||||
|
|
||||||
def forward(self, tenInput):
|
def forward(self, tenInput):
|
||||||
|
|
@ -93,11 +95,19 @@ class Network(torch.nn.Module):
|
||||||
# end
|
# end
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
netNetwork = None
|
||||||
netNetwork = Network().cuda().eval()
|
remote_model_path = "https://huggingface.co/datasets/nyanko7/tmp-public/resolve/main/network-bsds500.pt"
|
||||||
|
modeldir = os.path.join(extensions.extensions_dir, "sd-webui-controlnet", "annotator")
|
||||||
|
|
||||||
def apply_hed(input_image):
|
def apply_hed(input_image):
|
||||||
|
global netNetwork
|
||||||
|
if netNetwork is None:
|
||||||
|
modelpath = os.path.join(modeldir, "network-bsds500.pt")
|
||||||
|
if not os.path.exists(modelpath):
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
load_file_from_url(remote_model_path, model_dir=modeldir)
|
||||||
|
netNetwork = Network(modelpath).cuda().eval()
|
||||||
|
|
||||||
assert input_image.ndim == 3
|
assert input_image.ndim == 3
|
||||||
input_image = input_image[:, :, ::-1].copy()
|
input_image = input_image[:, :, ::-1].copy()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
||||||
|
After Width: | Height: | Size: 176 KiB |
|
After Width: | Height: | Size: 137 KiB |
|
After Width: | Height: | Size: 142 KiB |
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 9.0 KiB |
|
Before Width: | Height: | Size: 359 KiB After Width: | Height: | Size: 95 KiB |
|
Before Width: | Height: | Size: 240 KiB After Width: | Height: | Size: 59 KiB |
|
|
@ -162,15 +162,14 @@ class Script(scripts.Script):
|
||||||
selected = dd
|
selected = dd
|
||||||
else:
|
else:
|
||||||
selected = "None"
|
selected = "None"
|
||||||
print(cn_models)
|
|
||||||
update = gr.Dropdown.update(
|
update = gr.Dropdown.update(value=selected, choices=list(cn_models.keys()))
|
||||||
value=selected, choices=list(cn_models.keys()))
|
|
||||||
updates.append(update)
|
updates.append(update)
|
||||||
return updates
|
return updates
|
||||||
|
|
||||||
refresh_models = gr.Button(value='Refresh models')
|
refresh_models = gr.Button(value='Refresh models')
|
||||||
refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
|
refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
|
||||||
ctrls += (refresh_models, )
|
# ctrls += (refresh_models, )
|
||||||
|
|
||||||
def create_canvas(h, w):
|
def create_canvas(h, w):
|
||||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||||
|
|
@ -182,7 +181,7 @@ class Script(scripts.Script):
|
||||||
gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.')
|
gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.')
|
||||||
|
|
||||||
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
|
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
|
||||||
ctrls += (canvas_width, canvas_height, create_button, input_image, scribble_mode)
|
ctrls += (input_image, scribble_mode)
|
||||||
|
|
||||||
return ctrls
|
return ctrls
|
||||||
|
|
||||||
|
|
@ -212,15 +211,14 @@ class Script(scripts.Script):
|
||||||
self.latest_network.restore(unet)
|
self.latest_network.restore(unet)
|
||||||
self.latest_network = None
|
self.latest_network = None
|
||||||
|
|
||||||
enabled, module, model, weight, _ = args[:5]
|
enabled, module, model, weight,image, scribble_mode = args
|
||||||
_, _, _, image, scribble_mode = args[5:]
|
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
restore_networks()
|
restore_networks()
|
||||||
return
|
return
|
||||||
|
|
||||||
models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
|
models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
|
||||||
or self.latest_model_hash != p.sd_model.sd_model_hash
|
or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None
|
||||||
|
|
||||||
if models_changed:
|
if models_changed:
|
||||||
restore_networks()
|
restore_networks()
|
||||||
|
|
@ -273,6 +271,8 @@ class Script(scripts.Script):
|
||||||
self.set_infotext_fields(p, self.latest_params)
|
self.set_infotext_fields(p, self.latest_params)
|
||||||
|
|
||||||
def postprocess(self, p, processed, *args):
|
def postprocess(self, p, processed, *args):
|
||||||
|
if self.latest_network is None:
|
||||||
|
return
|
||||||
if hasattr(self, "control") and self.control is not None:
|
if hasattr(self, "control") and self.control is not None:
|
||||||
processed.images.append(ToPILImage()((self.control).clip(0, 255)))
|
processed.images.append(ToPILImage()((self.control).clip(0, 255)))
|
||||||
|
|
||||||
|
|
|
||||||