feat: make hed mode works
|
|
@ -13,6 +13,8 @@ Dragging a file on the Web UI will freeze the entire page. It is better to use t
|
|||
|
||||
### Install
|
||||
|
||||
Some users may need to install the cv2 library before installing it: `pip install opencv-python`
|
||||
|
||||
1. Open "Extensions" tab.
|
||||
2. Open "Install from URL" tab in the tab.
|
||||
3. Enter URL of this repo to "URL for extension's git repository".
|
||||
|
|
@ -32,3 +34,4 @@ Currently it supports both full models and trimmed models. Use `extract_controln
|
|||
| Source | Input | Output |
|
||||
|:-------------------------:|:-------------------------:|:-------------------------:|
|
||||
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_gen.png?raw=true"> |
|
||||
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_source.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_hed.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_gen.png?raw=true"> |
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ import cv2
|
|||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
import os
|
||||
from modules import extensions
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
|
||||
self.netVggOne = torch.nn.Sequential(
|
||||
|
|
@ -64,7 +66,7 @@ class Network(torch.nn.Module):
|
|||
torch.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()})
|
||||
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
|
||||
# end
|
||||
|
||||
def forward(self, tenInput):
|
||||
|
|
@ -93,11 +95,19 @@ class Network(torch.nn.Module):
|
|||
# end
|
||||
# end
|
||||
|
||||
|
||||
netNetwork = Network().cuda().eval()
|
||||
|
||||
netNetwork = None
|
||||
remote_model_path = "https://huggingface.co/datasets/nyanko7/tmp-public/resolve/main/network-bsds500.pt"
|
||||
modeldir = os.path.join(extensions.extensions_dir, "sd-webui-controlnet", "annotator")
|
||||
|
||||
def apply_hed(input_image):
|
||||
global netNetwork
|
||||
if netNetwork is None:
|
||||
modelpath = os.path.join(modeldir, "network-bsds500.pt")
|
||||
if not os.path.exists(modelpath):
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
load_file_from_url(remote_model_path, model_dir=modeldir)
|
||||
netNetwork = Network(modelpath).cuda().eval()
|
||||
|
||||
assert input_image.ndim == 3
|
||||
input_image = input_image[:, :, ::-1].copy()
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
|
After Width: | Height: | Size: 176 KiB |
|
After Width: | Height: | Size: 137 KiB |
|
After Width: | Height: | Size: 142 KiB |
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 9.0 KiB |
|
Before Width: | Height: | Size: 359 KiB After Width: | Height: | Size: 95 KiB |
|
Before Width: | Height: | Size: 240 KiB After Width: | Height: | Size: 59 KiB |
|
|
@ -162,15 +162,14 @@ class Script(scripts.Script):
|
|||
selected = dd
|
||||
else:
|
||||
selected = "None"
|
||||
print(cn_models)
|
||||
update = gr.Dropdown.update(
|
||||
value=selected, choices=list(cn_models.keys()))
|
||||
|
||||
update = gr.Dropdown.update(value=selected, choices=list(cn_models.keys()))
|
||||
updates.append(update)
|
||||
return updates
|
||||
|
||||
refresh_models = gr.Button(value='Refresh models')
|
||||
refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
|
||||
ctrls += (refresh_models, )
|
||||
# ctrls += (refresh_models, )
|
||||
|
||||
def create_canvas(h, w):
|
||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||
|
|
@ -182,7 +181,7 @@ class Script(scripts.Script):
|
|||
gr.Markdown(value='Don\'t drag image into the box! Use upload instead. Change your brush width to make it thinner if you want to draw something.')
|
||||
|
||||
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
|
||||
ctrls += (canvas_width, canvas_height, create_button, input_image, scribble_mode)
|
||||
ctrls += (input_image, scribble_mode)
|
||||
|
||||
return ctrls
|
||||
|
||||
|
|
@ -212,15 +211,14 @@ class Script(scripts.Script):
|
|||
self.latest_network.restore(unet)
|
||||
self.latest_network = None
|
||||
|
||||
enabled, module, model, weight, _ = args[:5]
|
||||
_, _, _, image, scribble_mode = args[5:]
|
||||
enabled, module, model, weight,image, scribble_mode = args
|
||||
|
||||
if not enabled:
|
||||
restore_networks()
|
||||
return
|
||||
|
||||
models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
|
||||
or self.latest_model_hash != p.sd_model.sd_model_hash
|
||||
or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None
|
||||
|
||||
if models_changed:
|
||||
restore_networks()
|
||||
|
|
@ -273,6 +271,8 @@ class Script(scripts.Script):
|
|||
self.set_infotext_fields(p, self.latest_params)
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
if self.latest_network is None:
|
||||
return
|
||||
if hasattr(self, "control") and self.control is not None:
|
||||
processed.images.append(ToPILImage()((self.control).clip(0, 255)))
|
||||
|
||||
|
|
|
|||