Merge pull request #4497 from vladmandic/dev

merge dev
pull/4517/head 2025-12-26
Vladimir Mandic 2025-12-26 09:21:35 +01:00 committed by GitHub
commit 3ae10181dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
75 changed files with 807 additions and 259 deletions

View File

@ -43,7 +43,7 @@ exclude = [
]
line-length = 250
indent-width = 4
target-version = "py39"
target-version = "py310"
[lint]
select = [
@ -72,6 +72,7 @@ select = [
ignore = [
"B006", # Do not use mutable data structures for argument defaults
"B008", # Do not perform function call in argument defaults
"B905", # Strict zip() usage
"C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead
"C408", # Unnecessary `dict` call
"I001", # Import block is un-sorted or un-formatted
@ -81,6 +82,7 @@ ignore = [
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"F401", # Imported by unused
"EXE001", # file with shebang is not marked executable
"NPY002", # replace legacy random
"RUF005", # Consider iterable unpacking
"RUF008", # Do not use mutable default values for dataclass

View File

@ -1,5 +1,59 @@
# Change Log for SD.Next
## Update for 2025-12-26
### Highlights for 2025-12-26
End of year release update, just two weeks after previous one, with several new models and features:
- Several new models including highly anticipated **Qwen-Image-Edit 2511** as well as **Qwen-Image-Layered**, **LongCat Image** and **Ovis Image**
- New features including support for **Z-Image** *ControlNets* and *fine-tunes* and **Detailer** segmentation support
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
### Details for 2025-12-26
- **Models**
- [LongCat Image](https://github.com/meituan-longcat/LongCat-Image) in *Image* and *Image Edit* variants
LongCat is a new 8B diffusion base model using Qwen-2.5 as text encoder
- [Qwen-Image-Edit 2511](Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants
Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability
- [Qwen-Image-Layered](https://huggingface.co/Qwen/Qwen-Image-Layered) in *base* and *pre-quantized* variants
Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers
*note*: set number of desired output layers in *settings -> model options*
- [Ovis Image 7B](https://huggingface.co/AIDC-AI/Ovis-Image-7B)
Ovis Image is a new text-to-image base model based on Qwen3 text-encoder and optimized for text-rendering
- **Features**
- Google **Gemini** and **Veo** models support for both *Dev* and *Vertex* access methods
see [docs](https://vladmandic.github.io/sdnext-docs/Google-GenAI/) for details
- **Z-Image Turbo** support loading transformer file-tunes in safetensors format
as with any transformers/unet finetunes, place them then `models/unet`
and use **UNET Model** to load safetensors file as they are not complete models
- **Z-Image Turbo** support for **ControlNet Union**
includes 1.0, 2.0 and 2.1 variants
- **Detailer** support for segmentation models
some detection models can produce exact segmentation mask and not just box
to enable, set `use segmentation` option
added segmentation models: *anzhc-eyes-seg*, *anzhc-face-1024-seg-8n*, *anzhc-head-seg-8n*
- **Internal**
- update nightlies to `rocm==7.1`
- mark `python==3.9` as deprecated
- extensions improved status indicators, thanks @awsr
- additional type-safety checks, thanks @awsr
- add model info to ui overlay
- **Wiki/Docs/Illustrations**
- update models page, thanks @alerikaisattera
- update reference models samples, thanks @liutyi
- **Fixes**
- generate forever fix loop checks, thanks @awsr
- tokenizer expclit use for flux2, thanks @CalamitousFelicitousness
- torch.compile skip offloading steps
- kanvas css with standardui
- control input media with non-english locales
- handle embeds when on meta device
- improve offloading when model has manual modules
- ui section colapsible state, thanks @awsr
- ui filter by model type
## Update for 2025-12-11
### Highlights for 2025-12-11

21
TODO.md
View File

@ -1,12 +1,19 @@
# TODO
## Known issues
- z-image-turbo controlnet device mismatch: <https://github.com/huggingface/diffusers/pull/12886>
- z-image-turbo safetensors loader: <https://github.com/huggingface/diffusers/issues/12887>
- kandinsky-image-5 hardcoded cuda: <https://github.com/huggingface/diffusers/pull/12814>
- peft lora with torch-rocm-windows: <https://github.com/huggingface/peft/pull/2963>
## Project Board
- <https://github.com/users/vladmandic/projects>
## Internal
- Reimplement llama remover for kanvas
- Reimplement `llama` remover for kanvas
- Deploy: Create executable for SD.Next
- Feature: Integrate natural language image search
[ImageDB](https://github.com/vladmandic/imagedb)
@ -21,7 +28,9 @@
- UI: Lite vs Expert mode
- Video tab: add full API support
- Control tab: add overrides handling
- Engine: TensorRT acceleration
- Engine: `TensorRT` acceleration
- Engine: [mmgp](https://github.com/deepbeepmeep/mmgp)
- Engine: [sharpfin](https://github.com/drhead/sharpfin) instead of `torchvision`
## Features
@ -38,6 +47,7 @@
TODO: *Prioritize*!
- [Cosmos-Predict-2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B)
- [NewBie Image Exp0.1](https://github.com/huggingface/diffusers/pull/12803)
- [Sana-I2V](https://github.com/huggingface/diffusers/pull/12634#issuecomment-3540534268)
- [Bria FIBO](https://huggingface.co/briaai/FIBO)
@ -74,11 +84,6 @@ TODO: *Prioritize*!
- [Wan2.2-Animate-14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B)
- [WAN2GP](https://github.com/deepbeepmeep/Wan2GP)
### General
- Review/improve type-hinting and type checking
- A little easier to work with due to syntax changes in Python 3.10
### Asyncio
- Policy system is deprecated and will be removed in **Python 3.16**
@ -91,8 +96,6 @@ TODO: *Prioritize*!
- [asyncio.run](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.run)
- [asyncio.Runner](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.Runner)
### Shutil
#### rmtree
- `onerror` deprecated and replaced with `onexc` in **Python 3.12**

0
cli/api-xyz.py Normal file → Executable file
View File

@ -1 +1 @@
Subproject commit c6dc85eb28a02bc7af268497b7a5a596770c5d7b
Subproject commit 2a7005fbcf8985644b66121365fa7228a65f34b0

@ -1 +1 @@
Subproject commit f3cfab10af26f0c7243878a3c320d50012764694
Subproject commit 989a54a5b2ae4ba12fefbf48c9ed61c3663c4c0c

@ -1 +1 @@
Subproject commit af99fbab29e9a424c4e79fa8e4ae194481cb5f75
Subproject commit ded112e94a94bf64daefa027376e0335fb43e0b7

View File

@ -200,6 +200,24 @@
"size": 56.1,
"date": "2025 September"
},
"Qwen-Image-Edit-2511": {
"path": "Qwen/Qwen-Image-Edit-2511",
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
"desc": "Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability.",
"skip": true,
"extras": "",
"size": 56.1,
"date": "2025 December"
},
"Qwen-Image-Layered": {
"path": "Qwen/Qwen-Image-Layered",
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
"desc": "Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers",
"skip": true,
"extras": "",
"size": 53.7,
"date": "2025 December"
},
"Qwen-Image-Lightning": {
"path": "vladmandic/Qwen-Lightning",
"preview": "vladmandic--Qwen-Lightning.jpg",
@ -314,6 +332,25 @@
"date": "2025 July"
},
"Meituan LongCat Image": {
"path": "meituan-longcat/LongCat-Image",
"preview": "meituan-longcat--LongCat-Image.jpg",
"desc": "Pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.",
"skip": true,
"extras": "",
"size": 27.30,
"date": "2025 December"
},
"Meituan LongCat Image-Edit": {
"path": "meituan-longcat/LongCat-Image-Edit",
"preview": "meituan-longcat--LongCat-Image-Edit.jpg",
"desc": "Pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.",
"skip": true,
"extras": "",
"size": 27.30,
"date": "2025 December"
},
"Ostris Flex.2 Preview": {
"path": "ostris/Flex.2-preview",
"preview": "ostris--Flex.2-preview.jpg",
@ -782,7 +819,7 @@
"Kandinsky 5.0 T2I Lite": {
"path": "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers",
"desc": "Kandinsky 5.0 Image Lite is a 6B image generation models 1K resulution, high visual quality and strong text-writing",
"preview": "kandinsky-community--kandinsky-3.jpg",
"preview": "kandinskylab--Kandinsky-5.0-T2I-Lite-sft-Diffusers.jpg",
"skip": true,
"size": 33.20,
"date": "2025 November"
@ -790,7 +827,7 @@
"Kandinsky 5.0 I2I Lite": {
"path": "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers",
"desc": "Kandinsky 5.0 Image Lite is a 6B image editing models 1K resulution, high visual quality and strong text-writing",
"preview": "kandinsky-community--kandinsky-3.jpg",
"preview": "kandinskylab--Kandinsky-5.0-T2I-Lite-sft-Diffusers.jpg",
"skip": true,
"size": 33.20,
"date": "2025 November"
@ -901,6 +938,16 @@
"date": "2024 January"
},
"AIDC Ovis-Image 7B": {
"path": "AIDC-AI/Ovis-Image-7B",
"skip": true,
"desc": "Built upon Ovis-U1, Ovis-Image is a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints.",
"preview": "AIDC-AI--Ovis-Image-7B.jpg",
"size": 23.38,
"date": "2025 December",
"extras": ""
},
"HDM-XUT 340M Anime": {
"path": "KBlueLeaf/HDM-xut-340M-anime",
"skip": true,
@ -1075,6 +1122,26 @@
"size": 16.10,
"extras": ""
},
"Qwen-Image-Edit-2511 sdnq-svd-uint4": {
"path": "Disty0/Qwen-Image-Edit-2511-SDNQ-uint4-svd-r32",
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
"desc": "Quantization of Qwen/Qwen-Image-Edit-2511 using SDNQ: sdnq-svd 4-bit uint with svd rank 32",
"skip": true,
"tags": "quantized",
"date": "2025 December",
"size": 16.10,
"extras": ""
},
"Qwen-Image-Layered sdnq-svd-uint4": {
"path": "Disty0/Qwen-Image-Layered-SDNQ-uint4-svd-r32",
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
"desc": "Quantization of Qwen/Qwen-Image-Layered using SDNQ: sdnq-svd 4-bit uint with svd rank 32",
"skip": true,
"tags": "quantized",
"date": "2025 December",
"size": 16.10,
"extras": ""
},
"nVidia ChronoEdit sdnq-svd-uint4": {
"path": "Disty0/ChronoEdit-14B-SDNQ-uint4-svd-r32",
"preview": "Disty0--ChronoEdit-14B-SDNQ-uint4-svd-r32.jpg",

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import overload, List, Optional
import os
import sys
import json
@ -19,7 +19,6 @@ class Dot(dict): # dot notation access to dictionary attributes
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
version = {
'app': 'sd.next',
'updated': 'unknown',
@ -94,19 +93,35 @@ def get_log():
return log
@overload
def str_to_bool(val: str | bool) -> bool: ...
@overload
def str_to_bool(val: None) -> None: ...
def str_to_bool(val: str | bool | None) -> bool | None:
if isinstance(val, str):
if val.strip() and val.strip().lower() in ("1", "true"):
return True
return False
return val
def install_traceback(suppress: list = []):
from rich.traceback import install as traceback_install
from rich.pretty import install as pretty_install
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
if width is not None:
width = int(width)
traceback_install(
console=console,
extra_lines=os.environ.get('SD_TRACELINES', 1),
max_frames=os.environ.get('SD_TRACEFRAMES', 16),
width=os.environ.get('SD_TRACEWIDTH', console.width),
word_wrap=os.environ.get('SD_TRACEWRAP', False),
indent_guides=os.environ.get('SD_TRACEINDENT', False),
show_locals=os.environ.get('SD_TRACELOCALS', False),
locals_hide_dunder=os.environ.get('SD_TRACEDUNDER', True),
locals_hide_sunder=os.environ.get('SD_TRACESUNDER', None),
extra_lines=int(os.environ.get("SD_TRACELINES", 1)),
max_frames=int(os.environ.get("SD_TRACEFRAMES", 16)),
width=width,
word_wrap=str_to_bool(os.environ.get("SD_TRACEWRAP", False)),
indent_guides=str_to_bool(os.environ.get("SD_TRACEINDENT", False)),
show_locals=str_to_bool(os.environ.get("SD_TRACELOCALS", False)),
locals_hide_dunder=str_to_bool(os.environ.get("SD_TRACEDUNDER", True)),
locals_hide_sunder=str_to_bool(os.environ.get("SD_TRACESUNDER", None)),
suppress=suppress,
)
pretty_install(console=console)
@ -633,7 +648,7 @@ def check_diffusers():
t_start = time.time()
if args.skip_all:
return
sha = '3d02cd543ef3101d821cb09c8fcab23c6e7ead33' # diffusers commit hash
sha = 'f6b6a7181eb44f0120b29cd897c129275f366c2a' # diffusers commit hash
# if args.use_rocm or args.use_zluda or args.use_directml:
# sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
@ -781,10 +796,10 @@ def install_rocm_zluda():
else:
#check_python(supported_minors=[10, 11, 12, 13, 14], reason='ROCm backend requires a Python version between 3.10 and 3.13')
if args.use_nightly:
if rocm.version is None or float(rocm.version) >= 7.0: # assume the latest if version check fails
if rocm.version is None or float(rocm.version) >= 7.1: # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1')
else: # oldest rocm version on nightly is 7.0
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0')
else: # oldest rocm version on nightly is 6.4
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.4')
else:
if rocm.version is None or float(rocm.version) >= 6.4: # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+rocm6.4 torchvision==0.24.1+rocm6.4 --index-url https://download.pytorch.org/whl/rocm6.4')
@ -1714,7 +1729,7 @@ def add_args(parser):
group_install.add_argument('--skip-env', default=os.environ.get("SD_SKIPENV",False), action='store_true', help="Skips setting of env variables during startup, default: %(default)s")
group_compute = parser.add_argument_group('Compute Engine')
group_compute.add_argument("--device-id", type=str, default=os.environ.get("SD_DEVICEID", None), help="Select the default CUDA device to use, default: %(default)s")
group_compute.add_argument("--device-id", type=str, default=os.environ.get("SD_DEVICEID", None), help="Select the default GPU device to use, default: %(default)s")
group_compute.add_argument("--use-cuda", default=os.environ.get("SD_USECUDA",False), action='store_true', help="Force use nVidia CUDA backend, default: %(default)s")
group_compute.add_argument("--use-ipex", default=os.environ.get("SD_USEIPEX",False), action='store_true', help="Force use Intel OneAPI XPU backend, default: %(default)s")
group_compute.add_argument("--use-rocm", default=os.environ.get("SD_USEROCM",False), action='store_true', help="Force use AMD ROCm backend, default: %(default)s")
@ -1730,9 +1745,11 @@ def add_args(parser):
group_paths.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
group_paths.add_argument("--extensions-dir", type=str, default=os.environ.get("SD_EXTENSIONSDIR", None), help="Base path where all extensions are stored, default: %(default)s",)
group_ui = parser.add_argument_group('UI')
group_ui.add_argument('--theme', type=str, default=os.environ.get("SD_THEME", None), help='Override UI theme')
group_ui.add_argument('--locale', type=str, default=os.environ.get("SD_LOCALE", None), help='Override UI locale')
group_http = parser.add_argument_group('HTTP')
group_http.add_argument('--theme', type=str, default=os.environ.get("SD_THEME", None), help='Override UI theme')
group_http.add_argument('--locale', type=str, default=os.environ.get("SD_LOCALE", None), help='Override UI locale')
group_http.add_argument("--server-name", type=str, default=os.environ.get("SD_SERVERNAME", None), help="Sets hostname of server, default: %(default)s")
group_http.add_argument("--tls-keyfile", type=str, default=os.environ.get("SD_TLSKEYFILE", None), help="Enable TLS and specify key file, default: %(default)s")
group_http.add_argument("--tls-certfile", type=str, default=os.environ.get("SD_TLSCERTFILE", None), help="Enable TLS and specify cert file, default: %(default)s")

View File

@ -100,13 +100,20 @@ const generateForever = (genbuttonid) => {
clearInterval(window.generateOnRepeatInterval);
window.generateOnRepeatInterval = null;
} else {
log('generateForever: start');
const genbutton = gradioApp().querySelector(genbuttonid);
const busy = document.getElementById('progressbar')?.style.display === 'block';
if (!busy) genbutton.click();
const isBusy = () => {
let busy = document.getElementById('progressbar')?.style.display === 'block';
if (!busy) {
// Also check in Modern UI
const outerButton = genbutton.parentElement.closest('button');
busy = outerButton?.classList.contains('generate') && outerButton?.classList.contains('active');
}
return busy;
};
log('generateForever: start');
if (!isBusy()) genbutton.click();
window.generateOnRepeatInterval = setInterval(() => {
const pbBusy = document.getElementById('progressbar')?.style.display === 'block';
if (!pbBusy) genbutton.click();
if (!isBusy()) genbutton.click();
}, 500);
}
};

View File

@ -3,7 +3,11 @@ function controlInputMode(inputMode, ...args) {
if (updateEl) updateEl.click();
const tab = gradioApp().querySelector('#control-tab-input button.selected');
if (!tab) return ['Image', ...args];
let inputTab = tab.innerText;
// let inputTab = tab.innerText;
const tabs = Array.from(gradioApp().querySelectorAll('#control-tab-input button'));
const tabIdx = tabs.findIndex((btn) => btn.classList.contains('selected'));
const tabNames = ['Image', 'Video', 'Batch', 'Folder'];
let inputTab = tabNames[tabIdx] || 'Image';
log('controlInputMode', { mode: inputMode, tab: inputTab, kanvas: typeof Kanvas });
if ((inputTab === 'Image') && (typeof 'Kanvas' !== 'undefined')) {
inputTab = 'Kanvas';

View File

@ -312,8 +312,9 @@ function extraNetworksFilterVersion(event) {
const version = event.target.textContent.trim();
const activeTab = getENActiveTab();
const activePage = getENActivePage().toLowerCase();
const cardContainer = gradioApp().querySelector(`#${activeTab}_${activePage}_cards`);
log('extraNetworksFilterVersion', version);
let cardContainer = gradioApp().querySelector(`#${activeTab}_${activePage}_cards`);
if (!cardContainer) cardContainer = gradioApp().querySelector(`#txt2img_extra_networks_${activePage}_cards`);
log('extraNetworksFilterVersion', { version, activeTab, activePage, cardContainer });
if (!cardContainer) return;
if (cardContainer.dataset.activeVersion === version) {
cardContainer.dataset.activeVersion = '';

View File

@ -46,7 +46,9 @@ function forceLogin() {
})
.then(async (res) => {
const json = await res.json();
const txt = `${res.status}: ${res.statusText} - ${json.detail}`;
let txt = '';
if (res.status === 200) txt = 'login verified';
else txt = `${res.status}: ${res.statusText} - ${json.detail}`;
status.textContent = txt;
console.log('login', txt);
if (res.status === 200) location.reload();

View File

@ -1,3 +1,11 @@
const getModel = () => {
const cp = opts?.sd_model_checkpoint || '';
if (!cp) return 'unknown model';
const noBracket = cp.replace(/\s*\[.*\]\s*$/, ''); // remove trailing [hash]
const parts = noBracket.split(/[\\/]/); // split on / or \
return parts[parts.length - 1].trim() || 'unknown model';
};
async function updateIndicator(online, data, msg) {
const el = document.getElementById('logo_nav');
if (!el || !data) return;
@ -5,9 +13,10 @@ async function updateIndicator(online, data, msg) {
const date = new Date();
const template = `
Version: <b>${data.updated}</b><br>
Commit: <b>${data.hash}</b><br>
Commit: <b>${data.commit}</b><br>
Branch: <b>${data.branch}</b><br>
Status: ${status}<br>
Model: <b>${getModel()}</b><br>
Since: ${date.toLocaleString()}<br>
`;
if (online) {

View File

@ -120,7 +120,7 @@ textarea {
}
span {
font-size: var(--text-md) !important;
font-size: var(--text-md);
}
button {

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@ -8,8 +8,15 @@ except Exception:
nvml_initialized = False
warned = False
def warn_once(msg):
global warned # pylint: disable=global-statement
if not warned:
log.error(msg)
warned = True
def get_reason(val):
throttle = {
1: 'gpu idle',
@ -28,6 +35,8 @@ def get_reason(val):
def get_nvml():
global nvml_initialized # pylint: disable=global-statement
if warned:
return []
try:
from modules.memstats import ram_stats
if not nvml_initialized:
@ -71,7 +80,7 @@ def get_nvml():
# log.debug(f'nmvl: {devices}')
return devices
except Exception as e:
log.error(f'NVML: {e}')
warn_once(f'NVML: {e}')
return []

View File

@ -77,7 +77,7 @@ def civit_update_metadata(raw:bool=False):
model.id = d['modelId']
download_civit_meta(model.fn, model.id)
fn = os.path.splitext(item['filename'])[0] + '.json'
model.meta = readfile(fn, silent=True)
model.meta = readfile(fn, silent=True, as_type="dict")
model.name = model.meta.get('name', model.name)
model.versions = len(model.meta.get('modelVersions', []))
versions = model.meta.get('modelVersions', [])

View File

@ -182,7 +182,7 @@ def check_active(p, unit_type, units):
p.is_tile = p.is_tile or 'tile' in u.mode.lower()
p.control_tile = u.tile
p.extra_generation_params["Control mode"] = u.mode
shared.log.debug(f'Control ControlNet unit: i={num_units} process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}')
shared.log.debug(f'Control unit: i={num_units} type=ControlNet process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}')
elif unit_type == 'xs' and u.controlnet.model is not None:
active_process.append(u.process)
active_model.append(u.controlnet)
@ -190,13 +190,13 @@ def check_active(p, unit_type, units):
active_start.append(float(u.start))
active_end.append(float(u.end))
active_units.append(u)
shared.log.debug(f'Control ControlNet-XS unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
shared.log.debug(f'Control unit: i={num_units} type=ControlNetXS process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
elif unit_type == 'lite' and u.controlnet.model is not None:
active_process.append(u.process)
active_model.append(u.controlnet)
active_strength.append(float(u.strength))
active_units.append(u)
shared.log.debug(f'Control ControlLLite unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
shared.log.debug(f'Control unit: i={num_units} type=ControlLLite process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
elif unit_type == 'reference':
p.override = u.override
p.attention = u.attention
@ -209,7 +209,7 @@ def check_active(p, unit_type, units):
if u.process.processor_id is not None:
active_process.append(u.process)
active_units.append(u)
shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}')
shared.log.debug(f'Control unit: i={num_units} type=Process process={u.process.processor_id}')
active_strength.append(float(u.strength))
debug_log(f'Control active: process={len(active_process)} model={len(active_model)}')
return active_process, active_model, active_strength, active_start, active_end, active_units
@ -654,7 +654,7 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
debug_log(f'Control: pipeline units={len(active_model)} process={len(active_process)} outputs={len(output_images)}')
except Exception as e:
shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}')
shared.log.error(f'Control: type={unit_type} units={len(active_model)} {e}')
errors.display(e, 'Control')
if len(output_images) == 0:

View File

@ -134,15 +134,15 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
if image_file is None:
self.process.override = None
self.override = None
log.debug('Control process clear image')
log.debug('Control image: clear')
return gr.update(value=None)
try:
self.process.override = Image.open(image_file.name)
self.override = self.process.override
log.debug(f'Control process upload image: path="{image_file.name}" image={self.process.override}')
log.debug(f'Control image: upload={self.process.override} path="{image_file.name}"')
return gr.update(visible=self.process.override is not None, value=self.process.override)
except Exception as e:
log.error(f'Control process upload image failed: path="{image_file.name}" error={e}')
log.error(f'Control image: upload path="{image_file.name}" error={e}')
return gr.update(visible=False, value=None)
def reuse_image(image):

View File

@ -111,6 +111,11 @@ predefined_hunyuandit = {
"HunyuanDiT Pose": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Pose',
"HunyuanDiT Depth": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Depth',
}
predefined_zimage = {
"Z-Image-Turbo Union 1.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union',
"Z-Image-Turbo Union 2.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0',
"Z-Image-Turbo Union 2.1": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1',
}
variants = {
'NoobAI Canny XL': 'fp16',
@ -137,6 +142,7 @@ all_models.update(predefined_f1)
all_models.update(predefined_sd3)
all_models.update(predefined_qwen)
all_models.update(predefined_hunyuandit)
all_models.update(predefined_zimage)
cache_dir = 'models/control/controlnet'
load_lock = threading.Lock()
@ -175,6 +181,8 @@ def api_list_models(model_type: str = None):
model_list += list(predefined_qwen)
if model_type == 'hunyuandit' or model_type == 'all':
model_list += list(predefined_hunyuandit)
if model_type == 'z_image':
model_list += list(predefined_zimage)
model_list += sorted(find_models())
return model_list
@ -199,9 +207,11 @@ def list_models(refresh=False):
models = ['None'] + list(predefined_qwen) + sorted(find_models())
elif modules.shared.sd_model_type == 'hunyuandit':
models = ['None'] + list(predefined_hunyuandit) + sorted(find_models())
elif modules.shared.sd_model_type == 'z_image':
models = ['None'] + list(predefined_zimage) + sorted(find_models())
else:
log.warning(f'Control {what} model list failed: unknown model type')
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(predefined_f1) + sorted(predefined_sd3) + sorted(find_models())
models = ['None'] + list(all_models) + sorted(find_models())
debug_log(f'Control list {what}: path={cache_dir} models={models}')
return models
@ -263,6 +273,14 @@ class ControlNet():
elif shared.sd_model_type == 'hunyuandit':
from diffusers import HunyuanDiT2DControlNetModel as cls
config = 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Canny'
elif shared.sd_model_type == 'z_image':
from diffusers import ZImageControlNetModel as cls
if '2.0' in model_id:
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0'
elif '2.1' in model_id:
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1'
else:
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union'
else:
log.error(f'Control {what}: type={shared.sd_model_type} unsupported model')
return None, None
@ -508,6 +526,17 @@ class ControlNetPipeline():
feature_extractor=None,
controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list
)
elif detect.is_zimage(pipeline) and len(controlnets) > 0:
from diffusers import ZImageControlNetPipeline
self.pipeline = ZImageControlNetPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
transformer=pipeline.transformer,
scheduler=pipeline.scheduler,
controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list
)
self.pipeline.task_args = { 'guidance_scale': 1 }
elif len(loras) > 0:
self.pipeline = pipeline
for lora in loras:

View File

@ -28,3 +28,6 @@ def is_qwen(model):
def is_hunyuandit(model):
return is_compatible(model, pattern='HunyuanDiT')
def is_zimage(model):
return is_compatible(model, pattern='ZImage')

View File

@ -238,8 +238,8 @@ def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
torch.cuda.synchronize()
torch.cuda.empty_cache() # cuda gc
torch.cuda.ipc_collect()
except Exception:
pass
except Exception as e:
log.error(f'GC: {e}')
else:
return gpu, ram
t1 = time.time()
@ -390,7 +390,7 @@ def test_triton(early: bool = False):
def test_triton_func(a,b,c):
return a * b + c
test_triton_func = torch.compile(test_triton_func, fullgraph=True)
test_triton_func(torch.randn(32, device=device), torch.randn(32, device=device), torch.randn(32, device=device))
test_triton_func(torch.randn(16, device=device), torch.randn(16, device=device), torch.randn(16, device=device))
triton_ok = True
else:
triton_ok = False

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import os
from datetime import datetime
import git
@ -5,7 +6,7 @@ from modules import shared, errors
from modules.paths import extensions_dir, extensions_builtin_dir
extensions = []
extensions: list[Extension] = []
if not os.path.exists(extensions_dir):
os.makedirs(extensions_dir)

View File

@ -12,7 +12,7 @@ progress_ok = True
def init_cache():
global cache_data # pylint: disable=global-statement
if cache_data is None:
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True, as_type="dict")
def dump_cache():
@ -22,7 +22,7 @@ def dump_cache():
def cache(subsection):
global cache_data # pylint: disable=global-statement
if cache_data is None:
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True, as_type="dict")
s = cache_data.get(subsection, {})
cache_data[subsection] = s
return s

View File

@ -120,9 +120,9 @@ def atomically_save_image():
if not fn.endswith('.json'):
fn += '.json'
entries = shared.readfile(fn, silent=True)
idx = len(list(entries))
if idx == 0:
if not isinstance(entries, list):
entries = []
idx = len(entries)
entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo }
entries.append(entry)
shared.writefile(entries, fn, mode='w', silent=True)
@ -132,7 +132,7 @@ def atomically_save_image():
save_queue.task_done()
save_queue = queue.Queue()
save_queue: queue.Queue[tuple[Image.Image, str, str, script_callbacks.ImageSaveParams, str, str | None, bool]] = queue.Queue()
save_thread = threading.Thread(target=atomically_save_image, daemon=True)
save_thread.start()

View File

@ -2,6 +2,7 @@ import os
import sys
import time
import json
from typing import overload, Literal
import fasteners
import orjson
from installer import log
@ -10,7 +11,13 @@ from installer import log
locking_available = True # used by file read/write locking
def readfile(filename, silent=False, lock=False):
@overload
def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type: Literal["dict"]) -> dict: ...
@overload
def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type: Literal["list"]) -> list: ...
@overload
def readfile(filename: str, silent: bool = False, lock: bool = False) -> dict | list: ...
def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type="") -> dict | list:
global locking_available # pylint: disable=global-statement
data = {}
lock_file = None
@ -51,6 +58,13 @@ def readfile(filename, silent=False, lock=False):
os.remove(f"{filename}.lock")
except Exception:
locking_available = False
if isinstance(data, list) and as_type == "dict":
data0 = data[0]
if isinstance(data0, dict):
return data0
return {}
if isinstance(data, dict) and as_type == "list":
return [data]
return data

View File

@ -94,9 +94,9 @@ timer.startup.record("torch")
try:
import bitsandbytes # pylint: disable=W0611,C0411
_bnb = True
except Exception:
from diffusers.utils import import_utils
import_utils._bitsandbytes_available = False # pylint: disable=protected-access
_bnb = False
timer.startup.record("bnb")
import transformers # pylint: disable=W0611,C0411
@ -134,6 +134,7 @@ try:
import diffusers.utils.import_utils # pylint: disable=W0611,C0411
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
diffusers.utils.import_utils._bitsandbytes_available = _bnb # pylint: disable=protected-access
import diffusers # pylint: disable=W0611,C0411
import diffusers.loaders.single_file # pylint: disable=W0611,C0411

View File

@ -43,6 +43,7 @@ force_models_diffusers = [ # forced always
'chrono',
'z_image',
'f2',
'longcat',
# video models
'hunyuanvideo',
'hunyuanvideo15'

View File

@ -107,9 +107,7 @@ class NetworkOnDisk:
if self.filename is not None:
fn = os.path.splitext(self.filename)[0] + '.json'
if os.path.exists(fn):
data = shared.readfile(fn, silent=True)
if type(data) is list:
data = data[0]
data = shared.readfile(fn, silent=True, as_type="dict")
return data
def get_desc(self):
@ -118,7 +116,8 @@ class NetworkOnDisk:
if self.filename is not None:
fn = os.path.splitext(self.filename)[0] + '.txt'
if os.path.exists(fn):
return shared.readfile(fn, silent=True)
with open(fn, "r", encoding="utf-8") as file:
return file.read()
return None
def get_alias(self):

View File

@ -76,6 +76,10 @@ def get_model_type(pipe):
model_type = 'x-omni'
elif 'Photoroom' in name:
model_type = 'prx'
elif 'LongCat' in name:
model_type = 'longcat'
elif 'Ovis-Image' in name:
model_type = 'ovis'
# video models
elif "CogVideo" in name:
model_type = 'cogvideo'

View File

@ -1,8 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from installer import log
def options_section(section_identifier, options_dict):
if TYPE_CHECKING:
from collections.abc import Callable
from gradio.components import Component
from modules.shared_legacy import LegacyOption
from modules.ui_components import DropdownEditable
def options_section(section_identifier: tuple[str, str], options_dict: dict[str, OptionInfo | LegacyOption]):
"""Set the `section` value for all OptionInfo/LegacyOption items"""
for v in options_dict.values():
v.section = section_identifier
return options_dict
@ -11,14 +21,14 @@ def options_section(section_identifier, options_dict):
class OptionInfo:
def __init__(
self,
default=None,
default: Any | None = None,
label="",
component=None,
component_args=None,
onchange=None,
section=None,
refresh=None,
folder=None,
component: type[Component] | type[DropdownEditable] | None = None,
component_args: dict | Callable[..., dict] | None = None,
onchange: Callable | None = None,
section: tuple[str, ...] | None = None,
refresh: Callable | None = None,
folder=False,
submit=None,
comment_before='',
comment_after='',
@ -40,7 +50,7 @@ class OptionInfo:
self.exclude = ['sd_model_checkpoint', 'sd_model_refiner', 'sd_vae', 'sd_unet', 'sd_text_encoder']
self.dynamic = callable(component_args)
args = {} if self.dynamic else (component_args or {}) # executing callable here is too expensive
self.visible = args.get('visible', True) and len(self.label) > 2
self.visible = args.get('visible', True) and len(self.label) > 2 # type: ignore - Type checking only sees the value of self.dynamic, not the `callable` check
def needs_reload_ui(self):
return self

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import os
import json
import threading
@ -6,10 +7,12 @@ from modules import cmd_args, errors
from modules.json_helpers import readfile, writefile
from modules.shared_legacy import LegacyOption
from installer import log
if TYPE_CHECKING:
from modules.options import OptionInfo
if TYPE_CHECKING:
from collections.abc import Callable
from modules.options import OptionInfo
cmd_opts = cmd_args.parse_args()
compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order']
@ -21,7 +24,9 @@ class Options():
typemap = {int: float}
debug = os.environ.get('SD_CONFIG_DEBUG', None) is not None
def __init__(self, options_templates:dict={}, restricted_opts:dict={}):
def __init__(self, options_templates: dict[str, OptionInfo | LegacyOption] = {}, restricted_opts: set[str] | None = None):
if restricted_opts is None:
restricted_opts = set()
self.data_labels = options_templates
self.restricted_opts = restricted_opts
self.data = {k: v.default for k, v in self.data_labels.items()}
@ -163,12 +168,12 @@ class Options():
log.debug(f'Settings: fn="{filename}" created')
self.save(filename)
return
self.data = readfile(filename, lock=True)
self.data = readfile(filename, lock=True, as_type="dict")
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
unknown_settings = []
for k, v in self.data.items():
info: OptionInfo = self.data_labels.get(k, None)
info: OptionInfo | None = self.data_labels.get(k, None)
if info is not None:
if not info.validate(k, v):
self.data[k] = info.default
@ -180,7 +185,7 @@ class Options():
if len(unknown_settings) > 0:
log.warning(f"Setting validation: unknown={unknown_settings}")
def onchange(self, key, func, call=True):
def onchange(self, key, func: Callable, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:

View File

@ -18,6 +18,9 @@ predefined = [ # <https://huggingface.co/vladmandic/yolo-detailers/tree/main>
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/person_yolov8n-seg.pt',
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/eyes-v1.pt',
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/eyes-full-v1.pt',
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/anzhc-eyes-seg.pt',
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/anzhc-face-1024-seg-8n.pt',
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/anzhc-head-seg-8n.pt',
'https://huggingface.co/netrunner-exe/Face-Upscalers-onnx/resolve/main/codeformer.fp16.onnx',
'https://huggingface.co/netrunner-exe/Face-Upscalers-onnx/resolve/main/restoreformer.fp16.onnx',
'https://huggingface.co/netrunner-exe/Face-Upscalers-onnx/resolve/main/GFPGANv1.4.fp16.onnx',
@ -136,25 +139,45 @@ class YoloRestorer(Detailer):
boxes = prediction.boxes.xyxy.detach().int().cpu().numpy() if prediction.boxes is not None else []
scores = prediction.boxes.conf.detach().float().cpu().numpy() if prediction.boxes is not None else []
classes = prediction.boxes.cls.detach().float().cpu().numpy() if prediction.boxes is not None else []
for score, box, cls in zip(scores, boxes, classes):
masks = prediction.masks.data.cpu().float().numpy() if prediction.masks is not None else []
if len(masks) < len(classes):
masks = len(classes) * [None]
for score, box, cls, seg in zip(scores, boxes, classes, masks):
if seg is not None:
try:
seg = (255 * seg).astype(np.uint8)
seg = Image.fromarray(seg).resize(image.size).convert('L')
except Exception:
seg = None
cls = int(cls)
label = prediction.names[cls] if cls < len(prediction.names) else f'cls{cls}'
if len(desired) > 0 and label.lower() not in desired:
continue
box = box.tolist()
mask_image = None
w, h = box[2] - box[0], box[3] - box[1]
x_size, y_size = w/image.width, h/image.height
min_size = shared.opts.detailer_min_size if shared.opts.detailer_min_size >= 0 and shared.opts.detailer_min_size <= 1 else 0
max_size = shared.opts.detailer_max_size if shared.opts.detailer_max_size >= 0 and shared.opts.detailer_max_size <= 1 else 1
if x_size >= min_size and y_size >=min_size and x_size <= max_size and y_size <= max_size:
if mask:
mask_image = image.copy()
mask_image = Image.new('L', image.size, 0)
draw = ImageDraw.Draw(mask_image)
draw.rectangle(box, fill="white", outline=None, width=0)
if shared.opts.detailer_seg and seg is not None:
masked = seg
else:
masked = Image.new('L', image.size, 0)
draw = ImageDraw.Draw(masked)
draw.rectangle(box, fill="white", outline=None, width=0)
cropped = image.crop(box)
res = YoloResult(cls=cls, label=label, score=round(score, 2), box=box, mask=mask_image, item=cropped, width=w, height=h, args=args)
res = YoloResult(
cls=cls,
label=label,
score=round(score, 2),
box=box,
mask=masked,
item=cropped,
width=w,
height=h,
args=args,
)
result.append(res)
if len(result) >= shared.opts.detailer_max:
break
@ -217,18 +240,30 @@ class YoloRestorer(Detailer):
)
return [merged]
def draw_boxes(self, image: Image.Image, items: list[YoloResult]) -> Image.Image:
if isinstance(image, Image.Image):
draw = ImageDraw.Draw(image)
else:
def draw_masks(self, image: Image.Image, items: list[YoloResult]) -> Image.Image:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
font = images.get_font(16)
image = image.convert('RGBA')
size = min(image.width, image.height) // 32
font = images.get_font(size)
color = (0, 190, 190)
shared.log.debug(f'Detailer: draw={items}')
for i, item in enumerate(items):
draw.rectangle(item.box, outline="#00C8C8", width=3)
draw.text((item.box[0]+4, item.box[1]+4), f'{i+1} {item.label} {item.score:.2f}', fill="black", font=font)
draw.text((item.box[0]+2, item.box[1]+2), f'{i+1} {item.label} {item.score:.2f}', fill="white", font=font)
if shared.opts.detailer_seg and item.mask is not None:
mask = item.mask.convert('L')
else:
mask = Image.new('L', image.size, 0)
draw_mask = ImageDraw.Draw(mask)
draw_mask.rectangle(item.box, fill="white", outline=None, width=0)
alpha = mask.point(lambda p: int(p * 0.5))
overlay = Image.new("RGBA", image.size, color + (0,))
overlay.putalpha(alpha)
image = Image.alpha_composite(image, overlay)
draw_text = ImageDraw.Draw(image)
draw_text.text((item.box[0] + 2, item.box[1] - size - 2), f'{i+1} {item.label} {item.score:.2f}', fill="black", font=font)
draw_text.text((item.box[0] + 0, item.box[1] - size - 4), f'{i+1} {item.label} {item.score:.2f}', fill="white", font=font)
image = image.convert("RGB")
return np.array(image)
def restore(self, np_image, p: processing.StableDiffusionProcessing = None):
@ -369,7 +404,7 @@ class YoloRestorer(Detailer):
if shared.opts.detailer_sort:
items = sorted(items, key=lambda x: x.box[0]) # sort items left-to-right to improve consistency
if shared.opts.detailer_save:
annotated = self.draw_boxes(annotated, items)
annotated = self.draw_masks(annotated, items)
for j, item in enumerate(items):
if item.mask is None:
@ -382,7 +417,7 @@ class YoloRestorer(Detailer):
pc.negative_prompts = [pc.negative_prompt]
pc.prompts, pc.network_data = extra_networks.parse_prompts(pc.prompts)
extra_networks.activate(pc, pc.network_data)
shared.log.debug(f'Detail: model="{i+1}:{name}" item={j+1}/{len(items)} box={item.box} label="{item.label} score={item.score:.2f} prompt="{pc.prompt}"')
shared.log.debug(f'Detail: model="{i+1}:{name}" item={j+1}/{len(items)} box={item.box} label="{item.label}" score={item.score:.2f} seg={shared.opts.detailer_seg} prompt="{pc.prompt}"')
pc.init_images = [image]
pc.image_mask = [item.mask]
pc.overlay_images = []
@ -437,7 +472,7 @@ class YoloRestorer(Detailer):
return gr.update(visible=False), gr.update(visible=True, value=value), gr.update(visible=False)
def ui(self, tab: str):
def ui_settings_change(merge, detailers, text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort):
def ui_settings_change(merge, detailers, text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg):
shared.opts.detailer_merge = merge
shared.opts.detailer_models = detailers
shared.opts.detailer_args = text if not self.ui_mode else ''
@ -453,15 +488,18 @@ class YoloRestorer(Detailer):
shared.opts.detailer_sigma_adjust_max = renoise_end
shared.opts.detailer_save = save
shared.opts.detailer_sort = sort
shared.opts.detailer_seg = seg
# shared.opts.detailer_resolution = resolution
shared.opts.save(shared.config_filename, silent=True)
shared.log.debug(f'Detailer settings: models={detailers} classes={classes} strength={strength} conf={min_confidence} max={max_detected} iou={iou} size={min_size}-{max_size} padding={padding} steps={steps} resolution={resolution} save={save} sort={sort}')
shared.log.debug(f'Detailer settings: models={detailers} classes={classes} strength={strength} conf={min_confidence} max={max_detected} iou={iou} size={min_size}-{max_size} padding={padding} steps={steps} resolution={resolution} save={save} sort={sort} seg={seg}')
if not self.ui_mode:
shared.log.debug(f'Detailer expert: {text}')
with gr.Accordion(open=False, label="Detailer", elem_id=f"{tab}_detailer_accordion", elem_classes=["small-accordion"]):
with gr.Row():
enabled = gr.Checkbox(label="Enable detailer pass", elem_id=f"{tab}_detailer_enabled", value=False)
with gr.Row():
seg = gr.Checkbox(label="Use segmentation", elem_id=f"{tab}_detailer_seg", value=shared.opts.detailer_seg, visible=True)
save = gr.Checkbox(label="Include detection results", elem_id=f"{tab}_detailer_save", value=shared.opts.detailer_save, visible=True)
with gr.Row():
merge = gr.Checkbox(label="Merge detailers", elem_id=f"{tab}_detailer_merge", value=shared.opts.detailer_merge, visible=True)
@ -499,20 +537,21 @@ class YoloRestorer(Detailer):
renoise_value = gr.Slider(minimum=0.5, maximum=1.5, step=0.01, label='Renoise', value=shared.opts.detailer_sigma_adjust, elem_id=f"{tab}_detailer_renoise")
renoise_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Renoise end', value=shared.opts.detailer_sigma_adjust_max, elem_id=f"{tab}_detailer_renoise_end")
merge.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
detailers.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
detailers_text.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
classes.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
padding.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
blur.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
min_confidence.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
max_detected.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
min_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
max_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
iou.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
resolution.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
save.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
sort.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
merge.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
detailers.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
detailers_text.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
classes.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
padding.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
blur.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
min_confidence.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
max_detected.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
min_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
max_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
iou.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
resolution.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
save.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
sort.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
seg.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
return enabled, prompt, negative, steps, strength, resolution

View File

@ -127,6 +127,8 @@ def task_specific_kwargs(p, model):
task_args['image'] = [Image.new('RGB', (p.width, p.height), (0, 0, 0))] # monkey-patch so qwen-image-edit pipeline does not error-out on t2i
if ('QwenImageEditPlusPipeline' in model_cls) and (p.init_control is not None) and (len(p.init_control) > 0):
task_args['image'] += p.init_control
if ('QwenImageLayeredPipeline' in model_cls) and (p.init_images is not None) and (len(p.init_images) > 0):
task_args['image'] = p.init_images[0].convert('RGBA')
if ('Flux2' in model_cls) and (p.init_control is not None) and (len(p.init_control) > 0):
task_args['image'] += p.init_control
if ('LatentConsistencyModelPipeline' in model_cls) and (len(p.init_images) > 0):
@ -235,6 +237,11 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
embeds = prompt_parser_diffusers.embedder('prompt_embeds')
if embeds is None:
shared.log.warning('Prompt parser encode: empty prompt embeds')
prompt_parser_diffusers.embedder = None
args['prompt'] = prompts
elif embeds.device == torch.device('meta'):
shared.log.warning('Prompt parser encode: embeds on meta device')
prompt_parser_diffusers.embedder = None
args['prompt'] = prompts
else:
args['prompt_embeds'] = embeds
@ -271,6 +278,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
args['negative_prompt'] = negative_prompts[0]
else:
args['negative_prompt'] = negative_prompts
if 'complex_human_instruction' in possible:
chi = shared.opts.te_complex_human_instruction
p.extra_generation_params["CHI"] = chi
@ -454,7 +462,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
args['max_area'] = args['width'] * args['height']
# handle implicit controlnet
if 'control_image' in possible and 'control_image' not in args and 'image' in args:
if ('control_image' in possible) and ('control_image' not in args) and ('image' in args):
if sd_models.get_diffusers_task(model) != sd_models.DiffusersTaskType.MODULAR:
debug_log('Process: set control image')
args['control_image'] = args['image']

View File

@ -185,6 +185,8 @@ def process_base(p: processing.StableDiffusionProcessing):
output = SimpleNamespace(images=output)
if isinstance(output, Image.Image):
output = SimpleNamespace(images=[output])
if hasattr(output, 'image'):
output.images = output.image
if hasattr(output, 'images'):
shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops)
timer.process.record('pipeline')

View File

@ -408,7 +408,7 @@ def calculate_base_steps(p, use_denoise_start, use_refiner_start):
if len(getattr(p, 'timesteps', [])) > 0:
return None
cls = shared.sd_model.__class__.__name__
if 'Flex' in cls or 'Kontext' in cls or 'Edit' in cls or 'Wan' in cls or 'Flux2' in cls:
if 'Flex' in cls or 'Kontext' in cls or 'Edit' in cls or 'Wan' in cls or 'Flux2' in cls or 'Layered' in cls:
steps = p.steps
elif is_modular():
steps = p.steps

View File

@ -410,7 +410,7 @@ def get_tokens(pipe, msg, prompt):
fn = os.path.join(fn, 'vocab.json')
else:
fn = os.path.join(fn, 'tokenizer', 'vocab.json')
token_dict = shared.readfile(fn, silent=True)
token_dict = shared.readfile(fn, silent=True, as_type="dict")
added_tokens = getattr(tokenizer, 'added_tokens_decoder', {})
for k, v in added_tokens.items():
token_dict[str(v)] = k

View File

@ -329,7 +329,7 @@ def select_checkpoint(op='model', sd_model_checkpoint=None):
def init_metadata():
global sd_metadata # pylint: disable=global-statement
if sd_metadata is None:
sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {}
def extract_thumbnail(filename, data):
@ -349,7 +349,7 @@ def extract_thumbnail(filename, data):
def read_metadata_from_safetensors(filename):
global sd_metadata # pylint: disable=global-statement
if sd_metadata is None:
sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {}
res = sd_metadata.get(filename, None)
if res is not None:
return res

View File

@ -139,6 +139,10 @@ def guess_by_name(fn, current_guess):
new_guess = 'NanoBanana'
elif 'z-image' in fn.lower() or 'z_image' in fn.lower():
new_guess = 'Z-Image'
elif 'longcat-image' in fn.lower():
new_guess = 'LongCat'
elif 'ovis-image' in fn.lower():
new_guess = 'Ovis-Image'
if debug_load:
shared.log.trace(f'Autodetect: method=name file="{fn}" previous="{current_guess}" current="{new_guess}"')
return new_guess or current_guess
@ -150,7 +154,7 @@ def guess_by_diffusers(fn, current_guess):
return current_guess, None
index = os.path.join(fn, 'model_index.json')
if os.path.exists(index) and os.path.isfile(index):
index = shared.readfile(index, silent=True)
index = shared.readfile(index, silent=True, as_type="dict")
name = index.get('_name_or_path', None)
if name is not None and name in exclude_by_name:
return current_guess, None
@ -169,7 +173,7 @@ def guess_by_diffusers(fn, current_guess):
is_quant = True
break
if folder.endswith('config.json'):
quantization_config = shared.readfile(folder, silent=True).get("quantization_config", None)
quantization_config = shared.readfile(folder, silent=True, as_type="dict").get("quantization_config", None)
if quantization_config is not None:
is_quant = True
break
@ -180,7 +184,7 @@ def guess_by_diffusers(fn, current_guess):
is_quant = True
break
if f.endswith('config.json'):
quantization_config = shared.readfile(f, silent=True).get("quantization_config", None)
quantization_config = shared.readfile(f, silent=True, as_type="dict").get("quantization_config", None)
if quantization_config is not None:
is_quant = True
break

View File

@ -184,6 +184,23 @@ def set_diffuser_options(sd_model, vae=None, op:str='model', offload:bool=True,
def move_model(model, device=None, force=False):
def set_execution_device(module, device):
if device == torch.device('cpu'):
return
if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device"): # pylint: disable=protected-access
try:
"""
for k, v in module.named_parameters(recurse=True):
if v.device == torch.device('meta'):
from accelerate.utils import set_module_tensor_to_device
set_module_tensor_to_device(module, k, device, tied_params_map=module._hf_hook.tied_params_map)
"""
module._hf_hook.execution_device = device # pylint: disable=protected-access
# module._hf_hook.offload = True
except Exception as e:
if os.environ.get('SD_MOVE_DEBUG', None):
shared.log.error(f'Model move execution device: device={device} {e}')
if model is None or device is None:
return
@ -204,20 +221,20 @@ def move_model(model, device=None, force=False):
if not isinstance(m, torch.nn.Module) or name in model._exclude_from_cpu_offload: # pylint: disable=protected-access
continue
for module in m.modules():
if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None): # pylint: disable=protected-access
try:
module._hf_hook.execution_device = device # pylint: disable=protected-access
except Exception as e:
if os.environ.get('SD_MOVE_DEBUG', None):
shared.log.error(f'Model move execution device: device={device} {e}')
set_execution_device(module, device)
# set_execution_device(model, device)
if getattr(model, 'has_accelerate', False) and not force:
return
if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device) and not force:
return
try:
t0 = time.time()
try:
if hasattr(model, 'to'):
if model.device == torch.device('meta'):
set_execution_device(model, device)
elif hasattr(model, 'to'):
model.to(device)
if hasattr(model, "prior_pipe"):
model.prior_pipe.to(device)
@ -458,6 +475,14 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
from pipelines.model_z_image import load_z_image
sd_model = load_z_image(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['LongCat']:
from pipelines.model_longcat import load_longcat
sd_model = load_longcat(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['Overfit']:
from pipelines.model_ovis import load_ovis
sd_model = load_ovis(checkpoint_info, diffusers_load_config)
allow_post_quant = False
except Exception as e:
shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
if debug_load:
@ -604,9 +629,9 @@ def load_sdnq_module(fn: str, module_name: str, load_method: str):
quantization_config_path = os.path.join(fn, module_name, 'quantization_config.json')
model_config_path = os.path.join(fn, module_name, 'config.json')
if os.path.exists(quantization_config_path):
quantization_config = shared.readfile(quantization_config_path, silent=True)
quantization_config = shared.readfile(quantization_config_path, silent=True, as_type="dict")
elif os.path.exists(model_config_path):
quantization_config = shared.readfile(model_config_path, silent=True).get("quantization_config", None)
quantization_config = shared.readfile(model_config_path, silent=True, as_type="dict").get("quantization_config", None)
if quantization_config is None:
return None, module_name, 0
model_name = os.path.join(fn, module_name)
@ -1050,11 +1075,13 @@ def copy_diffuser_options(new_pipe, orig_pipe):
new_pipe.feature_extractor = getattr(orig_pipe, 'feature_extractor', None)
new_pipe.mask_processor = getattr(orig_pipe, 'mask_processor', None)
new_pipe.restore_pipeline = getattr(orig_pipe, 'restore_pipeline', None)
new_pipe.task_args = getattr(orig_pipe, 'task_args', None)
new_pipe.is_sdxl = getattr(orig_pipe, 'is_sdxl', False) # a1111 compatibility item
new_pipe.is_sd2 = getattr(orig_pipe, 'is_sd2', False)
new_pipe.is_sd1 = getattr(orig_pipe, 'is_sd1', True)
add_noise_pred_to_diffusers_callback(new_pipe)
if getattr(new_pipe, 'task_args', None) is None:
new_pipe.task_args = {}
new_pipe.task_args.update(getattr(orig_pipe, 'task_args', {}))
if new_pipe.has_accelerate:
set_accelerate(new_pipe)

View File

@ -14,7 +14,7 @@ from modules.timer import process as process_timer
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
verbose = os.environ.get('SD_MOVE_VERBOSE', None) is not None
debug_move = log.trace if debug else lambda *args, **kwargs: None
offload_warn = ['sc', 'sd3', 'f1', 'f2', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma', 'x-omni', 'hunyuanimage', 'hunyuanimage3']
offload_warn = ['sc', 'sd3', 'f1', 'f2', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma', 'x-omni', 'hunyuanimage', 'hunyuanimage3', 'longcat']
offload_post = ['h1']
offload_hook_instance = None
balanced_offload_exclude = ['CogView4Pipeline', 'MeissonicPipeline']
@ -213,6 +213,7 @@ class OffloadHook(accelerate.hooks.ModelHook):
return False
return True
@torch.compiler.disable
def pre_forward(self, module, *args, **kwargs):
_id = id(module)
@ -239,14 +240,26 @@ class OffloadHook(accelerate.hooks.ModelHook):
max_memory = { device_index: self.gpu, "cpu": self.cpu }
device_map = getattr(module, "balanced_offload_device_map", None)
if (device_map is None) or (max_memory != getattr(module, "balanced_offload_max_memory", None)):
device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory, no_split_module_classes=no_split_module_classes, verbose=verbose)
device_map = accelerate.infer_auto_device_map(module,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
verbose=verbose,
clean_result=False,
)
offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
if devices.backend == "directml":
for k, v in device_map.items():
if isinstance(v, int):
device_map[k] = f"{devices.device.type}:{v}" # int implies CUDA or XPU device, but it will break DirectML backend so we add type
if debug:
shared.log.trace(f'Offload: type=balanced op=dispatch map={device_map}')
if device_map is not None:
module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir)
module = accelerate.dispatch_model(module,
main_device=torch.device(devices.device),
device_map=device_map,
offload_dir=offload_dir,
force_hooks=True,
)
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
module.balanced_offload_device_map = device_map
module.balanced_offload_max_memory = max_memory
@ -261,6 +274,7 @@ class OffloadHook(accelerate.hooks.ModelHook):
self.last_pre = _id
return args, kwargs
@torch.compiler.disable
def post_forward(self, module, output):
if self.last_post != id(module):
self.last_post = id(module)
@ -289,6 +303,15 @@ def get_pipe_variants(pipe=None):
def get_module_names(pipe=None, exclude=None):
def is_valid(module):
if isinstance(getattr(pipe, module, None), torch.nn.ModuleDict):
return True
if isinstance(getattr(pipe, module, None), torch.nn.ModuleList):
return True
if isinstance(getattr(pipe, module, None), torch.nn.Module):
return True
return False
if exclude is None:
exclude = []
if pipe is None:
@ -296,12 +319,19 @@ def get_module_names(pipe=None, exclude=None):
pipe = shared.sd_model
else:
return []
if hasattr(pipe, "_internal_dict"):
modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access
else:
modules_names = get_signature(pipe).keys()
modules_names = []
try:
dict_keys = pipe._internal_dict.keys() # pylint: disable=protected-access
modules_names.extend(dict_keys)
except Exception:
pass
try:
dict_keys = get_signature(pipe).keys()
modules_names.extend(dict_keys)
except Exception:
pass
modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')]
modules_names = [m for m in modules_names if isinstance(getattr(pipe, m, None), torch.nn.Module)]
modules_names = [m for m in modules_names if is_valid(m)]
modules_names = sorted(set(modules_names))
return modules_names

View File

@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
flow_models = ['f1', 'f2', 'sd3', 'lumina', 'auraflow', 'sana', 'z_image', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen', 'omnigen2']
flow_models = ['f1', 'f2', 'sd3', 'lumina', 'auraflow', 'sana', 'z_image', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen', 'omnigen2', 'longcat']
warned = False
queue_lock = threading.Lock()

View File

@ -52,7 +52,7 @@ def load_unet(model, repo_id:str=None):
config_file = os.path.splitext(unet_dict[shared.opts.sd_unet])[0] + '.json'
if os.path.exists(config_file):
config = shared.readfile(config_file)
config = shared.readfile(config_file, as_type="dict")
else:
config = None
config_file = 'default'

View File

@ -128,13 +128,13 @@ def apply_vae_config(model_file, vae_file, sd_model):
def get_vae_config():
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(model_file))[0] + '_vae.json')
if config_file is not None and os.path.exists(config_file):
return shared.readfile(config_file)
return shared.readfile(config_file, as_type="dict")
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(vae_file))[0] + '.json') if vae_file else None
if config_file is not None and os.path.exists(config_file):
return shared.readfile(config_file)
return shared.readfile(config_file, as_type="dict")
config_file = os.path.join(paths.sd_configs_path, shared.sd_model_type, 'vae', 'config.json')
if config_file is not None and os.path.exists(config_file):
return shared.readfile(config_file)
return shared.readfile(config_file, as_type="dict")
return {}
if hasattr(sd_model, 'vae') and hasattr(sd_model.vae, 'config'):

View File

@ -38,7 +38,7 @@ prev_cls = ''
prev_type = ''
prev_model = ''
lock = threading.Lock()
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen']
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat']
def warn_once(msg, variant=None):
@ -59,7 +59,7 @@ def get_model(model_type = 'decoder', variant = None):
model_cls = 'sd'
elif model_cls in {'pixartsigma', 'hunyuandit', 'omnigen', 'auraflow'}:
model_cls = 'sdxl'
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma'}:
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma', 'longcat'}:
model_cls = 'f1'
elif model_cls in {'wanai', 'qwen', 'chrono'}:
variant = variant or 'TAE WanVideo'
@ -156,8 +156,8 @@ def decode(latents):
image = vae.decode(tensor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0, 1).detach()
t1 = time.time()
if (t1 - t0) > 1.0 and not first_run:
shared.log.warning(f'Decode: type="taesd" variant="{variant}" time{t1 - t0:.2f}')
if (t1 - t0) > 3.0 and not first_run:
shared.log.warning(f'Decode: type="taesd" variant="{variant}" long decode time={t1 - t0:.2f}')
first_run = False
return image
except Exception as e:

View File

@ -5,7 +5,7 @@ import torch
from modules import shared, devices
sdnq_version = "0.1.2"
sdnq_version = "0.1.3"
dtype_dict = {
"int32": {"min": -2147483648, "max": 2147483647, "num_bits": 32, "sign": 1, "exponent": 0, "mantissa": 31, "target_dtype": torch.int32, "torch_dtype": torch.int32, "storage_dtype": torch.int32, "is_unsigned": False, "is_integer": True, "is_packed": False},
@ -34,17 +34,33 @@ dtype_dict = {
"float8_e5m2": {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": torch.float8_e5m2, "torch_dtype": torch.float8_e5m2, "storage_dtype": torch.float8_e5m2, "is_unsigned": False, "is_integer": False, "is_packed": False},
}
if hasattr(torch, "float8_e4m3fnuz"):
dtype_dict["float8_e4m3fnuz"] = {"min": -240, "max": 240, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": "fp8", "torch_dtype": torch.float8_e4m3fnuz, "storage_dtype": torch.float8_e4m3fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
if hasattr(torch, "float8_e5m2fnuz"):
dtype_dict["float8_e5m2fnuz"] = {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": "fp8", "torch_dtype": torch.float8_e5m2fnuz, "storage_dtype": torch.float8_e5m2fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
dtype_dict["fp32"] = dtype_dict["float32"]
dtype_dict["bf16"] = dtype_dict["bfloat16"]
dtype_dict["fp16"] = dtype_dict["float16"]
dtype_dict["fp8"] = dtype_dict["float8_e4m3fn"]
dtype_dict["bool"] = dtype_dict["uint1"]
torch_dtype_dict = {
torch.int32: "int32",
torch.int16: "int16",
torch.int8: "int8",
torch.uint32: "uint32",
torch.uint16: "uint16",
torch.uint8: "uint8",
torch.float32: "float32",
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float8_e4m3fn: "float8_e4m3fn",
torch.float8_e5m2: "float8_e5m2",
}
if hasattr(torch, "float8_e4m3fnuz"):
dtype_dict["float8_e4m3fnuz"] = {"min": -240, "max": 240, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": "fp8", "torch_dtype": torch.float8_e4m3fnuz, "storage_dtype": torch.float8_e4m3fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
torch_dtype_dict[torch.float8_e4m3fnuz] = "float8_e4m3fnuz"
if hasattr(torch, "float8_e5m2fnuz"):
dtype_dict["float8_e5m2fnuz"] = {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": "fp8", "torch_dtype": torch.float8_e5m2fnuz, "storage_dtype": torch.float8_e5m2fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
torch_dtype_dict[torch.float8_e5m2fnuz] = "float8_e5m2fnuz"
linear_types = {"Linear"}
conv_types = {"Conv1d", "Conv2d", "Conv3d"}
conv_transpose_types = {"ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"}
@ -172,7 +188,7 @@ module_skip_keys_dict = {
{}
],
"ZImageTransformer2DModel": [
["layers.0.adaLN_modulation.0.weight", "t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer"],
["layers.0.adaLN_modulation.0.weight", "t_embedder", "cap_embedder", "siglip_embedder", "all_x_embedder", "all_final_layer"],
{}
],
"HunyuanImage3ForCausalMM": [

View File

@ -97,10 +97,10 @@ def quantize_fp_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str =
@devices.inference_context()
def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
mantissa_difference = mantissa_difference = 1 << (23 - dtype_dict[matmul_dtype]["mantissa"])
mantissa_difference = 1 << (23 - dtype_dict[matmul_dtype]["mantissa"])
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
input = torch.div(input, scale).to(dtype=torch.float32).view(dtype=torch.int32)
input = input.add_(torch.randint_like(input, low=0, high=mantissa_difference)).bitwise_and_(-mantissa_difference).view(dtype=torch.float32)
input = input.add_(torch.randint_like(input, low=0, high=mantissa_difference, dtype=torch.int32)).view(dtype=torch.float32)
input = input.nan_to_num_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"])
return input, scale

View File

@ -50,14 +50,14 @@ def load_files(files: list[str], state_dict: dict = None, key_mapping: dict = No
if isinstance(files, str):
files = [files]
if method is None:
method = 'safetensors'
method = "safetensors"
if state_dict is None:
state_dict = {}
if method == 'safetensors':
if method == "safetensors":
load_safetensors(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
elif method == 'threaded':
elif method == "threaded":
load_threaded(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
elif method == 'streamer':
elif method == "streamer":
load_streamer(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
else:
raise ValueError(f"Unsupported loading method: {method}")

View File

@ -63,7 +63,7 @@ def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "1
model.config.quantization_config.to_json_file(quantization_config_path)
def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: str = None, dtype: torch.dtype = None, device: torch.device = "cpu", dequantize_fp32: bool = None, use_quantized_matmul: bool = None, model_config: dict = None, quantization_config: dict = None, load_method: str = 'safetensors') -> ModelMixin:
def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: str = None, dtype: torch.dtype = None, device: torch.device = "cpu", dequantize_fp32: bool = None, use_quantized_matmul: bool = None, model_config: dict = None, quantization_config: dict = None, load_method: str = "safetensors") -> ModelMixin:
from accelerate import init_empty_weights
with init_empty_weights():

View File

@ -59,8 +59,7 @@ def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[i
else:
if use_stochastic_rounding:
mantissa_difference = 1 << (23 - dtype_dict[weights_dtype]["mantissa"])
quantized_weight = quantized_weight.view(dtype=torch.int32)
quantized_weight = torch.randint_like(quantized_weight, low=0, high=mantissa_difference).add_(quantized_weight).bitwise_and_(-mantissa_difference).view(dtype=torch.float32)
quantized_weight = quantized_weight.view(dtype=torch.int32).add_(torch.randint_like(quantized_weight, low=0, high=mantissa_difference, dtype=torch.int32)).view(dtype=torch.float32)
quantized_weight.nan_to_num_()
quantized_weight = quantized_weight.clamp_(dtype_dict[weights_dtype]["min"], dtype_dict[weights_dtype]["max"]).to(dtype_dict[weights_dtype]["torch_dtype"])
return quantized_weight, scale, zero_point
@ -211,6 +210,7 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
result_shape = None
original_shape = weight.shape
original_stride = weight.stride()
weight = weight.detach()
if torch_dtype is None:
torch_dtype = weight.dtype
@ -229,8 +229,6 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
)
if layer_class_name in conv_types:
if dtype_dict[weights_dtype]["num_bits"] < 4:
weights_dtype = "uint4"
is_conv_type = True
reduction_axes = 1
output_channel_size, channel_size = weight.shape[:2]
@ -242,8 +240,6 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
weight = weight.flatten(1,-1)
reduction_axes = -1
elif layer_class_name in conv_transpose_types:
if dtype_dict[weights_dtype]["num_bits"] < 4:
weights_dtype = "uint4"
is_conv_transpose_type = True
reduction_axes = 0
channel_size, output_channel_size = weight.shape[:2]

View File

@ -2,6 +2,7 @@
Modified from Triton MatMul example.
PyTorch torch._int_mm is broken on backward pass with Nvidia.
AMD RDNA2 doesn't support torch._int_mm, so we use int_mm via Triton.
PyTorch doesn't support FP32 output type with FP16 MM so we use Triton for it too.
"""
import torch
@ -13,47 +14,47 @@ import triton.language as tl
def get_autotune_config():
if triton.runtime.driver.active.get_current_target().backend == "cuda":
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2),
#
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
]
else:
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=2),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=2),
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=2),
#
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8),
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
]
@triton.autotune(configs=get_autotune_config(), key=['M', 'N', 'K', 'stride_bk'])
@triton.autotune(configs=get_autotune_config(), key=["M", "N", "K", "stride_bk"])
@triton.jit
def int_mm_kernel(
a_ptr, b_ptr, c_ptr,
@ -111,7 +112,7 @@ def int_mm(a, b):
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
int_mm_kernel[grid](
a, b, c,
M, N, K,
@ -122,7 +123,7 @@ def int_mm(a, b):
return c
@triton.autotune(configs=get_autotune_config(), key=['M', 'N', 'K', 'stride_bk'])
@triton.autotune(configs=get_autotune_config(), key=["M", "N", "K", "stride_bk"])
@triton.jit
def fp_mm_kernel(
a_ptr, b_ptr, c_ptr,
@ -180,7 +181,7 @@ def fp_mm(a, b):
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
fp_mm_kernel[grid](
a, b, c,
M, N, K,

View File

@ -1,11 +1,12 @@
from __future__ import annotations
import io
import os
import sys
import time
import contextlib
from enum import Enum
from typing import TYPE_CHECKING
import gradio as gr
import diffusers
from modules.json_helpers import readfile, writefile # pylint: disable=W0611
from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611
from modules.shared_defaults import get_default_modes
@ -23,6 +24,12 @@ import modules.styles
import modules.paths as paths
from installer import log, print_dict, console, get_version # pylint: disable=unused-import
if TYPE_CHECKING:
# Behavior modified by __future__.annotations
from diffusers import DiffusionPipeline
from modules.shared_legacy import LegacyOption
from modules.ui_extra_networks import ExtraNetworksPage
class Backend(Enum):
ORIGINAL = 1
@ -30,7 +37,7 @@ class Backend(Enum):
errors.install([gr])
demo: gr.Blocks = None
demo: gr.Blocks | None = None
api = None
url = 'https://github.com/vladmandic/sdnext'
cmd_opts = cmd_args.parse_args()
@ -44,8 +51,8 @@ detailers = []
face_restorers = []
yolo = None
tab_names = []
extra_networks = []
options_templates = {}
extra_networks: list[ExtraNetworksPage] = []
options_templates: dict[str, OptionInfo | LegacyOption] = {}
hypernetworks = {}
settings_components = {}
restricted_opts = {
@ -164,6 +171,11 @@ options_templates.update(options_section(('sd', "Model Loading"), {
options_templates.update(options_section(('model_options', "Model Options"), {
"model_modular_sep": OptionInfo("<h2>Modular Pipelines</h2>", "", gr.HTML),
"model_modular_enable": OptionInfo(False, "Enable modular pipelines (experimental)"),
"model_google_sep": OptionInfo("<h2>Google GenAI</h2>", "", gr.HTML),
"google_use_vertexai": OptionInfo(False, "Google cloud use VertexAI endpoints"),
"google_api_key": OptionInfo("", "Google cloud API key", gr.Textbox),
"google_project_id": OptionInfo("", "Google Cloud project ID", gr.Textbox),
"google_location_id": OptionInfo("", "Google Cloud location ID", gr.Textbox),
"model_sd3_sep": OptionInfo("<h2>Stable Diffusion 3.x</h2>", "", gr.HTML),
"model_sd3_disable_te5": OptionInfo(False, "Disable T5 text encoder"),
"model_h1_sep": OptionInfo("<h2>HiDream</h2>", "", gr.HTML),
@ -173,6 +185,8 @@ options_templates.update(options_section(('model_options', "Model Options"), {
"model_wan_boundary": OptionInfo(0.85, "Stage boundary ratio", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05 }),
"model_chrono_sep": OptionInfo("<h2>ChronoEdit</h2>", "", gr.HTML),
"model_chrono_temporal_steps": OptionInfo(0, "Temporal steps", gr.Slider, {"minimum": 0, "maximum": 50, "step": 1 }),
"model_qwen_layer_sep": OptionInfo("<h2>WanAI</h2>", "", gr.HTML),
"model_qwen_layers": OptionInfo(2, "Qwen layered number of layers", gr.Slider, {"minimum": 2, "maximum": 9, "step": 1 }),
}))
options_templates.update(options_section(('offload', "Model Offloading"), {
@ -626,7 +640,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
"postprocessing_sep_detailer": OptionInfo("<h2>Detailer</h2>", "", gr.HTML),
"detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"),
"detailer_augment": OptionInfo(True, "Detailer use model augment"),
"detailer_augment": OptionInfo(False, "Detailer use model augment"),
"postprocessing_sep_seedvt": OptionInfo("<h2>SeedVT</h2>", "", gr.HTML),
"seedvt_cfg_scale": OptionInfo(3.5, "SeedVR CFG Scale", gr.Slider, {"minimum": 1, "maximum": 15, "step": 1}),
@ -802,6 +816,7 @@ options_templates.update(options_section(('hidden_options', "Hidden options"), {
"detailer_merge": OptionInfo(False, "Merge multiple results from each detailer model", gr.Checkbox, {"visible": False}),
"detailer_sort": OptionInfo(False, "Sort detailer output by location", gr.Checkbox, {"visible": False}),
"detailer_save": OptionInfo(False, "Include detection results", gr.Checkbox, {"visible": False}),
"detailer_seg": OptionInfo(False, "Use segmentation", gr.Checkbox, {"visible": False}),
}))
@ -822,7 +837,7 @@ log.info(f'Engine: backend={backend} compute={devices.backend} device={devices.g
profiler = None
prompt_styles = modules.styles.StyleDatabase(opts)
reference_models = readfile(os.path.join('html', 'reference.json')) if opts.extra_network_reference_enable else {}
reference_models = readfile(os.path.join('html', 'reference.json'), as_type="dict") if opts.extra_network_reference_enable else {}
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or (cmd_opts.server_name or False)) and not cmd_opts.insecure
devices.args = cmd_opts
devices.opts = opts
@ -884,8 +899,8 @@ def restore_defaults(restart=True):
# startup def of shared.sd_model before its redefined in modeldata
sd_model: diffusers.DiffusionPipeline = None # dummy and overwritten by class
sd_refiner: diffusers.DiffusionPipeline = None # dummy and overwritten by class
sd_model: DiffusionPipeline | None = None # dummy and overwritten by class
sd_refiner: DiffusionPipeline | None = None # dummy and overwritten by class
sd_model_type: str = '' # dummy and overwritten by class
sd_refiner_type: str = '' # dummy and overwritten by class
sd_loaded: bool = False # dummy and overwritten by class

View File

@ -13,9 +13,9 @@ def get_default_modes(cmd_opts, mem_stat):
if "gpu" in mem_stat and gpu_memory != 0:
if gpu_memory <= 4:
cmd_opts.lowvram = True
default_offload_mode = "sequential"
default_offload_mode = "balanced"
default_diffusers_offload_min_gpu_memory = 0
log.info(f"Device detect: memory={gpu_memory:.1f} default=sequential optimization=lowvram")
log.info(f"Device detect: memory={gpu_memory:.1f} default=balanced optimization=lowvram")
elif gpu_memory <= 12:
cmd_opts.medvram = True # VAE Tiling and other stuff
default_offload_mode = "balanced"

View File

@ -48,6 +48,7 @@ pipelines = {
'Qwen': getattr(diffusers, 'QwenImagePipeline', None),
'HunyuanImage': getattr(diffusers, 'HunyuanImagePipeline', None),
'Z-Image': getattr(diffusers, 'ZImagePipeline', None),
'LongCat': getattr(diffusers, 'LongCatImagePipeline', None),
# dynamically imported and redefined later
'Meissonic': getattr(diffusers, 'DiffusionPipeline', None),
'Monetico': getattr(diffusers, 'DiffusionPipeline', None),

View File

@ -414,7 +414,7 @@ def update_token_counter(text):
token_count = 0
max_length = 75
if shared.state.job_count > 0:
shared.log.info('Tokenizer busy')
shared.log.debug('Tokenizer busy')
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
from modules import extra_networks
prompt, _ = extra_networks.parse_prompt(text)

View File

@ -29,8 +29,8 @@ sort_ordering = {
}
def get_installed(ext) -> extensions.Extension:
installed: extensions.Extension = [e for e in extensions.extensions if (e.remote or '').startswith(ext['url'].replace('.git', ''))]
def get_installed(ext):
installed = [e for e in extensions.extensions if (e.remote or '').startswith(ext['url'].replace('.git', ''))]
return installed[0] if len(installed) > 0 else None
@ -382,27 +382,27 @@ def create_html(search_text, sort_column):
if ext.get('status', None) is None or type(ext['status']) == str: # old format
ext['status'] = 0
if ext['url'] is None or ext['url'] == '':
status = f"<span style='cursor:pointer;color:#00C0FD' title='Local'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Local'>{ui_symbols.svg_bullet.color('#00C0FD')}</div>"
elif ext['status'] > 0:
if ext['status'] == 1:
status = f"<span style='cursor:pointer;color:#00FD9C ' title='Verified'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Verified'>{ui_symbols.svg_bullet.color('#00FD9C')}</div>"
elif ext['status'] == 2:
status = f"<span style='cursor:pointer;color:#FFC300' title='Supported only with backend:Original'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Supported only with backend: Original'>{ui_symbols.svg_bullet.color('#FFC300')}</div>"
elif ext['status'] == 3:
status = f"<span style='cursor:pointer;color:#FFC300' title='Supported only with backend:Diffusers'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Supported only with backend: Diffusers'>{ui_symbols.svg_bullet.color('#FFC300')}</div>"
elif ext['status'] == 4:
status = f"<span style='cursor:pointer;color:#4E22FF' title=\"{ext.get('note', 'custom value')}\">{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title=\"{html.escape(ext.get('note', 'custom value'))}\">{ui_symbols.svg_bullet.color('#4E22FF')}</div>"
elif ext['status'] == 5:
status = f"<span style='cursor:pointer;color:#CE0000' title='Not supported'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Not supported'>{ui_symbols.svg_bullet.color('#CE0000')}</div>"
elif ext['status'] == 6:
status = f"<span style='cursor:pointer;color:#AEAEAE' title='Just discovered'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Just discovered'>{ui_symbols.svg_bullet.color('#AEAEAE')}</div>"
else:
status = f"<span style='cursor:pointer;color:#008EBC' title='Unknown status'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Unknown status'>{ui_symbols.svg_bullet.color('#008EBC')}</div>"
else:
if updated < datetime.timestamp(datetime.now() - timedelta(6*30)):
status = f"<span style='cursor:pointer;color:#C000CF' title='Unmaintained'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='Unmaintained'>{ui_symbols.svg_bullet.color('#C000CF')}</div>"
else:
status = f"<span style='cursor:pointer;color:#7C7C7C' title='No info'>{ui_symbols.bullet}</span>"
status = f"<div style='cursor:help;width:1em;' title='No info'>{ui_symbols.svg_bullet.color('#7C7C7C')}</div>"
code += f"""
<tr style="display: {visible}">

View File

@ -8,6 +8,7 @@ import html
import base64
import urllib.parse
import threading
from typing import TYPE_CHECKING
from types import SimpleNamespace
from pathlib import Path
from html.parser import HTMLParser
@ -440,7 +441,7 @@ class ExtraNetworksPage:
def update_all_previews(self, items):
global preview_map # pylint: disable=global-statement
if preview_map is None:
preview_map = shared.readfile('html/previews.json', silent=True)
preview_map = shared.readfile('html/previews.json', silent=True, as_type="dict")
t0 = time.time()
reference_path = os.path.abspath(os.path.join('models', 'Reference'))
possible_paths = list(set([os.path.dirname(item['filename']) for item in items] + [reference_path]))
@ -520,12 +521,10 @@ class ExtraNetworksPage:
t0 = time.time()
fn = os.path.splitext(path)[0] + '.json'
if not data and os.path.exists(fn):
data = shared.readfile(fn, silent=True)
data = shared.readfile(fn, silent=True, as_type="dict")
fn = os.path.join(path, 'model_index.json')
if not data and os.path.exists(fn):
data = shared.readfile(fn, silent=True)
if type(data) is list:
data = data[0]
data = shared.readfile(fn, silent=True, as_type="dict")
t1 = time.time()
self.info_time += t1-t0
return data
@ -862,13 +861,15 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
is_valid = (item is not None) and hasattr(item, 'name') and hasattr(item, 'filename')
if is_valid:
if TYPE_CHECKING:
assert item is not None # Part of the definition of "is_valid"
stat_size, stat_mtime = modelstats.stat(item.filename)
if hasattr(item, 'size') and item.size > 0:
stat_size = item.size
if hasattr(item, 'mtime') and item.mtime is not None:
stat_mtime = item.mtime
desc = item.description
fullinfo = shared.readfile(os.path.splitext(item.filename)[0] + '.json', silent=True)
fullinfo = shared.readfile(os.path.splitext(item.filename)[0] + '.json', silent=True, as_type="dict")
if 'modelVersions' in fullinfo: # sanitize massive objects
fullinfo['modelVersions'] = []
info = fullinfo

View File

@ -40,7 +40,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
shared.log.debug(f'Networks: type="reference" autodownload={shared.opts.sd_checkpoint_autodownload} enable={shared.opts.extra_network_reference_enable}')
return []
count = { 'total': 0, 'ready': 0, 'hidden': 0, 'experimental': 0, 'base': 0 }
shared.reference_models = readfile(os.path.join('html', 'reference.json'))
shared.reference_models = readfile(os.path.join('html', 'reference.json'), as_type="dict")
for k, v in shared.reference_models.items():
count['total'] += 1
url = v['path']

View File

@ -118,7 +118,7 @@ class UiLoadsave:
def read_from_file(self):
from modules.shared import readfile
return readfile(self.filename)
return readfile(self.filename, as_type="dict")
def write_to_file(self, current_ui_settings):
from modules.shared import writefile

View File

@ -84,7 +84,7 @@ def create_setting_component(key, is_quicksettings=False):
with gr.Row():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"settings_{key}_refresh")
elif info.folder is not None:
elif info.folder:
with gr.Row():
res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
else:

View File

@ -40,3 +40,24 @@ sort_time_asc = '\uf0de'
sort_time_dsc = '\uf0dd'
style_apply = ''
style_save = ''
class SVGSymbol:
def __init__(self, svg: str):
self.svg = svg
self.before = ""
self.after = ""
self.supports_color = False
if "currentColor" in self.svg:
self.supports_color = True
self.before, self.after = self.svg.split("currentColor", maxsplit=1)
def color(self, color: str):
if self.supports_color:
return self.before + color + self.after
else:
return self.svg
def __str__(self):
return self.svg
svg_bullet = SVGSymbol("<svg style='stroke:currentColor;fill:none;stroke-width:2;' viewBox='0 0 16 16'><circle cx='8' cy='8' r='7'/></svg>")

View File

@ -23,7 +23,7 @@ class Upscaler:
def __init__(self, create_dirs=True):
global models # pylint: disable=global-statement
if models is None:
models = shared.readfile('html/upscalers.json')
models = shared.readfile('html/upscalers.json', as_type="dict")
self.mod_pad_h = None
self.tile_size = shared.opts.upscaler_tile_size
self.tile_pad = shared.opts.upscaler_tile_overlap

View File

@ -69,17 +69,38 @@ class GoogleVeoVideoPipeline():
image=genai.types.Image(image_bytes=image_bytes.getvalue(), mime_type='image/jpeg'),
)
def get_args(self):
from modules.shared import opts
api_key = os.getenv("GOOGLE_API_KEY") or opts.google_api_key
vertex_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
if (api_key is None or len(api_key) == 0) and (vertex_credentials is None or len(vertex_credentials) == 0):
log.error(f'Cloud: model="{self.model}" API key not provided')
return None
use_vertexai = (os.getenv("GOOGLE_GENAI_USE_VERTEXAI") is not None) or opts.google_use_vertexai
project_id = os.getenv("GOOGLE_CLOUD_PROJECT") or opts.google_project_id
location_id = os.getenv("GOOGLE_CLOUD_LOCATION") or opts.google_location_id
args = {
'api_key': api_key,
'vertexai': use_vertexai,
'project': project_id if len(project_id) > 0 else None,
'location': location_id if len(location_id) > 0 else None,
}
args_copy = args.copy()
args_copy['api_key'] = '...' + args_copy['api_key'][-4:] # last 4 chars
args_copy['credentials'] = vertex_credentials
log.debug(f'Cloud: model="{self.model}" args={args_copy}')
return args
def __call__(self, prompt: list[str], width: int, height: int, image: Image.Image = None, num_frames: int = 4*24):
from google import genai
if isinstance(prompt, list) and len(prompt) > 0:
prompt = prompt[0]
if self.client is None:
api_key = os.getenv("GOOGLE_API_KEY", None)
if api_key is None:
log.error(f'Cloud: model="{self.model}" GOOGLE_API_KEY environment variable not set')
args = self.get_args()
if args is None:
return None
self.client = genai.Client(api_key=api_key, vertexai=False)
self.client = genai.Client(**args)
resolution, aspect_ratio = get_size_buckets(width, height)
duration = num_frames // 24
@ -115,11 +136,12 @@ class GoogleVeoVideoPipeline():
log.error(f'Cloud video: model="{self.model}" {operation} {e}')
return None
if operation is None or operation.response is None or operation.response.generated_videos is None or len(operation.response.generated_videos) == 0:
try:
response: genai.types.GeneratedVideo = operation.response.generated_videos[0]
except Exception:
log.error(f'Cloud video: model="{self.model}" no response {operation}')
return None
try:
response: genai.types.GeneratedVideo = operation.response.generated_videos[0]
self.client.files.download(file=response.video)
video_bytes = response.video.video_bytes
return { 'bytes': video_bytes, 'images': [] }

View File

@ -56,6 +56,7 @@ def load_transformer(repo_id, cls_name, load_config=None, subfolder="transformer
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" loader={_loader("diffusers")} args={load_args}')
if dtype is not None:
load_args['torch_dtype'] = dtype
load_args.pop('device_map', None) # single-file uses different syntax
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
transformer = loader(
local_file,

View File

@ -1,5 +1,6 @@
import io
import os
import time
from PIL import Image
from installer import install, reload, log
@ -67,14 +68,35 @@ class GoogleNanoBananaPipeline():
],
)
def get_args(self):
from modules.shared import opts
api_key = os.getenv("GOOGLE_API_KEY") or opts.google_api_key
vertex_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
if (api_key is None or len(api_key) == 0) and (vertex_credentials is None or len(vertex_credentials) == 0):
log.error(f'Cloud: model="{self.model}" API key not provided')
return None
use_vertexai = (os.getenv("GOOGLE_GENAI_USE_VERTEXAI") is not None) or opts.google_use_vertexai
project_id = os.getenv("GOOGLE_CLOUD_PROJECT") or opts.google_project_id
location_id = os.getenv("GOOGLE_CLOUD_LOCATION") or opts.google_location_id
args = {
'api_key': api_key,
'vertexai': use_vertexai,
'project': project_id if len(project_id) > 0 else None,
'location': location_id if len(location_id) > 0 else None,
}
args_copy = args.copy()
args_copy['api_key'] = '...' + args_copy['api_key'][-4:] # last 4 chars
args_copy['credentials'] = vertex_credentials
log.debug(f'Cloud: model="{self.model}" args={args_copy}')
return args
def __call__(self, prompt: list[str], width: int, height: int, image: Image.Image = None):
from google import genai
if self.client is None:
api_key = os.getenv("GOOGLE_API_KEY", None)
if api_key is None:
log.error(f'Cloud: model="{self.model}" GOOGLE_API_KEY environment variable not set')
args = self.get_args()
if args is None:
return None
self.client = genai.Client(api_key=api_key, vertexai=False)
self.client = genai.Client(**args)
image_size, aspect_ratio = get_size_buckets(width, height)
if 'gemini-3' in self.model:
@ -85,14 +107,21 @@ class GoogleNanoBananaPipeline():
response_modalities=["IMAGE"],
image_config=image_config
)
log.debug(f'Cloud: prompt="{prompt}" size={image_size} ar={aspect_ratio} image={image} model="{self.model}"')
log.debug(f'Cloud: model="{self.model}" prompt="{prompt}" size={image_size} ar={aspect_ratio} image={image}')
# log.debug(f'Cloud: config={self.config}')
try:
t0 = time.time()
if image is not None:
response = self.img2img(prompt, image)
else:
response = self.txt2img(prompt)
t1 = time.time()
try:
tokens = response.usage_metadata.total_token_count
except Exception:
tokens = 0
log.debug(f'Cloud: model="{self.model}" tokens={tokens} time={(t1 - t0):.2f}')
except Exception as e:
log.error(f'Cloud: model="{self.model}" {e}')
return None
@ -100,10 +129,16 @@ class GoogleNanoBananaPipeline():
image = None
if getattr(response, 'prompt_feedback', None) is not None:
log.error(f'Cloud: model="{self.model}" {response.prompt_feedback}')
if not hasattr(response, 'candidates') or (response.candidates is None) or (len(response.candidates) == 0):
parts = []
try:
for candidate in response.candidates:
parts.extend(candidate.content.parts)
except Exception:
log.error(f'Cloud: model="{self.model}" no images received')
return None
for part in response.candidates[0].content.parts:
for part in parts:
if part.inline_data is not None:
image = Image.open(io.BytesIO(part.inline_data.data))
return image

View File

@ -0,0 +1,40 @@
import transformers
import diffusers
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
from pipelines import generic
def load_longcat(checkpoint_info, diffusers_load_config=None):
if diffusers_load_config is None:
diffusers_load_config = {}
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
shared.log.debug(f'Load model: type=LongCat repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
transformer = generic.load_transformer(repo_id, cls_name=diffusers.LongCatImageTransformer2DModel, load_config=diffusers_load_config)
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen2_5_VLForConditionalGeneration, load_config=diffusers_load_config)
text_processor = transformers.Qwen2VLProcessor.from_pretrained(repo_id, subfolder='tokenizer', cache_dir=shared.opts.hfcache_dir)
if 'edit' in repo_id.lower():
cls = diffusers.LongCatImageEditPipeline
else:
cls = diffusers.LongCatImagePipeline
pipe = cls.from_pretrained(
repo_id,
cache_dir=shared.opts.diffusers_dir,
transformer=transformer,
text_encoder=text_encoder,
text_processor=text_processor,
**load_args,
)
del transformer
del text_encoder
del text_processor
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason='load')
return pipe

36
pipelines/model_ovis.py Normal file
View File

@ -0,0 +1,36 @@
import transformers
import diffusers
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
from pipelines import generic
def load_ovis(checkpoint_info, diffusers_load_config=None):
if diffusers_load_config is None:
diffusers_load_config = {}
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
shared.log.debug(f'Load model: type=OvisImage repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
transformer = generic.load_transformer(repo_id, cls_name=diffusers.OvisImageTransformer2DModel, load_config=diffusers_load_config)
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen3Model, load_config=diffusers_load_config)
pipe = diffusers.OvisImagePipeline.from_pretrained(
repo_id,
cache_dir=shared.opts.diffusers_dir,
transformer=transformer,
text_encoder=text_encoder,
**load_args,
)
pipe.task_args = {
'output_type': 'np',
}
del transformer
del text_encoder
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason='load')
return pipe

View File

@ -15,7 +15,7 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model')
shared.log.debug(f'Load model: type=Qwen model="{checkpoint_info.name}" repo="{repo_id}" offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
if '2509' in repo_id :
if '2509' in repo_id or '2511' in repo_id:
cls_name = diffusers.QwenImageEditPlusPipeline
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPlusPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPlusPipeline
@ -25,6 +25,11 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPipeline
elif 'Layered' in repo_id:
cls_name = diffusers.QwenImageLayeredPipeline
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-layered"] = diffusers.QwenImageLayeredPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["qwen-layered"] = diffusers.QwenImageLayeredPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-layered"] = diffusers.QwenImageLayeredPipeline
else:
cls_name = diffusers.QwenImagePipeline
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImagePipeline
@ -69,6 +74,11 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
pipe.task_args = {
'output_type': 'np',
}
if 'Layered' in repo_id:
pipe.task_args['use_en_prompt'] = True
pipe.task_args['cfg_normalize'] = False
pipe.task_args['layers'] = shared.opts.model_qwen_layers
pipe.task_args['resolution'] = 640
del text_encoder
del transformer

View File

@ -11,7 +11,7 @@ def load_z_image(checkpoint_info, diffusers_load_config=None):
sd_models.hf_auth_check(checkpoint_info)
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
shared.log.debug(f'Load model: type=Z-Image repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
shared.log.debug(f'Load model: type=ZImage repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
transformer = generic.load_transformer(repo_id, cls_name=diffusers.ZImageTransformer2DModel, load_config=diffusers_load_config)
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen3ForCausalLM, load_config=diffusers_load_config)

View File

@ -32,6 +32,7 @@ PyWavelets
pi-heif
# versioned
fastapi==0.124.4
rich==14.1.0
safetensors==0.7.0
tensordict==0.8.3

2
wiki

@ -1 +1 @@
Subproject commit 12af554d26d12ce84d43e35f81da89bdbeac4057
Subproject commit 01a5b7af78897212a8d1b32def6ff4bd3d03a352