mirror of https://github.com/vladmandic/automatic
update requirements and add control-lllite
parent
60e0e110dd
commit
068809719a
|
|
@ -1,11 +1,12 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2023-12-24
|
||||
## Update for 2023-12-25
|
||||
|
||||
*Note*: based on `diffusers==0.25.0.dev0`
|
||||
|
||||
- **Control**
|
||||
- native implementation of **ControlNet**, **ControlNet XS**, **T2I Adapters** and **IP Adapters**
|
||||
- native implementation of all image control methods:
|
||||
**ControlNet**, **ControlNet XS**, **Control LLLite**, **T2I Adapters** and **IP Adapters**
|
||||
- top-level **Control** next to **Text** and **Image** generate
|
||||
- supports all variations of **SD15** and **SD-XL** models
|
||||
- supports *Text*, *Image*, *Batch* and *Video* processing
|
||||
|
|
@ -104,6 +105,7 @@
|
|||
- improve handling of long filenames and filenames during batch processing
|
||||
- do not set preview samples when using via api
|
||||
- avoid unnecessary resizes in img2img and inpaint
|
||||
- safe handling of config updates avoid file corruption on I/O errors
|
||||
- updated `cli/simple-txt2img.py` and `cli/simple-img2img.py` scripts
|
||||
- save `params.txt` regardless of image save status
|
||||
- update built-in log monitor in ui, thanks @midcoastal
|
||||
|
|
|
|||
|
|
@ -66,45 +66,54 @@ function extract_image_from_gallery(gallery) {
|
|||
|
||||
window.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around
|
||||
|
||||
function switchToTab(tab) {
|
||||
const tabs = Array.from(gradioApp().querySelectorAll('#tabs > .tab-nav > button'));
|
||||
const btn = tabs?.find((t) => t.innerText === tab);
|
||||
log('switchToTab', tab);
|
||||
if (btn) btn.click();
|
||||
}
|
||||
|
||||
function switch_to_txt2img(...args) {
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
|
||||
switchToTab('Text');
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
function switch_to_img2img_tab(no) {
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||
switchToTab('Image');
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
|
||||
}
|
||||
|
||||
function switch_to_img2img(...args) {
|
||||
switchToTab('Image');
|
||||
switch_to_img2img_tab(0);
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
function switch_to_sketch(...args) {
|
||||
switchToTab('Image');
|
||||
switch_to_img2img_tab(1);
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
function switch_to_inpaint(...args) {
|
||||
switchToTab('Image');
|
||||
switch_to_img2img_tab(2);
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
function switch_to_inpaint_sketch(...args) {
|
||||
switchToTab('Image');
|
||||
switch_to_img2img_tab(3);
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
function switch_to_extras(...args) {
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
|
||||
switchToTab('Process');
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
function switch_to_control(...args) {
|
||||
const tabs = Array.from(gradioApp().querySelector('#tabs').querySelectorAll('button'));
|
||||
const btn = tabs.find((el) => el.innerText.toLowerCase() === 'control');
|
||||
btn.click();
|
||||
switchToTab('Control');
|
||||
return Array.from(arguments);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Union
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import
|
||||
from modules.shared import log
|
||||
from modules import errors
|
||||
|
|
@ -33,7 +34,7 @@ models = {}
|
|||
all_models = {}
|
||||
all_models.update(predefined_sd15)
|
||||
all_models.update(predefined_sdxl)
|
||||
cache_dir = 'models/control/adapters'
|
||||
cache_dir = 'models/control/adapter'
|
||||
|
||||
|
||||
def list_models(refresh=False):
|
||||
|
|
@ -105,10 +106,10 @@ class Adapter():
|
|||
|
||||
|
||||
class AdapterPipeline():
|
||||
def __init__(self, adapter: T2IAdapter | list[T2IAdapter], pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None):
|
||||
def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
|
||||
t0 = time.time()
|
||||
self.orig_pipeline = pipeline
|
||||
self.pipeline = None
|
||||
self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None
|
||||
if pipeline is None:
|
||||
log.error(f'Control {what} pipeline: model not loaded')
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
import time
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
|
||||
from typing import Union
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
|
||||
from modules.shared import log, opts
|
||||
from modules import errors
|
||||
|
||||
|
|
@ -38,11 +37,11 @@ models = {}
|
|||
all_models = {}
|
||||
all_models.update(predefined_sd15)
|
||||
all_models.update(predefined_sdxl)
|
||||
cache_dir = 'models/control/controlnets'
|
||||
cache_dir = 'models/control/controlnet'
|
||||
|
||||
|
||||
def find_models():
|
||||
path = os.path.join(opts.control_dir, 'controlnets')
|
||||
path = os.path.join(opts.control_dir, 'controlnet')
|
||||
files = os.listdir(path)
|
||||
files = [f for f in files if f.endswith('.safetensors')]
|
||||
downloaded_models = {}
|
||||
|
|
@ -123,14 +122,14 @@ class ControlNet():
|
|||
|
||||
|
||||
class ControlNetPipeline():
|
||||
def __init__(self, controlnet: ControlNetModel | list[ControlNetModel], pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None):
|
||||
def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
|
||||
t0 = time.time()
|
||||
self.orig_pipeline = pipeline
|
||||
self.pipeline = None
|
||||
if pipeline is None:
|
||||
log.error('Control model pipeline: model not loaded')
|
||||
return
|
||||
if isinstance(pipeline, StableDiffusionXLPipeline):
|
||||
elif isinstance(pipeline, StableDiffusionXLPipeline):
|
||||
self.pipeline = StableDiffusionXLControlNetPipeline(
|
||||
vae=pipeline.vae,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from modules.shared import log, opts
|
||||
from modules import errors
|
||||
from modules.control.controlnetslite_model import ControlNetLLLite
|
||||
|
||||
|
||||
what = 'ControlLLLite'
|
||||
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
debug('Trace: CONTROL')
|
||||
predefined_sd15 = {
|
||||
}
|
||||
predefined_sdxl = {
|
||||
'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny',
|
||||
'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime',
|
||||
'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime',
|
||||
'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta',
|
||||
'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime',
|
||||
'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2',
|
||||
}
|
||||
models = {}
|
||||
all_models = {}
|
||||
all_models.update(predefined_sd15)
|
||||
all_models.update(predefined_sdxl)
|
||||
cache_dir = 'models/control/lite'
|
||||
|
||||
|
||||
def find_models():
|
||||
path = os.path.join(opts.control_dir, 'lite')
|
||||
files = os.listdir(path)
|
||||
files = [f for f in files if f.endswith('.safetensors')]
|
||||
downloaded_models = {}
|
||||
for f in files:
|
||||
basename = os.path.splitext(f)[0]
|
||||
downloaded_models[basename] = os.path.join(path, f)
|
||||
all_models.update(downloaded_models)
|
||||
return downloaded_models
|
||||
|
||||
|
||||
def list_models(refresh=False):
|
||||
import modules.shared
|
||||
global models # pylint: disable=global-statement
|
||||
if not refresh and len(models) > 0:
|
||||
return models
|
||||
models = {}
|
||||
if modules.shared.sd_model_type == 'none':
|
||||
models = ['None']
|
||||
elif modules.shared.sd_model_type == 'sdxl':
|
||||
models = ['None'] + sorted(predefined_sdxl) + sorted(find_models())
|
||||
elif modules.shared.sd_model_type == 'sd':
|
||||
models = ['None'] + sorted(predefined_sd15) + sorted(find_models())
|
||||
else:
|
||||
log.warning(f'Control {what} model list failed: unknown model type')
|
||||
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models())
|
||||
debug(f'Control list {what}: path={cache_dir} models={models}')
|
||||
return models
|
||||
|
||||
|
||||
class ControlLLLite():
|
||||
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
|
||||
self.model: ControlNetLLLite = None
|
||||
self.model_id: str = model_id
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.load_config = { 'cache_dir': cache_dir }
|
||||
if load_config is not None:
|
||||
self.load_config.update(load_config)
|
||||
if model_id is not None:
|
||||
self.load()
|
||||
|
||||
def reset(self):
|
||||
if self.model is not None:
|
||||
log.debug(f'Control {what} model unloaded')
|
||||
self.model = None
|
||||
self.model_id = None
|
||||
|
||||
def load(self, model_id: str = None) -> str:
|
||||
try:
|
||||
t0 = time.time()
|
||||
model_id = model_id or self.model_id
|
||||
if model_id is None or model_id == 'None':
|
||||
self.reset()
|
||||
return
|
||||
model_path = all_models[model_id]
|
||||
if model_path == '':
|
||||
return
|
||||
if model_path is None:
|
||||
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
|
||||
return
|
||||
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}')
|
||||
if model_path.endswith('.safetensors'):
|
||||
self.model = ControlNetLLLite(model_path)
|
||||
else:
|
||||
import huggingface_hub as hf
|
||||
folder, filename = os.path.split(model_path)
|
||||
model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir)
|
||||
self.model = ControlNetLLLite(model_path)
|
||||
if self.device is not None:
|
||||
self.model.to(self.device)
|
||||
if self.dtype is not None:
|
||||
self.model.to(self.dtype)
|
||||
t1 = time.time()
|
||||
self.model_id = model_id
|
||||
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
|
||||
return f'{what} loaded model: {model_id}'
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
|
||||
errors.display(e, f'Control {what} load')
|
||||
return f'{what} failed to load model: {model_id}'
|
||||
|
||||
|
||||
class ControlLLitePipeline():
|
||||
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]):
|
||||
self.pipeline = pipeline
|
||||
self.nets = []
|
||||
|
||||
def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning):
|
||||
if image is None:
|
||||
return
|
||||
self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet
|
||||
debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}')
|
||||
weight = [conditioning] if isinstance(conditioning, float) else conditioning
|
||||
images = [image] if isinstance(image, Image.Image) else image
|
||||
images = [i.convert('RGB') for i in images]
|
||||
for i, cn in enumerate(self.nets):
|
||||
cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)])
|
||||
|
||||
def restore(self):
|
||||
from modules.control.controlnetslite_model import clear_all_lllite
|
||||
clear_all_lllite()
|
||||
self.nets = []
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
# Credits: <https://github.com/mycodeiscat/ControlNet-LLLite-diffusers>
|
||||
# <https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI/blob/main/node_control_net_lllite.py>
|
||||
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
all_hack = {}
|
||||
|
||||
|
||||
class LLLiteModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
is_conv2d: bool,
|
||||
in_dim: int,
|
||||
depth: int,
|
||||
cond_emb_dim: int,
|
||||
mlp_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.is_conv2d = is_conv2d
|
||||
self.is_first = False
|
||||
modules = []
|
||||
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
|
||||
if depth == 1:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
elif depth == 2:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
||||
elif depth == 3:
|
||||
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
self.conditioning1 = torch.nn.Sequential(*modules)
|
||||
if self.is_conv2d:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim, in_dim),
|
||||
)
|
||||
self.depth = depth
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.cond_emb is None:
|
||||
# print(f"cond_emb is None, {self.name}")
|
||||
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
|
||||
# if blk_shape is not None:
|
||||
# b, c, h, w = blk_shape
|
||||
# cx = torch.nn.functional.interpolate(cx, (h, w), mode="nearest-exact")
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
self.cond_emb = cx
|
||||
cx = self.cond_emb
|
||||
|
||||
# uncond/condでxはバッチサイズが2倍
|
||||
if x.shape[0] != cx.shape[0]:
|
||||
if self.is_conv2d:
|
||||
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
|
||||
else:
|
||||
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
|
||||
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.mid(cx)
|
||||
cx = self.up(cx)
|
||||
return cx
|
||||
|
||||
|
||||
def clear_all_lllite():
|
||||
global all_hack # pylint: disable=global-statement
|
||||
for k, v in all_hack.items():
|
||||
k.forward = v
|
||||
k.lllite_list = []
|
||||
all_hack = {}
|
||||
return
|
||||
|
||||
|
||||
class ControlNetLLLite(torch.nn.Module): # pylint: disable=abstract-method
|
||||
def __init__(self, path: str):
|
||||
super().__init__()
|
||||
module_weights = {}
|
||||
try:
|
||||
state_dict = load_file(path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load {path}") from e
|
||||
for key, value in state_dict.items():
|
||||
fragments = key.split(".")
|
||||
module_name = fragments[0]
|
||||
weight_name = ".".join(fragments[1:])
|
||||
if module_name not in module_weights:
|
||||
module_weights[module_name] = {}
|
||||
module_weights[module_name][weight_name] = value
|
||||
modules = {}
|
||||
for module_name, weights in module_weights.items():
|
||||
if "conditioning1.4.weight" in weights:
|
||||
depth = 3
|
||||
elif weights["conditioning1.2.weight"].shape[-1] == 4:
|
||||
depth = 2
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
module = LLLiteModule(
|
||||
name=module_name,
|
||||
is_conv2d=weights["down.0.weight"].ndim == 4,
|
||||
in_dim=weights["down.0.weight"].shape[1],
|
||||
depth=depth,
|
||||
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
|
||||
mlp_dim=weights["down.0.weight"].shape[0],
|
||||
)
|
||||
# info = module.load_state_dict(weights)
|
||||
modules[module_name] = module
|
||||
setattr(self, module_name, module)
|
||||
if len(modules) == 1:
|
||||
module.is_first = True
|
||||
|
||||
self.modules = modules
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def apply(self, pipe, cond, weight): # pylint: disable=arguments-differ
|
||||
map_down_lllite_to_unet = {4: (1, 0), 5: (1, 1), 7: (2, 0), 8: (2, 1)}
|
||||
model = pipe.unet
|
||||
if type(cond) != torch.Tensor:
|
||||
cond = torch.tensor(cond)
|
||||
cond = cond/255 # 0-255 -> 0-1
|
||||
cond_image = cond.unsqueeze(dim=0).permute(0, 3, 1, 2) # h,w,c -> b,c,h,w
|
||||
cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-1
|
||||
|
||||
for module in self.modules.values():
|
||||
module.set_cond_image(cond_image)
|
||||
for k, v in self.modules.items():
|
||||
k = k.replace('middle_block', 'middle_blocks_0')
|
||||
match = re.match("lllite_unet_(.*)_blocks_(.*)_1_transformer_blocks_(.*)_(.*)_to_(.*)", k, re.M | re.I)
|
||||
assert match, 'Failed to load ControlLLLite!'
|
||||
root = match.group(1)
|
||||
block = match.group(2)
|
||||
block_number = match.group(3)
|
||||
attn_name = match.group(4)
|
||||
proj_name = match.group(5)
|
||||
if root == 'input':
|
||||
mapped_block, mapped_number = map_down_lllite_to_unet[int(block)]
|
||||
b = model.down_blocks[mapped_block].attentions[int(mapped_number)].transformer_blocks[int(block_number)]
|
||||
elif root == 'output':
|
||||
# TODO: Map up unet blocks to lite blocks
|
||||
print(f'Not implemented: {root}')
|
||||
else:
|
||||
b = model.mid_block.attentions[0].transformer_blocks[int(block_number)]
|
||||
b = getattr(b, attn_name, None)
|
||||
assert b is not None, 'Failed to load ControlLLLite!'
|
||||
b = getattr(b, 'to_' + proj_name, None)
|
||||
assert b is not None, 'Failed to load ControlLLLite!'
|
||||
if not hasattr(b, 'lllite_list'):
|
||||
b.lllite_list = []
|
||||
if len(b.lllite_list) == 0:
|
||||
all_hack[b] = b.forward
|
||||
b.forward = self.get_hacked_forward(original_forward=b.forward, model=model, blk=b)
|
||||
b.lllite_list.append((weight, v))
|
||||
return
|
||||
|
||||
def get_hacked_forward(self, original_forward, model, blk):
|
||||
@torch.no_grad()
|
||||
def forward(x, **kwargs):
|
||||
hack = 0
|
||||
for weight, module in blk.lllite_list:
|
||||
module.to(x.device)
|
||||
module.to(x.dtype)
|
||||
hack = hack + module(x) * weight
|
||||
x = x + hack
|
||||
return original_forward(x, **kwargs)
|
||||
return forward
|
||||
|
|
@ -1,13 +1,12 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Union
|
||||
from modules.shared import log, opts
|
||||
from modules import errors
|
||||
|
||||
ok = True
|
||||
try:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import ControlNetXSModel, StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetXSModel, StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline
|
||||
except Exception:
|
||||
from diffusers import ControlNetModel
|
||||
ControlNetXSModel = ControlNetModel # dummy
|
||||
|
|
@ -27,11 +26,11 @@ models = {}
|
|||
all_models = {}
|
||||
all_models.update(predefined_sd15)
|
||||
all_models.update(predefined_sdxl)
|
||||
cache_dir = 'models/control/controlnetsxs'
|
||||
cache_dir = 'models/control/xs'
|
||||
|
||||
|
||||
def find_models():
|
||||
path = os.path.join(opts.control_dir, 'controlnetsxs')
|
||||
path = os.path.join(opts.control_dir, 'xs')
|
||||
files = os.listdir(path)
|
||||
files = [f for f in files if f.endswith('.safetensors')]
|
||||
downloaded_models = {}
|
||||
|
|
@ -115,7 +114,7 @@ class ControlNetXS():
|
|||
|
||||
|
||||
class ControlNetXSPipeline():
|
||||
def __init__(self, controlnet: ControlNetXSModel | list[ControlNetXSModel], pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None):
|
||||
def __init__(self, controlnet: Union[ControlNetXSModel, list[ControlNetXSModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
|
||||
t0 = time.time()
|
||||
self.orig_pipeline = pipeline
|
||||
self.pipeline = None
|
||||
|
|
|
|||
|
|
@ -31,5 +31,6 @@ class CannyDetector:
|
|||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
detected_map = detected_map.convert('L')
|
||||
|
||||
return detected_map
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class EdgeDetector:
|
|||
edge_map = cv2.resize(edge_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if output_type == "pil":
|
||||
edge_map = edge_map.convert('L')
|
||||
edge_map = Image.fromarray(edge_map)
|
||||
|
||||
return edge_map
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import time
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from modules.shared import log
|
||||
from modules.control.proc.reference_sd15 import StableDiffusionReferencePipeline
|
||||
from modules.control.proc.reference_sdxl import StableDiffusionXLReferencePipeline
|
||||
from modules.shared import log
|
||||
|
||||
|
||||
what = 'Reference'
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from modules.control import unit
|
|||
from modules.control import processors
|
||||
from modules.control import controlnets # lllyasviel ControlNet
|
||||
from modules.control import controlnetsxs # VisLearn ControlNet-XS
|
||||
from modules.control import controlnetslite # Kohya ControlLLLite
|
||||
from modules.control import adapters # TencentARC T2I-Adapter
|
||||
from modules.control import reference # ControlNet-Reference
|
||||
from modules.control import ipadapter # IP-Adapter
|
||||
|
|
@ -156,7 +157,11 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
active_strength.append(float(u.strength))
|
||||
active_start.append(float(u.start))
|
||||
active_end.append(float(u.end))
|
||||
p.guess_mode = u.guess
|
||||
shared.log.debug(f'Control ControlNet-XS unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
|
||||
elif unit_type == 'lite' and u.controlnet.model is not None:
|
||||
active_process.append(u.process)
|
||||
active_model.append(u.controlnet)
|
||||
active_strength.append(float(u.strength))
|
||||
shared.log.debug(f'Control ControlNet-XS unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
|
||||
elif unit_type == 'reference':
|
||||
p.override = u.override
|
||||
|
|
@ -173,7 +178,7 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
|
||||
has_models = False
|
||||
selected_models: List[Union[controlnets.ControlNetModel, controlnetsxs.ControlNetXSModel, adapters.AdapterModel]] = None
|
||||
if unit_type == 'adapter' or unit_type == 'controlnet' or unit_type == 'xs':
|
||||
if unit_type == 'adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite':
|
||||
if len(active_model) == 0:
|
||||
selected_models = None
|
||||
elif len(active_model) == 1:
|
||||
|
|
@ -216,6 +221,14 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
pipe = instance.pipeline
|
||||
if inits is not None:
|
||||
shared.log.warning('Control: ControlNet-XS does not support separate init image')
|
||||
elif unit_type == 'lite' and has_models:
|
||||
p.extra_generation_params["Control mode"] = 'ControlLLLite'
|
||||
p.extra_generation_params["Control conditioning"] = use_conditioning
|
||||
p.controlnet_conditioning_scale = use_conditioning
|
||||
instance = controlnetslite.ControlLLitePipeline(shared.sd_model)
|
||||
pipe = instance.pipeline
|
||||
if inits is not None:
|
||||
shared.log.warning('Control: ControlLLLite does not support separate init image')
|
||||
elif unit_type == 'reference':
|
||||
p.extra_generation_params["Control mode"] = 'Reference'
|
||||
p.extra_generation_params["Control attention"] = p.attention
|
||||
|
|
@ -232,6 +245,8 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
if len(active_strength) > 0:
|
||||
p.strength = active_strength[0]
|
||||
pipe = diffusers.AutoPipelineForImage2Image.from_pipe(shared.sd_model) # use set_diffuser_pipe
|
||||
instance = None
|
||||
|
||||
debug(f'Control pipeline: class={pipe.__class__} args={vars(p)}')
|
||||
t1, t2, t3 = time.time(), 0, 0
|
||||
status = True
|
||||
|
|
@ -396,7 +411,7 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
if hasattr(p, 'init_images'):
|
||||
del p.init_images # control never uses init_image as-is
|
||||
if pipe is not None:
|
||||
if not has_models and (unit_type == 'controlnet' or unit_type == 'adapter' or unit_type == 'xs'): # run in txt2img or img2img mode
|
||||
if not has_models and (unit_type == 'controlnet' or unit_type == 'adapter' or unit_type == 'xs' or unit_type == 'lite'): # run in txt2img or img2img mode
|
||||
if processed_image is not None:
|
||||
p.init_images = [processed_image]
|
||||
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
|
||||
|
|
@ -411,6 +426,8 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) # only controlnet supports img2img
|
||||
else:
|
||||
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE)
|
||||
if unit_type == 'lite':
|
||||
instance.apply(selected_models, p.image, use_conditioning)
|
||||
|
||||
# pipeline
|
||||
output = None
|
||||
|
|
@ -480,6 +497,8 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
|
|||
image_txt = f'| Frames {len(output_images)} | Size {output_images[0].width}x{output_images[0].height}'
|
||||
|
||||
image_txt += f' | {util.dict2str(p.extra_generation_params)}'
|
||||
if hasattr(instance, 'restore'):
|
||||
instance.restore()
|
||||
restore_pipeline()
|
||||
debug(f'Control ready: {image_txt}')
|
||||
if is_generator:
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ def test_adapters(prompt, negative, image):
|
|||
model_id = 'None'
|
||||
if shared.state.interrupted:
|
||||
continue
|
||||
output = image
|
||||
output = image.copy()
|
||||
if model_id != 'None':
|
||||
adapter = adapters.Adapter(model_id=model_id, device=devices.device, dtype=devices.dtype)
|
||||
if adapter is None:
|
||||
|
|
@ -122,6 +122,7 @@ def test_adapters(prompt, negative, image):
|
|||
pipe = adapters.AdapterPipeline(adapter=adapter.model, pipeline=shared.sd_model)
|
||||
pipe.pipeline.to(device=devices.device, dtype=devices.dtype)
|
||||
sd_models.set_diffuser_options(pipe)
|
||||
image = image.convert('L') if 'Canny' in model_id or 'Sketch' in model_id else image.convert('RGB')
|
||||
try:
|
||||
res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil')
|
||||
output = res.images[0]
|
||||
|
|
@ -201,3 +202,56 @@ def test_xs(prompt, negative, image):
|
|||
grid.paste(thumb, box=(x, y))
|
||||
yield None, grid, None, images
|
||||
return None, grid, None, images # preview_process, output_image, output_video, output_gallery
|
||||
|
||||
|
||||
def test_lite(prompt, negative, image):
|
||||
from modules import devices, sd_models
|
||||
from modules.control import controlnetslite
|
||||
if image is None:
|
||||
shared.log.error('Image not loaded')
|
||||
return None, None, None
|
||||
from PIL import ImageDraw, ImageFont
|
||||
images = []
|
||||
for model_id in controlnetslite.list_models():
|
||||
if model_id is None:
|
||||
model_id = 'None'
|
||||
if shared.state.interrupted:
|
||||
continue
|
||||
output = image
|
||||
if model_id != 'None':
|
||||
lite = controlnetslite.ControlLLLite(model_id=model_id, device=devices.device, dtype=devices.dtype)
|
||||
if lite is None:
|
||||
shared.log.error(f'Control-LLite load failed: id="{model_id}"')
|
||||
continue
|
||||
shared.log.info(f'Testing ControlNet-XS: {model_id}')
|
||||
pipe = controlnetslite.ControlLLitePipeline(pipeline=shared.sd_model)
|
||||
pipe.apply(controlnet=lite.model, image=image, conditioning=1.0)
|
||||
pipe.pipeline.to(device=devices.device, dtype=devices.dtype)
|
||||
sd_models.set_diffuser_options(pipe)
|
||||
try:
|
||||
res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil')
|
||||
output = res.images[0]
|
||||
except Exception as e:
|
||||
errors.display(e, f'ControlNet-XS {model_id} inference')
|
||||
model_id = f'{model_id} error'
|
||||
pipe.restore()
|
||||
draw = ImageDraw.Draw(output)
|
||||
font = ImageFont.truetype('DejaVuSansMono', 48)
|
||||
draw.text((10, 10), model_id, (0,0,0), font=font)
|
||||
draw.text((8, 8), model_id, (255,255,255), font=font)
|
||||
images.append(output)
|
||||
yield output, None, None, images
|
||||
rows = round(math.sqrt(len(images)))
|
||||
cols = math.ceil(len(images) / rows)
|
||||
w, h = 256, 256
|
||||
size = (cols * w + cols, rows * h + rows)
|
||||
grid = Image.new('RGB', size=size, color='black')
|
||||
shared.log.info(f'Test ControlNet-XS: images={len(images)} grid={grid}')
|
||||
for i, image in enumerate(images):
|
||||
x = (i % cols * w) + (i % cols)
|
||||
y = (i // cols * h) + (i // cols)
|
||||
thumb = image.copy().convert('RGB')
|
||||
thumb.thumbnail((w, h), Image.Resampling.HAMMING)
|
||||
grid.paste(thumb, box=(x, y))
|
||||
yield None, grid, None, images
|
||||
return None, grid, None, images # preview_process, output_image, output_video, output_gallery
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from modules.shared import log
|
|||
from modules.control import processors
|
||||
from modules.control import controlnets
|
||||
from modules.control import controlnetsxs
|
||||
from modules.control import controlnetslite
|
||||
from modules.control import adapters
|
||||
from modules.control import reference # pylint: disable=unused-import
|
||||
|
||||
|
|
@ -110,6 +111,8 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
|
|||
self.controlnet = controlnets.ControlNet(device=default_device, dtype=default_dtype)
|
||||
elif self.type == 'xs':
|
||||
self.controlnet = controlnetsxs.ControlNetXS(device=default_device, dtype=default_dtype)
|
||||
elif self.type == 'lite':
|
||||
self.controlnet = controlnetslite.ControlLLLite(device=default_device, dtype=default_dtype)
|
||||
elif self.type == 'reference':
|
||||
pass
|
||||
else:
|
||||
|
|
@ -132,6 +135,11 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
|
|||
model_id.change(fn=self.controlnet.load, inputs=[model_id, extra_controls[0]], outputs=[result_txt], show_progress=True)
|
||||
if extra_controls is not None and len(extra_controls) > 0:
|
||||
extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls)
|
||||
elif self.type == 'lite':
|
||||
if model_id is not None:
|
||||
model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True)
|
||||
if extra_controls is not None and len(extra_controls) > 0:
|
||||
extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls)
|
||||
elif self.type == 'reference':
|
||||
if extra_controls is not None and len(extra_controls) > 0:
|
||||
extra_controls[0].change(fn=reference_extra, inputs=extra_controls)
|
||||
|
|
|
|||
|
|
@ -125,7 +125,8 @@ def bind_buttons(buttons, send_image, send_generate_info):
|
|||
for tabname, button in buttons.items():
|
||||
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
||||
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
||||
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
||||
bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)
|
||||
register_paste_params_button(bindings)
|
||||
|
||||
|
||||
def register_paste_params_button(binding: ParamBinding):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Dict
|
|||
from urllib.parse import urlparse
|
||||
import PIL.Image as Image
|
||||
import rich.progress as p
|
||||
from modules import shared
|
||||
from modules import shared, errors
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
|
|
@ -68,6 +68,7 @@ def download_civit_meta(model_path: str, model_id):
|
|||
return msg
|
||||
except Exception as e:
|
||||
msg = f'CivitAI download error: id={model_id} url={url} file={fn} {e}'
|
||||
errors.display(e, 'CivitAI download error')
|
||||
shared.log.error(msg)
|
||||
return msg
|
||||
return f'CivitAI download error: id={model_id} url={url} code={r.status_code}'
|
||||
|
|
@ -100,7 +101,8 @@ def download_civit_preview(model_path: str, preview_url: str):
|
|||
except Exception as e:
|
||||
os.remove(preview_file)
|
||||
res += f' error={e}'
|
||||
shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} {e}')
|
||||
shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} written={written} {e}')
|
||||
errors.display(e, 'CivitAI download error')
|
||||
shared.state.end()
|
||||
if img is None:
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -217,6 +217,8 @@ def readfile(filename, silent=False, lock=False):
|
|||
def writefile(data, filename, mode='w', silent=False):
|
||||
lock = None
|
||||
locked = False
|
||||
import tempfile
|
||||
|
||||
|
||||
def default(obj):
|
||||
log.error(f"Saving: {filename} not a valid object: {obj}")
|
||||
|
|
@ -239,13 +241,19 @@ def writefile(data, filename, mode='w', silent=False):
|
|||
raise ValueError('not a valid object')
|
||||
lock = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log)
|
||||
locked = lock.acquire_write_lock(blocking=True, timeout=3)
|
||||
with open(filename, mode, encoding="utf8") as file:
|
||||
file.write(output)
|
||||
with tempfile.NamedTemporaryFile(mode=mode, encoding="utf8", delete=False, dir=os.path.dirname(filename)) as f:
|
||||
f.write(output)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(f.name, filename)
|
||||
# with open(filename, mode=mode, encoding="utf8") as file:
|
||||
# file.write(output)
|
||||
t1 = time.time()
|
||||
if not silent:
|
||||
log.debug(f'Save: file="{filename}" json={len(data)} bytes={len(output)} time={t1-t0:.3f}')
|
||||
except Exception as e:
|
||||
log.error(f'Saving failed: {filename} {e}')
|
||||
errors.display(e, 'Saving failed')
|
||||
finally:
|
||||
if lock is not None:
|
||||
lock.release_read_lock()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import modules.script_callbacks
|
|||
|
||||
|
||||
folder_symbol = symbols.folder
|
||||
debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
debug('Trace: PASTE')
|
||||
|
||||
|
||||
def update_generation_info(generation_info, html_info, img_index):
|
||||
|
|
@ -214,7 +216,10 @@ def create_output_panel(tabname, preview=True):
|
|||
clip_files.click(fn=None, _js='clip_gallery_urls', inputs=[result_gallery], outputs=[])
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
delete = gr.Button('Delete', elem_id=f'delete_{tabname}')
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras", "control"])
|
||||
if shared.backend == shared.Backend.ORIGINAL:
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||
else:
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "control", "extras"])
|
||||
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
with gr.Group():
|
||||
|
|
@ -245,9 +250,9 @@ def create_output_panel(tabname, preview=True):
|
|||
else:
|
||||
paste_field_names = []
|
||||
for paste_tabname, paste_button in buttons.items():
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names
|
||||
))
|
||||
debug(f'Create output panel: button={paste_button} tabname={paste_tabname}')
|
||||
bindings = parameters_copypaste.ParamBinding(paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names)
|
||||
parameters_copypaste.register_paste_params_button(bindings)
|
||||
return result_gallery, generation_info, html_info, html_info_formatted, html_log
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import gradio as gr
|
|||
from modules.control import unit
|
||||
from modules.control import controlnets # lllyasviel ControlNet
|
||||
from modules.control import controlnetsxs # vislearn ControlNet-XS
|
||||
from modules.control import controlnetslite # vislearn ControlNet-XS
|
||||
from modules.control import adapters # TencentARC T2I-Adapter
|
||||
from modules.control import processors # patrickvonplaten controlnet_aux
|
||||
from modules.control import reference # reference pipeline
|
||||
|
|
@ -12,7 +13,7 @@ from modules.ui_components import FormRow, FormGroup
|
|||
|
||||
|
||||
gr_height = 512
|
||||
max_units = 10
|
||||
max_units = 5
|
||||
units: list[unit.Unit] = [] # main state variable
|
||||
input_source = None
|
||||
input_init = None
|
||||
|
|
@ -23,15 +24,17 @@ debug('Trace: CONTROL')
|
|||
def initialize():
|
||||
from modules import devices
|
||||
shared.log.debug(f'Control initialize: models={shared.opts.control_dir}')
|
||||
controlnets.cache_dir = os.path.join(shared.opts.control_dir, 'controlnets')
|
||||
controlnetsxs.cache_dir = os.path.join(shared.opts.control_dir, 'controlnetsxs')
|
||||
adapters.cache_dir = os.path.join(shared.opts.control_dir, 'adapters')
|
||||
processors.cache_dir = os.path.join(shared.opts.control_dir, 'processors')
|
||||
controlnets.cache_dir = os.path.join(shared.opts.control_dir, 'controlnet')
|
||||
controlnetsxs.cache_dir = os.path.join(shared.opts.control_dir, 'xs')
|
||||
controlnetslite.cache_dir = os.path.join(shared.opts.control_dir, 'lite')
|
||||
adapters.cache_dir = os.path.join(shared.opts.control_dir, 'adapter')
|
||||
processors.cache_dir = os.path.join(shared.opts.control_dir, 'processor')
|
||||
unit.default_device = devices.device
|
||||
unit.default_dtype = devices.dtype
|
||||
os.makedirs(shared.opts.control_dir, exist_ok=True)
|
||||
os.makedirs(controlnets.cache_dir, exist_ok=True)
|
||||
os.makedirs(controlnetsxs.cache_dir, exist_ok=True)
|
||||
os.makedirs(controlnetslite.cache_dir, exist_ok=True)
|
||||
os.makedirs(adapters.cache_dir, exist_ok=True)
|
||||
os.makedirs(processors.cache_dir, exist_ok=True)
|
||||
|
||||
|
|
@ -58,7 +61,7 @@ def return_controls(res):
|
|||
def generate_click(job_id: str, active_tab: str, *args):
|
||||
from modules.control.run import control_run
|
||||
shared.log.debug(f'Control: tab={active_tab} job={job_id} args={args}')
|
||||
if active_tab not in ['controlnet', 'xs', 'adapter', 'reference']:
|
||||
if active_tab not in ['controlnet', 'xs', 'adapter', 'reference', 'lite']:
|
||||
return None, None, None, None, f'Control: Unknown mode: {active_tab} args={args}'
|
||||
shared.state.begin('control')
|
||||
progress.add_task_to_queue(job_id)
|
||||
|
|
@ -203,14 +206,14 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
with gr.Row(elem_id='control_settings'):
|
||||
|
||||
with gr.Accordion(open=False, label="Input", elem_id="control_input", elem_classes=["small-accordion"]):
|
||||
with gr.Row():
|
||||
input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type')
|
||||
with gr.Row():
|
||||
denoising_strength = gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="control_denoising_strength")
|
||||
with gr.Row():
|
||||
show_ip = gr.Checkbox(label="Enable IP adapter", value=False, elem_id="control_show_ip")
|
||||
with gr.Row():
|
||||
show_preview = gr.Checkbox(label="Show preview", value=False, elem_id="control_show_preview")
|
||||
with gr.Row():
|
||||
input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type')
|
||||
with gr.Row():
|
||||
denoising_strength = gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="control_denoising_strength")
|
||||
|
||||
resize_mode, resize_name, width, height, scale_by, selected_scale_tab, resize_time = ui.create_resize_inputs('control', [], time_selector=True, scale_visible=False, mode='Fixed')
|
||||
|
||||
|
|
@ -402,10 +405,10 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
extra_controls = [
|
||||
gr.Slider(label="Control factor", minimum=0.0, maximum=1.0, step=0.05, value=1.0, scale=3),
|
||||
]
|
||||
num_adaptor_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1)
|
||||
adaptor_ui_units = [] # list of hidable accordions
|
||||
num_adapter_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1)
|
||||
adapter_ui_units = [] # list of hidable accordions
|
||||
for i in range(max_units):
|
||||
with gr.Accordion(f'Adapter unit {i+1}', visible= i < num_adaptor_units.value) as unit_ui:
|
||||
with gr.Accordion(f'Adapter unit {i+1}', visible= i < num_adapter_units.value) as unit_ui:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
|
|
@ -417,7 +420,7 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
reset_btn = ui_components.ToolButton(value=ui_symbols.reset)
|
||||
image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool'])
|
||||
process_btn= ui_components.ToolButton(value=ui_symbols.preview)
|
||||
adaptor_ui_units.append(unit_ui)
|
||||
adapter_ui_units.append(unit_ui)
|
||||
units.append(unit.Unit(
|
||||
unit_type = 'adapter',
|
||||
result_txt = result_txt,
|
||||
|
|
@ -435,7 +438,47 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
)
|
||||
if i == 0:
|
||||
units[-1].enabled = True # enable first unit in group
|
||||
num_adaptor_units.change(fn=display_units, inputs=[num_adaptor_units], outputs=adaptor_ui_units)
|
||||
num_adapter_units.change(fn=display_units, inputs=[num_adapter_units], outputs=adapter_ui_units)
|
||||
|
||||
with gr.Tab('Lite') as _tab_lite:
|
||||
gr.HTML('<a href="https://huggingface.co/kohya-ss/controlnet-lllite">Control LLLite</a>')
|
||||
with gr.Row():
|
||||
extra_controls = [
|
||||
]
|
||||
num_lite_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1)
|
||||
lite_ui_units = [] # list of hidable accordions
|
||||
for i in range(max_units):
|
||||
with gr.Accordion(f'Control unit {i+1}', visible= i < num_lite_units.value) as unit_ui:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
enabled_cb = gr.Checkbox(value= i == 0, label="Enabled")
|
||||
process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None')
|
||||
model_id = gr.Dropdown(label="Model", choices=controlnetslite.list_models(), value='None')
|
||||
ui_common.create_refresh_button(model_id, controlnetslite.list_models, lambda: {"choices": controlnetslite.list_models(refresh=True)}, 'refresh_lite_models')
|
||||
model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0-i/10)
|
||||
reset_btn = ui_components.ToolButton(value=ui_symbols.reset)
|
||||
image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool'])
|
||||
process_btn= ui_components.ToolButton(value=ui_symbols.preview)
|
||||
lite_ui_units.append(unit_ui)
|
||||
units.append(unit.Unit(
|
||||
unit_type = 'lite',
|
||||
result_txt = result_txt,
|
||||
image_input = input_image,
|
||||
enabled_cb = enabled_cb,
|
||||
reset_btn = reset_btn,
|
||||
process_id = process_id,
|
||||
model_id = model_id,
|
||||
model_strength = model_strength,
|
||||
preview_process = preview_process,
|
||||
preview_btn = process_btn,
|
||||
image_upload = image_upload,
|
||||
extra_controls = extra_controls,
|
||||
)
|
||||
)
|
||||
if i == 0:
|
||||
units[-1].enabled = True # enable first unit in group
|
||||
num_lite_units.change(fn=display_units, inputs=[num_lite_units], outputs=lite_ui_units)
|
||||
|
||||
with gr.Tab('Reference') as _tab_reference:
|
||||
gr.HTML('<a href="https://github.com/Mikubill/sd-webui-controlnet/discussions/1236">ControlNet reference-only control</a>')
|
||||
|
|
@ -555,16 +598,19 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
generation_parameters_copypaste.register_paste_params_button(bindings)
|
||||
|
||||
if os.environ.get('SD_CONTROL_DEBUG', None) is not None: # debug only
|
||||
from modules.control.test import test_processors, test_controlnets, test_adapters, test_xs
|
||||
from modules.control.test import test_processors, test_controlnets, test_adapters, test_xs, test_lite
|
||||
gr.HTML('<br><h1>Debug</h1><br>')
|
||||
with gr.Row():
|
||||
run_test_processors_btn = gr.Button(value="Test:Processors", variant='primary', elem_classes=['control-button'])
|
||||
run_test_controlnets_btn = gr.Button(value="Test:ControlNets", variant='primary', elem_classes=['control-button'])
|
||||
run_test_xs_btn = gr.Button(value="Test:ControlNets-XS", variant='primary', elem_classes=['control-button'])
|
||||
run_test_adapters_btn = gr.Button(value="Test:Adapters", variant='primary', elem_classes=['control-button'])
|
||||
run_test_lite_btn = gr.Button(value="Test:Control-LLLite", variant='primary', elem_classes=['control-button'])
|
||||
|
||||
run_test_processors_btn.click(fn=test_processors, inputs=[input_image], outputs=[preview_process, output_image, output_video, output_gallery])
|
||||
run_test_controlnets_btn.click(fn=test_controlnets, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
|
||||
run_test_xs_btn.click(fn=test_xs, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
|
||||
run_test_adapters_btn.click(fn=test_adapters, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
|
||||
run_test_lite_btn.click(fn=test_lite, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
|
||||
|
||||
return [(control_ui, 'Control', 'control')]
|
||||
|
|
|
|||
|
|
@ -515,7 +515,8 @@ def create_ui():
|
|||
continue
|
||||
for item in page.list_items():
|
||||
meta = os.path.splitext(item['filename'])[0] + '.json'
|
||||
if ('card-no-preview.png' in item['preview'] or not os.path.isfile(meta)) and os.path.isfile(item['filename']):
|
||||
has_meta = os.path.isfile(meta) and os.stat(meta).st_size > 0
|
||||
if ('card-no-preview.png' in item['preview'] or not has_meta) and os.path.isfile(item['filename']):
|
||||
sha = item.get('hash', None)
|
||||
found = False
|
||||
if sha is not None and len(sha) > 0:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
target-version = "py39"
|
||||
select = [
|
||||
"F",
|
||||
"E",
|
||||
|
|
|
|||
|
|
@ -50,8 +50,8 @@ clip-interrogator==0.6.0
|
|||
antlr4-python3-runtime==4.9.3
|
||||
requests==2.31.0
|
||||
tqdm==4.66.1
|
||||
accelerate==0.24.1
|
||||
opencv-python-headless==4.7.0.72
|
||||
accelerate==0.25.0
|
||||
opencv-python-headless==4.8.1.78
|
||||
diffusers==0.24.0
|
||||
einops==0.4.1
|
||||
gradio==3.43.2
|
||||
|
|
@ -65,9 +65,9 @@ pytorch_lightning==1.9.4
|
|||
tokenizers==0.15.0
|
||||
transformers==4.36.2
|
||||
tomesd==0.1.3
|
||||
urllib3==1.26.15
|
||||
urllib3==1.26.18
|
||||
Pillow==10.1.0
|
||||
timm==0.9.7
|
||||
timm==0.9.12
|
||||
pydantic==1.10.13
|
||||
typing-extensions==4.8.0
|
||||
typing-extensions==4.9.0
|
||||
peft
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit 4d9737acc348cfc0ae0b139ec5e39ae61967e7ac
|
||||
Subproject commit 7dadbd912a1854bad2d7c6463eaecc458f2808c3
|
||||
Loading…
Reference in New Issue