feat: make hed mode works

pull/1/head
Mikubill 2023-02-13 08:21:10 +00:00
parent d314cfb77b
commit 565447cfdb
9 changed files with 26 additions and 13 deletions

View File

@ -13,6 +13,8 @@ Dragging a file on the Web UI will freeze the entire page. It is better to use t
### Install
Some users may need to install the cv2 library before installing it: `pip install opencv-python`
1. Open "Extensions" tab.
2. Open "Install from URL" tab in the tab.
3. Enter URL of this repo to "URL for extension's git repository".
@ -32,3 +34,4 @@ Currently it supports both full models and trimmed models. Use `extract_controln
| Source | Input | Output |
|:-------------------------:|:-------------------------:|:-------------------------:|
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_gen.png?raw=true"> |
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_source.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_hed.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_gen.png?raw=true"> |

View File

@ -3,9 +3,11 @@ import cv2
import torch
from einops import rearrange
import os
from modules import extensions
class Network(torch.nn.Module):
def __init__(self):
def __init__(self, model_path):
super().__init__()
self.netVggOne = torch.nn.Sequential(
@ -64,7 +66,7 @@ class Network(torch.nn.Module):
torch.nn.Sigmoid()
)
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()})
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
# end
def forward(self, tenInput):
@ -93,11 +95,19 @@ class Network(torch.nn.Module):
# end
# end
netNetwork = Network().cuda().eval()
netNetwork = None
remote_model_path = "https://huggingface.co/datasets/nyanko7/tmp-public/resolve/main/network-bsds500.pt"
modeldir = os.path.join(extensions.extensions_dir, "sd-webui-controlnet", "annotator")
def apply_hed(input_image):
global netNetwork
if netNetwork is None:
modelpath = os.path.join(modeldir, "network-bsds500.pt")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=modeldir)
netNetwork = Network(modelpath).cuda().eval()
assert input_image.ndim == 3
input_image = input_image[:, :, ::-1].copy()
with torch.no_grad():

BIN
samples/evt_gen.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

BIN
samples/evt_hed.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

BIN
samples/evt_source.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

After

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 359 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 240 KiB

After

Width:  |  Height:  |  Size: 59 KiB

View File

@ -162,15 +162,14 @@ class Script(scripts.Script):
selected = dd
else:
selected = "None"
print(cn_models)
update = gr.Dropdown.update(
value=selected, choices=list(cn_models.keys()))
update = gr.Dropdown.update(value=selected, choices=list(cn_models.keys()))
updates.append(update)
return updates
refresh_models = gr.Button(value='Refresh models')
refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
ctrls += (refresh_models, )
# ctrls += (refresh_models, )
def create_canvas(h, w):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
@ -182,7 +181,7 @@ class Script(scripts.Script):
gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.')
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
ctrls += (canvas_width, canvas_height, create_button, input_image, scribble_mode)
ctrls += (input_image, scribble_mode)
return ctrls
@ -212,15 +211,14 @@ class Script(scripts.Script):
self.latest_network.restore(unet)
self.latest_network = None
enabled, module, model, weight, _ = args[:5]
_, _, _, image, scribble_mode = args[5:]
enabled, module, model, weight,image, scribble_mode = args
if not enabled:
restore_networks()
return
models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
or self.latest_model_hash != p.sd_model.sd_model_hash
or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None
if models_changed:
restore_networks()
@ -273,6 +271,8 @@ class Script(scripts.Script):
self.set_infotext_fields(p, self.latest_params)
def postprocess(self, p, processed, *args):
if self.latest_network is None:
return
if hasattr(self, "control") and self.control is not None:
processed.images.append(ToPILImage()((self.control).clip(0, 255)))