feat: make hed mode works

pull/1/head
Mikubill 2023-02-13 08:21:10 +00:00
parent d314cfb77b
commit 565447cfdb
9 changed files with 26 additions and 13 deletions

View File

@ -13,6 +13,8 @@ Dragging a file on the Web UI will freeze the entire page. It is better to use t
### Install ### Install
Some users may need to install the cv2 library before installing it: `pip install opencv-python`
1. Open "Extensions" tab. 1. Open "Extensions" tab.
2. Open "Install from URL" tab in the tab. 2. Open "Install from URL" tab in the tab.
3. Enter URL of this repo to "URL for extension's git repository". 3. Enter URL of this repo to "URL for extension's git repository".
@ -32,3 +34,4 @@ Currently it supports both full models and trimmed models. Use `extract_controln
| Source | Input | Output | | Source | Input | Output |
|:-------------------------:|:-------------------------:|:-------------------------:| |:-------------------------:|:-------------------------:|:-------------------------:|
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_gen.png?raw=true"> | |<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_gen.png?raw=true"> |
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_source.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_hed.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_gen.png?raw=true"> |

View File

@ -3,9 +3,11 @@ import cv2
import torch import torch
from einops import rearrange from einops import rearrange
import os
from modules import extensions
class Network(torch.nn.Module): class Network(torch.nn.Module):
def __init__(self): def __init__(self, model_path):
super().__init__() super().__init__()
self.netVggOne = torch.nn.Sequential( self.netVggOne = torch.nn.Sequential(
@ -64,7 +66,7 @@ class Network(torch.nn.Module):
torch.nn.Sigmoid() torch.nn.Sigmoid()
) )
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()}) self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
# end # end
def forward(self, tenInput): def forward(self, tenInput):
@ -93,11 +95,19 @@ class Network(torch.nn.Module):
# end # end
# end # end
netNetwork = None
netNetwork = Network().cuda().eval() remote_model_path = "https://huggingface.co/datasets/nyanko7/tmp-public/resolve/main/network-bsds500.pt"
modeldir = os.path.join(extensions.extensions_dir, "sd-webui-controlnet", "annotator")
def apply_hed(input_image): def apply_hed(input_image):
global netNetwork
if netNetwork is None:
modelpath = os.path.join(modeldir, "network-bsds500.pt")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=modeldir)
netNetwork = Network(modelpath).cuda().eval()
assert input_image.ndim == 3 assert input_image.ndim == 3
input_image = input_image[:, :, ::-1].copy() input_image = input_image[:, :, ::-1].copy()
with torch.no_grad(): with torch.no_grad():

BIN
samples/evt_gen.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

BIN
samples/evt_hed.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

BIN
samples/evt_source.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

After

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 359 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 240 KiB

After

Width:  |  Height:  |  Size: 59 KiB

View File

@ -162,15 +162,14 @@ class Script(scripts.Script):
selected = dd selected = dd
else: else:
selected = "None" selected = "None"
print(cn_models)
update = gr.Dropdown.update( update = gr.Dropdown.update(value=selected, choices=list(cn_models.keys()))
value=selected, choices=list(cn_models.keys()))
updates.append(update) updates.append(update)
return updates return updates
refresh_models = gr.Button(value='Refresh models') refresh_models = gr.Button(value='Refresh models')
refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns) refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
ctrls += (refresh_models, ) # ctrls += (refresh_models, )
def create_canvas(h, w): def create_canvas(h, w):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
@ -182,7 +181,7 @@ class Script(scripts.Script):
gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.') gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.')
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image]) create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
ctrls += (canvas_width, canvas_height, create_button, input_image, scribble_mode) ctrls += (input_image, scribble_mode)
return ctrls return ctrls
@ -212,15 +211,14 @@ class Script(scripts.Script):
self.latest_network.restore(unet) self.latest_network.restore(unet)
self.latest_network = None self.latest_network = None
enabled, module, model, weight, _ = args[:5] enabled, module, model, weight,image, scribble_mode = args
_, _, _, image, scribble_mode = args[5:]
if not enabled: if not enabled:
restore_networks() restore_networks()
return return
models_changed = self.latest_params[0] != module or self.latest_params[1] != model \ models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None
if models_changed: if models_changed:
restore_networks() restore_networks()
@ -273,6 +271,8 @@ class Script(scripts.Script):
self.set_infotext_fields(p, self.latest_params) self.set_infotext_fields(p, self.latest_params)
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
if self.latest_network is None:
return
if hasattr(self, "control") and self.control is not None: if hasattr(self, "control") and self.control is not None:
processed.images.append(ToPILImage()((self.control).clip(0, 255))) processed.images.append(ToPILImage()((self.control).clip(0, 255)))