update requirements and add control-lllite

pull/2650/head
Vladimir Mandic 2023-12-25 14:45:01 -05:00
parent 60e0e110dd
commit 068809719a
22 changed files with 555 additions and 62 deletions

View File

@ -1,11 +1,12 @@
# Change Log for SD.Next
## Update for 2023-12-24
## Update for 2023-12-25
*Note*: based on `diffusers==0.25.0.dev0`
- **Control**
- native implementation of **ControlNet**, **ControlNet XS**, **T2I Adapters** and **IP Adapters**
- native implementation of all image control methods:
**ControlNet**, **ControlNet XS**, **Control LLLite**, **T2I Adapters** and **IP Adapters**
- top-level **Control** next to **Text** and **Image** generate
- supports all variations of **SD15** and **SD-XL** models
- supports *Text*, *Image*, *Batch* and *Video* processing
@ -104,6 +105,7 @@
- improve handling of long filenames and filenames during batch processing
- do not set preview samples when using via api
- avoid unnecessary resizes in img2img and inpaint
- safe handling of config updates avoid file corruption on I/O errors
- updated `cli/simple-txt2img.py` and `cli/simple-img2img.py` scripts
- save `params.txt` regardless of image save status
- update built-in log monitor in ui, thanks @midcoastal

View File

@ -66,45 +66,54 @@ function extract_image_from_gallery(gallery) {
window.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around
function switchToTab(tab) {
const tabs = Array.from(gradioApp().querySelectorAll('#tabs > .tab-nav > button'));
const btn = tabs?.find((t) => t.innerText === tab);
log('switchToTab', tab);
if (btn) btn.click();
}
function switch_to_txt2img(...args) {
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
switchToTab('Text');
return Array.from(arguments);
}
function switch_to_img2img_tab(no) {
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
switchToTab('Image');
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
}
function switch_to_img2img(...args) {
switchToTab('Image');
switch_to_img2img_tab(0);
return Array.from(arguments);
}
function switch_to_sketch(...args) {
switchToTab('Image');
switch_to_img2img_tab(1);
return Array.from(arguments);
}
function switch_to_inpaint(...args) {
switchToTab('Image');
switch_to_img2img_tab(2);
return Array.from(arguments);
}
function switch_to_inpaint_sketch(...args) {
switchToTab('Image');
switch_to_img2img_tab(3);
return Array.from(arguments);
}
function switch_to_extras(...args) {
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
switchToTab('Process');
return Array.from(arguments);
}
function switch_to_control(...args) {
const tabs = Array.from(gradioApp().querySelector('#tabs').querySelectorAll('button'));
const btn = tabs.find((el) => el.innerText.toLowerCase() === 'control');
btn.click();
switchToTab('Control');
return Array.from(arguments);
}

View File

@ -1,5 +1,6 @@
import os
import time
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import
from modules.shared import log
from modules import errors
@ -33,7 +34,7 @@ models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/adapters'
cache_dir = 'models/control/adapter'
def list_models(refresh=False):
@ -105,10 +106,10 @@ class Adapter():
class AdapterPipeline():
def __init__(self, adapter: T2IAdapter | list[T2IAdapter], pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None):
def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline = None
self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None
if pipeline is None:
log.error(f'Control {what} pipeline: model not loaded')
return

View File

@ -1,8 +1,7 @@
import os
import time
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
from modules.shared import log, opts
from modules import errors
@ -38,11 +37,11 @@ models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/controlnets'
cache_dir = 'models/control/controlnet'
def find_models():
path = os.path.join(opts.control_dir, 'controlnets')
path = os.path.join(opts.control_dir, 'controlnet')
files = os.listdir(path)
files = [f for f in files if f.endswith('.safetensors')]
downloaded_models = {}
@ -123,14 +122,14 @@ class ControlNet():
class ControlNetPipeline():
def __init__(self, controlnet: ControlNetModel | list[ControlNetModel], pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None):
def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline = None
if pipeline is None:
log.error('Control model pipeline: model not loaded')
return
if isinstance(pipeline, StableDiffusionXLPipeline):
elif isinstance(pipeline, StableDiffusionXLPipeline):
self.pipeline = StableDiffusionXLControlNetPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,

View File

@ -0,0 +1,135 @@
import os
import time
from typing import Union
import numpy as np
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from modules.shared import log, opts
from modules import errors
from modules.control.controlnetslite_model import ControlNetLLLite
what = 'ControlLLLite'
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: CONTROL')
predefined_sd15 = {
}
predefined_sdxl = {
'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny',
'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime',
'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime',
'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta',
'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime',
'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2',
}
models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/lite'
def find_models():
path = os.path.join(opts.control_dir, 'lite')
files = os.listdir(path)
files = [f for f in files if f.endswith('.safetensors')]
downloaded_models = {}
for f in files:
basename = os.path.splitext(f)[0]
downloaded_models[basename] = os.path.join(path, f)
all_models.update(downloaded_models)
return downloaded_models
def list_models(refresh=False):
import modules.shared
global models # pylint: disable=global-statement
if not refresh and len(models) > 0:
return models
models = {}
if modules.shared.sd_model_type == 'none':
models = ['None']
elif modules.shared.sd_model_type == 'sdxl':
models = ['None'] + sorted(predefined_sdxl) + sorted(find_models())
elif modules.shared.sd_model_type == 'sd':
models = ['None'] + sorted(predefined_sd15) + sorted(find_models())
else:
log.warning(f'Control {what} model list failed: unknown model type')
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models())
debug(f'Control list {what}: path={cache_dir} models={models}')
return models
class ControlLLLite():
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
self.model: ControlNetLLLite = None
self.model_id: str = model_id
self.device = device
self.dtype = dtype
self.load_config = { 'cache_dir': cache_dir }
if load_config is not None:
self.load_config.update(load_config)
if model_id is not None:
self.load()
def reset(self):
if self.model is not None:
log.debug(f'Control {what} model unloaded')
self.model = None
self.model_id = None
def load(self, model_id: str = None) -> str:
try:
t0 = time.time()
model_id = model_id or self.model_id
if model_id is None or model_id == 'None':
self.reset()
return
model_path = all_models[model_id]
if model_path == '':
return
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}')
if model_path.endswith('.safetensors'):
self.model = ControlNetLLLite(model_path)
else:
import huggingface_hub as hf
folder, filename = os.path.split(model_path)
model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir)
self.model = ControlNetLLLite(model_path)
if self.device is not None:
self.model.to(self.device)
if self.dtype is not None:
self.model.to(self.dtype)
t1 = time.time()
self.model_id = model_id
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
errors.display(e, f'Control {what} load')
return f'{what} failed to load model: {model_id}'
class ControlLLitePipeline():
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]):
self.pipeline = pipeline
self.nets = []
def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning):
if image is None:
return
self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet
debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}')
weight = [conditioning] if isinstance(conditioning, float) else conditioning
images = [image] if isinstance(image, Image.Image) else image
images = [i.convert('RGB') for i in images]
for i, cn in enumerate(self.nets):
cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)])
def restore(self):
from modules.control.controlnetslite_model import clear_all_lllite
clear_all_lllite()
self.nets = []

View File

@ -0,0 +1,202 @@
# Credits: <https://github.com/mycodeiscat/ControlNet-LLLite-diffusers>
# <https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI/blob/main/node_control_net_lllite.py>
import re
import torch
from safetensors.torch import load_file
all_hack = {}
class LLLiteModule(torch.nn.Module):
def __init__(
self,
name: str,
is_conv2d: bool,
in_dim: int,
depth: int,
cond_emb_dim: int,
mlp_dim: int,
):
super().__init__()
self.name = name
self.is_conv2d = is_conv2d
self.is_first = False
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
self.conditioning1 = torch.nn.Sequential(*modules)
if self.is_conv2d:
self.down = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
)
else:
self.down = torch.nn.Sequential(
torch.nn.Linear(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Linear(mlp_dim, in_dim),
)
self.depth = depth
self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None
def forward(self, x):
if self.cond_emb is None:
# print(f"cond_emb is None, {self.name}")
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
# if blk_shape is not None:
# b, c, h, w = blk_shape
# cx = torch.nn.functional.interpolate(cx, (h, w), mode="nearest-exact")
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
self.cond_emb = cx
cx = self.cond_emb
# uncond/condでxはバッチサイズが2倍
if x.shape[0] != cx.shape[0]:
if self.is_conv2d:
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
else:
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
cx = self.up(cx)
return cx
def clear_all_lllite():
global all_hack # pylint: disable=global-statement
for k, v in all_hack.items():
k.forward = v
k.lllite_list = []
all_hack = {}
return
class ControlNetLLLite(torch.nn.Module): # pylint: disable=abstract-method
def __init__(self, path: str):
super().__init__()
module_weights = {}
try:
state_dict = load_file(path)
except Exception as e:
raise RuntimeError(f"Failed to load {path}") from e
for key, value in state_dict.items():
fragments = key.split(".")
module_name = fragments[0]
weight_name = ".".join(fragments[1:])
if module_name not in module_weights:
module_weights[module_name] = {}
module_weights[module_name][weight_name] = value
modules = {}
for module_name, weights in module_weights.items():
if "conditioning1.4.weight" in weights:
depth = 3
elif weights["conditioning1.2.weight"].shape[-1] == 4:
depth = 2
else:
depth = 1
module = LLLiteModule(
name=module_name,
is_conv2d=weights["down.0.weight"].ndim == 4,
in_dim=weights["down.0.weight"].shape[1],
depth=depth,
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
mlp_dim=weights["down.0.weight"].shape[0],
)
# info = module.load_state_dict(weights)
modules[module_name] = module
setattr(self, module_name, module)
if len(modules) == 1:
module.is_first = True
self.modules = modules
return
@torch.no_grad()
def apply(self, pipe, cond, weight): # pylint: disable=arguments-differ
map_down_lllite_to_unet = {4: (1, 0), 5: (1, 1), 7: (2, 0), 8: (2, 1)}
model = pipe.unet
if type(cond) != torch.Tensor:
cond = torch.tensor(cond)
cond = cond/255 # 0-255 -> 0-1
cond_image = cond.unsqueeze(dim=0).permute(0, 3, 1, 2) # h,w,c -> b,c,h,w
cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-1
for module in self.modules.values():
module.set_cond_image(cond_image)
for k, v in self.modules.items():
k = k.replace('middle_block', 'middle_blocks_0')
match = re.match("lllite_unet_(.*)_blocks_(.*)_1_transformer_blocks_(.*)_(.*)_to_(.*)", k, re.M | re.I)
assert match, 'Failed to load ControlLLLite!'
root = match.group(1)
block = match.group(2)
block_number = match.group(3)
attn_name = match.group(4)
proj_name = match.group(5)
if root == 'input':
mapped_block, mapped_number = map_down_lllite_to_unet[int(block)]
b = model.down_blocks[mapped_block].attentions[int(mapped_number)].transformer_blocks[int(block_number)]
elif root == 'output':
# TODO: Map up unet blocks to lite blocks
print(f'Not implemented: {root}')
else:
b = model.mid_block.attentions[0].transformer_blocks[int(block_number)]
b = getattr(b, attn_name, None)
assert b is not None, 'Failed to load ControlLLLite!'
b = getattr(b, 'to_' + proj_name, None)
assert b is not None, 'Failed to load ControlLLLite!'
if not hasattr(b, 'lllite_list'):
b.lllite_list = []
if len(b.lllite_list) == 0:
all_hack[b] = b.forward
b.forward = self.get_hacked_forward(original_forward=b.forward, model=model, blk=b)
b.lllite_list.append((weight, v))
return
def get_hacked_forward(self, original_forward, model, blk):
@torch.no_grad()
def forward(x, **kwargs):
hack = 0
for weight, module in blk.lllite_list:
module.to(x.device)
module.to(x.dtype)
hack = hack + module(x) * weight
x = x + hack
return original_forward(x, **kwargs)
return forward

View File

@ -1,13 +1,12 @@
import os
import time
from typing import Union
from modules.shared import log, opts
from modules import errors
ok = True
try:
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline
from diffusers import ControlNetXSModel, StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetXSModel, StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline
except Exception:
from diffusers import ControlNetModel
ControlNetXSModel = ControlNetModel # dummy
@ -27,11 +26,11 @@ models = {}
all_models = {}
all_models.update(predefined_sd15)
all_models.update(predefined_sdxl)
cache_dir = 'models/control/controlnetsxs'
cache_dir = 'models/control/xs'
def find_models():
path = os.path.join(opts.control_dir, 'controlnetsxs')
path = os.path.join(opts.control_dir, 'xs')
files = os.listdir(path)
files = [f for f in files if f.endswith('.safetensors')]
downloaded_models = {}
@ -115,7 +114,7 @@ class ControlNetXS():
class ControlNetXSPipeline():
def __init__(self, controlnet: ControlNetXSModel | list[ControlNetXSModel], pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None):
def __init__(self, controlnet: Union[ControlNetXSModel, list[ControlNetXSModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline = None

View File

@ -31,5 +31,6 @@ class CannyDetector:
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
detected_map = detected_map.convert('L')
return detected_map

View File

@ -58,6 +58,7 @@ class EdgeDetector:
edge_map = cv2.resize(edge_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
edge_map = edge_map.convert('L')
edge_map = Image.fromarray(edge_map)
return edge_map

View File

@ -1,8 +1,8 @@
import time
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from modules.shared import log
from modules.control.proc.reference_sd15 import StableDiffusionReferencePipeline
from modules.control.proc.reference_sdxl import StableDiffusionXLReferencePipeline
from modules.shared import log
what = 'Reference'

View File

@ -11,6 +11,7 @@ from modules.control import unit
from modules.control import processors
from modules.control import controlnets # lllyasviel ControlNet
from modules.control import controlnetsxs # VisLearn ControlNet-XS
from modules.control import controlnetslite # Kohya ControlLLLite
from modules.control import adapters # TencentARC T2I-Adapter
from modules.control import reference # ControlNet-Reference
from modules.control import ipadapter # IP-Adapter
@ -156,7 +157,11 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
active_strength.append(float(u.strength))
active_start.append(float(u.start))
active_end.append(float(u.end))
p.guess_mode = u.guess
shared.log.debug(f'Control ControlNet-XS unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
elif unit_type == 'lite' and u.controlnet.model is not None:
active_process.append(u.process)
active_model.append(u.controlnet)
active_strength.append(float(u.strength))
shared.log.debug(f'Control ControlNet-XS unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
elif unit_type == 'reference':
p.override = u.override
@ -173,7 +178,7 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
has_models = False
selected_models: List[Union[controlnets.ControlNetModel, controlnetsxs.ControlNetXSModel, adapters.AdapterModel]] = None
if unit_type == 'adapter' or unit_type == 'controlnet' or unit_type == 'xs':
if unit_type == 'adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite':
if len(active_model) == 0:
selected_models = None
elif len(active_model) == 1:
@ -216,6 +221,14 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
pipe = instance.pipeline
if inits is not None:
shared.log.warning('Control: ControlNet-XS does not support separate init image')
elif unit_type == 'lite' and has_models:
p.extra_generation_params["Control mode"] = 'ControlLLLite'
p.extra_generation_params["Control conditioning"] = use_conditioning
p.controlnet_conditioning_scale = use_conditioning
instance = controlnetslite.ControlLLitePipeline(shared.sd_model)
pipe = instance.pipeline
if inits is not None:
shared.log.warning('Control: ControlLLLite does not support separate init image')
elif unit_type == 'reference':
p.extra_generation_params["Control mode"] = 'Reference'
p.extra_generation_params["Control attention"] = p.attention
@ -232,6 +245,8 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
if len(active_strength) > 0:
p.strength = active_strength[0]
pipe = diffusers.AutoPipelineForImage2Image.from_pipe(shared.sd_model) # use set_diffuser_pipe
instance = None
debug(f'Control pipeline: class={pipe.__class__} args={vars(p)}')
t1, t2, t3 = time.time(), 0, 0
status = True
@ -396,7 +411,7 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
if hasattr(p, 'init_images'):
del p.init_images # control never uses init_image as-is
if pipe is not None:
if not has_models and (unit_type == 'controlnet' or unit_type == 'adapter' or unit_type == 'xs'): # run in txt2img or img2img mode
if not has_models and (unit_type == 'controlnet' or unit_type == 'adapter' or unit_type == 'xs' or unit_type == 'lite'): # run in txt2img or img2img mode
if processed_image is not None:
p.init_images = [processed_image]
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
@ -411,6 +426,8 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) # only controlnet supports img2img
else:
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE)
if unit_type == 'lite':
instance.apply(selected_models, p.image, use_conditioning)
# pipeline
output = None
@ -480,6 +497,8 @@ def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_genera
image_txt = f'| Frames {len(output_images)} | Size {output_images[0].width}x{output_images[0].height}'
image_txt += f' | {util.dict2str(p.extra_generation_params)}'
if hasattr(instance, 'restore'):
instance.restore()
restore_pipeline()
debug(f'Control ready: {image_txt}')
if is_generator:

View File

@ -112,7 +112,7 @@ def test_adapters(prompt, negative, image):
model_id = 'None'
if shared.state.interrupted:
continue
output = image
output = image.copy()
if model_id != 'None':
adapter = adapters.Adapter(model_id=model_id, device=devices.device, dtype=devices.dtype)
if adapter is None:
@ -122,6 +122,7 @@ def test_adapters(prompt, negative, image):
pipe = adapters.AdapterPipeline(adapter=adapter.model, pipeline=shared.sd_model)
pipe.pipeline.to(device=devices.device, dtype=devices.dtype)
sd_models.set_diffuser_options(pipe)
image = image.convert('L') if 'Canny' in model_id or 'Sketch' in model_id else image.convert('RGB')
try:
res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil')
output = res.images[0]
@ -201,3 +202,56 @@ def test_xs(prompt, negative, image):
grid.paste(thumb, box=(x, y))
yield None, grid, None, images
return None, grid, None, images # preview_process, output_image, output_video, output_gallery
def test_lite(prompt, negative, image):
from modules import devices, sd_models
from modules.control import controlnetslite
if image is None:
shared.log.error('Image not loaded')
return None, None, None
from PIL import ImageDraw, ImageFont
images = []
for model_id in controlnetslite.list_models():
if model_id is None:
model_id = 'None'
if shared.state.interrupted:
continue
output = image
if model_id != 'None':
lite = controlnetslite.ControlLLLite(model_id=model_id, device=devices.device, dtype=devices.dtype)
if lite is None:
shared.log.error(f'Control-LLite load failed: id="{model_id}"')
continue
shared.log.info(f'Testing ControlNet-XS: {model_id}')
pipe = controlnetslite.ControlLLitePipeline(pipeline=shared.sd_model)
pipe.apply(controlnet=lite.model, image=image, conditioning=1.0)
pipe.pipeline.to(device=devices.device, dtype=devices.dtype)
sd_models.set_diffuser_options(pipe)
try:
res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil')
output = res.images[0]
except Exception as e:
errors.display(e, f'ControlNet-XS {model_id} inference')
model_id = f'{model_id} error'
pipe.restore()
draw = ImageDraw.Draw(output)
font = ImageFont.truetype('DejaVuSansMono', 48)
draw.text((10, 10), model_id, (0,0,0), font=font)
draw.text((8, 8), model_id, (255,255,255), font=font)
images.append(output)
yield output, None, None, images
rows = round(math.sqrt(len(images)))
cols = math.ceil(len(images) / rows)
w, h = 256, 256
size = (cols * w + cols, rows * h + rows)
grid = Image.new('RGB', size=size, color='black')
shared.log.info(f'Test ControlNet-XS: images={len(images)} grid={grid}')
for i, image in enumerate(images):
x = (i % cols * w) + (i % cols)
y = (i // cols * h) + (i // cols)
thumb = image.copy().convert('RGB')
thumb.thumbnail((w, h), Image.Resampling.HAMMING)
grid.paste(thumb, box=(x, y))
yield None, grid, None, images
return None, grid, None, images # preview_process, output_image, output_video, output_gallery

View File

@ -4,6 +4,7 @@ from modules.shared import log
from modules.control import processors
from modules.control import controlnets
from modules.control import controlnetsxs
from modules.control import controlnetslite
from modules.control import adapters
from modules.control import reference # pylint: disable=unused-import
@ -110,6 +111,8 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
self.controlnet = controlnets.ControlNet(device=default_device, dtype=default_dtype)
elif self.type == 'xs':
self.controlnet = controlnetsxs.ControlNetXS(device=default_device, dtype=default_dtype)
elif self.type == 'lite':
self.controlnet = controlnetslite.ControlLLLite(device=default_device, dtype=default_dtype)
elif self.type == 'reference':
pass
else:
@ -132,6 +135,11 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
model_id.change(fn=self.controlnet.load, inputs=[model_id, extra_controls[0]], outputs=[result_txt], show_progress=True)
if extra_controls is not None and len(extra_controls) > 0:
extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls)
elif self.type == 'lite':
if model_id is not None:
model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True)
if extra_controls is not None and len(extra_controls) > 0:
extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls)
elif self.type == 'reference':
if extra_controls is not None and len(extra_controls) > 0:
extra_controls[0].change(fn=reference_extra, inputs=extra_controls)

View File

@ -125,7 +125,8 @@ def bind_buttons(buttons, send_image, send_generate_info):
for tabname, button in buttons.items():
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)
register_paste_params_button(bindings)
def register_paste_params_button(binding: ParamBinding):

View File

@ -6,7 +6,7 @@ from typing import Dict
from urllib.parse import urlparse
import PIL.Image as Image
import rich.progress as p
from modules import shared
from modules import shared, errors
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@ -68,6 +68,7 @@ def download_civit_meta(model_path: str, model_id):
return msg
except Exception as e:
msg = f'CivitAI download error: id={model_id} url={url} file={fn} {e}'
errors.display(e, 'CivitAI download error')
shared.log.error(msg)
return msg
return f'CivitAI download error: id={model_id} url={url} code={r.status_code}'
@ -100,7 +101,8 @@ def download_civit_preview(model_path: str, preview_url: str):
except Exception as e:
os.remove(preview_file)
res += f' error={e}'
shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} {e}')
shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} written={written} {e}')
errors.display(e, 'CivitAI download error')
shared.state.end()
if img is None:
return res

View File

@ -217,6 +217,8 @@ def readfile(filename, silent=False, lock=False):
def writefile(data, filename, mode='w', silent=False):
lock = None
locked = False
import tempfile
def default(obj):
log.error(f"Saving: {filename} not a valid object: {obj}")
@ -239,13 +241,19 @@ def writefile(data, filename, mode='w', silent=False):
raise ValueError('not a valid object')
lock = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log)
locked = lock.acquire_write_lock(blocking=True, timeout=3)
with open(filename, mode, encoding="utf8") as file:
file.write(output)
with tempfile.NamedTemporaryFile(mode=mode, encoding="utf8", delete=False, dir=os.path.dirname(filename)) as f:
f.write(output)
f.flush()
os.fsync(f.fileno())
os.replace(f.name, filename)
# with open(filename, mode=mode, encoding="utf8") as file:
# file.write(output)
t1 = time.time()
if not silent:
log.debug(f'Save: file="{filename}" json={len(data)} bytes={len(output)} time={t1-t0:.3f}')
except Exception as e:
log.error(f'Saving failed: {filename} {e}')
errors.display(e, 'Saving failed')
finally:
if lock is not None:
lock.release_read_lock()

View File

@ -13,6 +13,8 @@ import modules.script_callbacks
folder_symbol = symbols.folder
debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: PASTE')
def update_generation_info(generation_info, html_info, img_index):
@ -214,7 +216,10 @@ def create_output_panel(tabname, preview=True):
clip_files.click(fn=None, _js='clip_gallery_urls', inputs=[result_gallery], outputs=[])
save = gr.Button('Save', elem_id=f'save_{tabname}')
delete = gr.Button('Delete', elem_id=f'delete_{tabname}')
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras", "control"])
if shared.backend == shared.Backend.ORIGINAL:
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
else:
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "control", "extras"])
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
@ -245,9 +250,9 @@ def create_output_panel(tabname, preview=True):
else:
paste_field_names = []
for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names
))
debug(f'Create output panel: button={paste_button} tabname={paste_tabname}')
bindings = parameters_copypaste.ParamBinding(paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names)
parameters_copypaste.register_paste_params_button(bindings)
return result_gallery, generation_info, html_info, html_info_formatted, html_log

View File

@ -3,6 +3,7 @@ import gradio as gr
from modules.control import unit
from modules.control import controlnets # lllyasviel ControlNet
from modules.control import controlnetsxs # vislearn ControlNet-XS
from modules.control import controlnetslite # vislearn ControlNet-XS
from modules.control import adapters # TencentARC T2I-Adapter
from modules.control import processors # patrickvonplaten controlnet_aux
from modules.control import reference # reference pipeline
@ -12,7 +13,7 @@ from modules.ui_components import FormRow, FormGroup
gr_height = 512
max_units = 10
max_units = 5
units: list[unit.Unit] = [] # main state variable
input_source = None
input_init = None
@ -23,15 +24,17 @@ debug('Trace: CONTROL')
def initialize():
from modules import devices
shared.log.debug(f'Control initialize: models={shared.opts.control_dir}')
controlnets.cache_dir = os.path.join(shared.opts.control_dir, 'controlnets')
controlnetsxs.cache_dir = os.path.join(shared.opts.control_dir, 'controlnetsxs')
adapters.cache_dir = os.path.join(shared.opts.control_dir, 'adapters')
processors.cache_dir = os.path.join(shared.opts.control_dir, 'processors')
controlnets.cache_dir = os.path.join(shared.opts.control_dir, 'controlnet')
controlnetsxs.cache_dir = os.path.join(shared.opts.control_dir, 'xs')
controlnetslite.cache_dir = os.path.join(shared.opts.control_dir, 'lite')
adapters.cache_dir = os.path.join(shared.opts.control_dir, 'adapter')
processors.cache_dir = os.path.join(shared.opts.control_dir, 'processor')
unit.default_device = devices.device
unit.default_dtype = devices.dtype
os.makedirs(shared.opts.control_dir, exist_ok=True)
os.makedirs(controlnets.cache_dir, exist_ok=True)
os.makedirs(controlnetsxs.cache_dir, exist_ok=True)
os.makedirs(controlnetslite.cache_dir, exist_ok=True)
os.makedirs(adapters.cache_dir, exist_ok=True)
os.makedirs(processors.cache_dir, exist_ok=True)
@ -58,7 +61,7 @@ def return_controls(res):
def generate_click(job_id: str, active_tab: str, *args):
from modules.control.run import control_run
shared.log.debug(f'Control: tab={active_tab} job={job_id} args={args}')
if active_tab not in ['controlnet', 'xs', 'adapter', 'reference']:
if active_tab not in ['controlnet', 'xs', 'adapter', 'reference', 'lite']:
return None, None, None, None, f'Control: Unknown mode: {active_tab} args={args}'
shared.state.begin('control')
progress.add_task_to_queue(job_id)
@ -203,14 +206,14 @@ def create_ui(_blocks: gr.Blocks=None):
with gr.Row(elem_id='control_settings'):
with gr.Accordion(open=False, label="Input", elem_id="control_input", elem_classes=["small-accordion"]):
with gr.Row():
input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type')
with gr.Row():
denoising_strength = gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="control_denoising_strength")
with gr.Row():
show_ip = gr.Checkbox(label="Enable IP adapter", value=False, elem_id="control_show_ip")
with gr.Row():
show_preview = gr.Checkbox(label="Show preview", value=False, elem_id="control_show_preview")
with gr.Row():
input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type')
with gr.Row():
denoising_strength = gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="control_denoising_strength")
resize_mode, resize_name, width, height, scale_by, selected_scale_tab, resize_time = ui.create_resize_inputs('control', [], time_selector=True, scale_visible=False, mode='Fixed')
@ -402,10 +405,10 @@ def create_ui(_blocks: gr.Blocks=None):
extra_controls = [
gr.Slider(label="Control factor", minimum=0.0, maximum=1.0, step=0.05, value=1.0, scale=3),
]
num_adaptor_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1)
adaptor_ui_units = [] # list of hidable accordions
num_adapter_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1)
adapter_ui_units = [] # list of hidable accordions
for i in range(max_units):
with gr.Accordion(f'Adapter unit {i+1}', visible= i < num_adaptor_units.value) as unit_ui:
with gr.Accordion(f'Adapter unit {i+1}', visible= i < num_adapter_units.value) as unit_ui:
with gr.Row():
with gr.Column():
with gr.Row():
@ -417,7 +420,7 @@ def create_ui(_blocks: gr.Blocks=None):
reset_btn = ui_components.ToolButton(value=ui_symbols.reset)
image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool'])
process_btn= ui_components.ToolButton(value=ui_symbols.preview)
adaptor_ui_units.append(unit_ui)
adapter_ui_units.append(unit_ui)
units.append(unit.Unit(
unit_type = 'adapter',
result_txt = result_txt,
@ -435,7 +438,47 @@ def create_ui(_blocks: gr.Blocks=None):
)
if i == 0:
units[-1].enabled = True # enable first unit in group
num_adaptor_units.change(fn=display_units, inputs=[num_adaptor_units], outputs=adaptor_ui_units)
num_adapter_units.change(fn=display_units, inputs=[num_adapter_units], outputs=adapter_ui_units)
with gr.Tab('Lite') as _tab_lite:
gr.HTML('<a href="https://huggingface.co/kohya-ss/controlnet-lllite">Control LLLite</a>')
with gr.Row():
extra_controls = [
]
num_lite_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1)
lite_ui_units = [] # list of hidable accordions
for i in range(max_units):
with gr.Accordion(f'Control unit {i+1}', visible= i < num_lite_units.value) as unit_ui:
with gr.Row():
with gr.Column():
with gr.Row():
enabled_cb = gr.Checkbox(value= i == 0, label="Enabled")
process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None')
model_id = gr.Dropdown(label="Model", choices=controlnetslite.list_models(), value='None')
ui_common.create_refresh_button(model_id, controlnetslite.list_models, lambda: {"choices": controlnetslite.list_models(refresh=True)}, 'refresh_lite_models')
model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0-i/10)
reset_btn = ui_components.ToolButton(value=ui_symbols.reset)
image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool'])
process_btn= ui_components.ToolButton(value=ui_symbols.preview)
lite_ui_units.append(unit_ui)
units.append(unit.Unit(
unit_type = 'lite',
result_txt = result_txt,
image_input = input_image,
enabled_cb = enabled_cb,
reset_btn = reset_btn,
process_id = process_id,
model_id = model_id,
model_strength = model_strength,
preview_process = preview_process,
preview_btn = process_btn,
image_upload = image_upload,
extra_controls = extra_controls,
)
)
if i == 0:
units[-1].enabled = True # enable first unit in group
num_lite_units.change(fn=display_units, inputs=[num_lite_units], outputs=lite_ui_units)
with gr.Tab('Reference') as _tab_reference:
gr.HTML('<a href="https://github.com/Mikubill/sd-webui-controlnet/discussions/1236">ControlNet reference-only control</a>')
@ -555,16 +598,19 @@ def create_ui(_blocks: gr.Blocks=None):
generation_parameters_copypaste.register_paste_params_button(bindings)
if os.environ.get('SD_CONTROL_DEBUG', None) is not None: # debug only
from modules.control.test import test_processors, test_controlnets, test_adapters, test_xs
from modules.control.test import test_processors, test_controlnets, test_adapters, test_xs, test_lite
gr.HTML('<br><h1>Debug</h1><br>')
with gr.Row():
run_test_processors_btn = gr.Button(value="Test:Processors", variant='primary', elem_classes=['control-button'])
run_test_controlnets_btn = gr.Button(value="Test:ControlNets", variant='primary', elem_classes=['control-button'])
run_test_xs_btn = gr.Button(value="Test:ControlNets-XS", variant='primary', elem_classes=['control-button'])
run_test_adapters_btn = gr.Button(value="Test:Adapters", variant='primary', elem_classes=['control-button'])
run_test_lite_btn = gr.Button(value="Test:Control-LLLite", variant='primary', elem_classes=['control-button'])
run_test_processors_btn.click(fn=test_processors, inputs=[input_image], outputs=[preview_process, output_image, output_video, output_gallery])
run_test_controlnets_btn.click(fn=test_controlnets, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
run_test_xs_btn.click(fn=test_xs, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
run_test_adapters_btn.click(fn=test_adapters, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
run_test_lite_btn.click(fn=test_lite, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery])
return [(control_ui, 'Control', 'control')]

View File

@ -515,7 +515,8 @@ def create_ui():
continue
for item in page.list_items():
meta = os.path.splitext(item['filename'])[0] + '.json'
if ('card-no-preview.png' in item['preview'] or not os.path.isfile(meta)) and os.path.isfile(item['filename']):
has_meta = os.path.isfile(meta) and os.stat(meta).st_size > 0
if ('card-no-preview.png' in item['preview'] or not has_meta) and os.path.isfile(item['filename']):
sha = item.get('hash', None)
found = False
if sha is not None and len(sha) > 0:

View File

@ -1,5 +1,5 @@
[tool.ruff]
target-version = "py310"
target-version = "py39"
select = [
"F",
"E",

View File

@ -50,8 +50,8 @@ clip-interrogator==0.6.0
antlr4-python3-runtime==4.9.3
requests==2.31.0
tqdm==4.66.1
accelerate==0.24.1
opencv-python-headless==4.7.0.72
accelerate==0.25.0
opencv-python-headless==4.8.1.78
diffusers==0.24.0
einops==0.4.1
gradio==3.43.2
@ -65,9 +65,9 @@ pytorch_lightning==1.9.4
tokenizers==0.15.0
transformers==4.36.2
tomesd==0.1.3
urllib3==1.26.15
urllib3==1.26.18
Pillow==10.1.0
timm==0.9.7
timm==0.9.12
pydantic==1.10.13
typing-extensions==4.8.0
typing-extensions==4.9.0
peft

2
wiki

@ -1 +1 @@
Subproject commit 4d9737acc348cfc0ae0b139ec5e39ae61967e7ac
Subproject commit 7dadbd912a1854bad2d7c6463eaecc458f2808c3