mirror of https://github.com/vladmandic/automatic
commit
3ae10181dc
|
|
@ -43,7 +43,7 @@ exclude = [
|
|||
]
|
||||
line-length = 250
|
||||
indent-width = 4
|
||||
target-version = "py39"
|
||||
target-version = "py310"
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
|
|
@ -72,6 +72,7 @@ select = [
|
|||
ignore = [
|
||||
"B006", # Do not use mutable data structures for argument defaults
|
||||
"B008", # Do not perform function call in argument defaults
|
||||
"B905", # Strict zip() usage
|
||||
"C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead
|
||||
"C408", # Unnecessary `dict` call
|
||||
"I001", # Import block is un-sorted or un-formatted
|
||||
|
|
@ -81,6 +82,7 @@ ignore = [
|
|||
"E731", # Do not assign a `lambda` expression, use a `def`
|
||||
"E741", # Ambiguous variable name
|
||||
"F401", # Imported by unused
|
||||
"EXE001", # file with shebang is not marked executable
|
||||
"NPY002", # replace legacy random
|
||||
"RUF005", # Consider iterable unpacking
|
||||
"RUF008", # Do not use mutable default values for dataclass
|
||||
|
|
|
|||
54
CHANGELOG.md
54
CHANGELOG.md
|
|
@ -1,5 +1,59 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2025-12-26
|
||||
|
||||
### Highlights for 2025-12-26
|
||||
|
||||
End of year release update, just two weeks after previous one, with several new models and features:
|
||||
- Several new models including highly anticipated **Qwen-Image-Edit 2511** as well as **Qwen-Image-Layered**, **LongCat Image** and **Ovis Image**
|
||||
- New features including support for **Z-Image** *ControlNets* and *fine-tunes* and **Detailer** segmentation support
|
||||
|
||||
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
|
||||
|
||||
### Details for 2025-12-26
|
||||
|
||||
- **Models**
|
||||
- [LongCat Image](https://github.com/meituan-longcat/LongCat-Image) in *Image* and *Image Edit* variants
|
||||
LongCat is a new 8B diffusion base model using Qwen-2.5 as text encoder
|
||||
- [Qwen-Image-Edit 2511](Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants
|
||||
Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability
|
||||
- [Qwen-Image-Layered](https://huggingface.co/Qwen/Qwen-Image-Layered) in *base* and *pre-quantized* variants
|
||||
Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers
|
||||
*note*: set number of desired output layers in *settings -> model options*
|
||||
- [Ovis Image 7B](https://huggingface.co/AIDC-AI/Ovis-Image-7B)
|
||||
Ovis Image is a new text-to-image base model based on Qwen3 text-encoder and optimized for text-rendering
|
||||
- **Features**
|
||||
- Google **Gemini** and **Veo** models support for both *Dev* and *Vertex* access methods
|
||||
see [docs](https://vladmandic.github.io/sdnext-docs/Google-GenAI/) for details
|
||||
- **Z-Image Turbo** support loading transformer file-tunes in safetensors format
|
||||
as with any transformers/unet finetunes, place them then `models/unet`
|
||||
and use **UNET Model** to load safetensors file as they are not complete models
|
||||
- **Z-Image Turbo** support for **ControlNet Union**
|
||||
includes 1.0, 2.0 and 2.1 variants
|
||||
- **Detailer** support for segmentation models
|
||||
some detection models can produce exact segmentation mask and not just box
|
||||
to enable, set `use segmentation` option
|
||||
added segmentation models: *anzhc-eyes-seg*, *anzhc-face-1024-seg-8n*, *anzhc-head-seg-8n*
|
||||
- **Internal**
|
||||
- update nightlies to `rocm==7.1`
|
||||
- mark `python==3.9` as deprecated
|
||||
- extensions improved status indicators, thanks @awsr
|
||||
- additional type-safety checks, thanks @awsr
|
||||
- add model info to ui overlay
|
||||
- **Wiki/Docs/Illustrations**
|
||||
- update models page, thanks @alerikaisattera
|
||||
- update reference models samples, thanks @liutyi
|
||||
- **Fixes**
|
||||
- generate forever fix loop checks, thanks @awsr
|
||||
- tokenizer expclit use for flux2, thanks @CalamitousFelicitousness
|
||||
- torch.compile skip offloading steps
|
||||
- kanvas css with standardui
|
||||
- control input media with non-english locales
|
||||
- handle embeds when on meta device
|
||||
- improve offloading when model has manual modules
|
||||
- ui section colapsible state, thanks @awsr
|
||||
- ui filter by model type
|
||||
|
||||
## Update for 2025-12-11
|
||||
|
||||
### Highlights for 2025-12-11
|
||||
|
|
|
|||
21
TODO.md
21
TODO.md
|
|
@ -1,12 +1,19 @@
|
|||
# TODO
|
||||
|
||||
## Known issues
|
||||
|
||||
- z-image-turbo controlnet device mismatch: <https://github.com/huggingface/diffusers/pull/12886>
|
||||
- z-image-turbo safetensors loader: <https://github.com/huggingface/diffusers/issues/12887>
|
||||
- kandinsky-image-5 hardcoded cuda: <https://github.com/huggingface/diffusers/pull/12814>
|
||||
- peft lora with torch-rocm-windows: <https://github.com/huggingface/peft/pull/2963>
|
||||
|
||||
## Project Board
|
||||
|
||||
- <https://github.com/users/vladmandic/projects>
|
||||
|
||||
## Internal
|
||||
|
||||
- Reimplement llama remover for kanvas
|
||||
- Reimplement `llama` remover for kanvas
|
||||
- Deploy: Create executable for SD.Next
|
||||
- Feature: Integrate natural language image search
|
||||
[ImageDB](https://github.com/vladmandic/imagedb)
|
||||
|
|
@ -21,7 +28,9 @@
|
|||
- UI: Lite vs Expert mode
|
||||
- Video tab: add full API support
|
||||
- Control tab: add overrides handling
|
||||
- Engine: TensorRT acceleration
|
||||
- Engine: `TensorRT` acceleration
|
||||
- Engine: [mmgp](https://github.com/deepbeepmeep/mmgp)
|
||||
- Engine: [sharpfin](https://github.com/drhead/sharpfin) instead of `torchvision`
|
||||
|
||||
## Features
|
||||
|
||||
|
|
@ -38,6 +47,7 @@
|
|||
|
||||
TODO: *Prioritize*!
|
||||
|
||||
- [Cosmos-Predict-2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B)
|
||||
- [NewBie Image Exp0.1](https://github.com/huggingface/diffusers/pull/12803)
|
||||
- [Sana-I2V](https://github.com/huggingface/diffusers/pull/12634#issuecomment-3540534268)
|
||||
- [Bria FIBO](https://huggingface.co/briaai/FIBO)
|
||||
|
|
@ -74,11 +84,6 @@ TODO: *Prioritize*!
|
|||
- [Wan2.2-Animate-14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B)
|
||||
- [WAN2GP](https://github.com/deepbeepmeep/Wan2GP)
|
||||
|
||||
### General
|
||||
|
||||
- Review/improve type-hinting and type checking
|
||||
- A little easier to work with due to syntax changes in Python 3.10
|
||||
|
||||
### Asyncio
|
||||
|
||||
- Policy system is deprecated and will be removed in **Python 3.16**
|
||||
|
|
@ -91,8 +96,6 @@ TODO: *Prioritize*!
|
|||
- [asyncio.run](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.run)
|
||||
- [asyncio.Runner](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.Runner)
|
||||
|
||||
### Shutil
|
||||
|
||||
#### rmtree
|
||||
|
||||
- `onerror` deprecated and replaced with `onexc` in **Python 3.12**
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit c6dc85eb28a02bc7af268497b7a5a596770c5d7b
|
||||
Subproject commit 2a7005fbcf8985644b66121365fa7228a65f34b0
|
||||
|
|
@ -1 +1 @@
|
|||
Subproject commit f3cfab10af26f0c7243878a3c320d50012764694
|
||||
Subproject commit 989a54a5b2ae4ba12fefbf48c9ed61c3663c4c0c
|
||||
|
|
@ -1 +1 @@
|
|||
Subproject commit af99fbab29e9a424c4e79fa8e4ae194481cb5f75
|
||||
Subproject commit ded112e94a94bf64daefa027376e0335fb43e0b7
|
||||
|
|
@ -200,6 +200,24 @@
|
|||
"size": 56.1,
|
||||
"date": "2025 September"
|
||||
},
|
||||
"Qwen-Image-Edit-2511": {
|
||||
"path": "Qwen/Qwen-Image-Edit-2511",
|
||||
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
|
||||
"desc": "Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability.",
|
||||
"skip": true,
|
||||
"extras": "",
|
||||
"size": 56.1,
|
||||
"date": "2025 December"
|
||||
},
|
||||
"Qwen-Image-Layered": {
|
||||
"path": "Qwen/Qwen-Image-Layered",
|
||||
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
|
||||
"desc": "Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers",
|
||||
"skip": true,
|
||||
"extras": "",
|
||||
"size": 53.7,
|
||||
"date": "2025 December"
|
||||
},
|
||||
"Qwen-Image-Lightning": {
|
||||
"path": "vladmandic/Qwen-Lightning",
|
||||
"preview": "vladmandic--Qwen-Lightning.jpg",
|
||||
|
|
@ -314,6 +332,25 @@
|
|||
"date": "2025 July"
|
||||
},
|
||||
|
||||
"Meituan LongCat Image": {
|
||||
"path": "meituan-longcat/LongCat-Image",
|
||||
"preview": "meituan-longcat--LongCat-Image.jpg",
|
||||
"desc": "Pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.",
|
||||
"skip": true,
|
||||
"extras": "",
|
||||
"size": 27.30,
|
||||
"date": "2025 December"
|
||||
},
|
||||
"Meituan LongCat Image-Edit": {
|
||||
"path": "meituan-longcat/LongCat-Image-Edit",
|
||||
"preview": "meituan-longcat--LongCat-Image-Edit.jpg",
|
||||
"desc": "Pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.",
|
||||
"skip": true,
|
||||
"extras": "",
|
||||
"size": 27.30,
|
||||
"date": "2025 December"
|
||||
},
|
||||
|
||||
"Ostris Flex.2 Preview": {
|
||||
"path": "ostris/Flex.2-preview",
|
||||
"preview": "ostris--Flex.2-preview.jpg",
|
||||
|
|
@ -782,7 +819,7 @@
|
|||
"Kandinsky 5.0 T2I Lite": {
|
||||
"path": "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers",
|
||||
"desc": "Kandinsky 5.0 Image Lite is a 6B image generation models 1K resulution, high visual quality and strong text-writing",
|
||||
"preview": "kandinsky-community--kandinsky-3.jpg",
|
||||
"preview": "kandinskylab--Kandinsky-5.0-T2I-Lite-sft-Diffusers.jpg",
|
||||
"skip": true,
|
||||
"size": 33.20,
|
||||
"date": "2025 November"
|
||||
|
|
@ -790,7 +827,7 @@
|
|||
"Kandinsky 5.0 I2I Lite": {
|
||||
"path": "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers",
|
||||
"desc": "Kandinsky 5.0 Image Lite is a 6B image editing models 1K resulution, high visual quality and strong text-writing",
|
||||
"preview": "kandinsky-community--kandinsky-3.jpg",
|
||||
"preview": "kandinskylab--Kandinsky-5.0-T2I-Lite-sft-Diffusers.jpg",
|
||||
"skip": true,
|
||||
"size": 33.20,
|
||||
"date": "2025 November"
|
||||
|
|
@ -901,6 +938,16 @@
|
|||
"date": "2024 January"
|
||||
},
|
||||
|
||||
"AIDC Ovis-Image 7B": {
|
||||
"path": "AIDC-AI/Ovis-Image-7B",
|
||||
"skip": true,
|
||||
"desc": "Built upon Ovis-U1, Ovis-Image is a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints.",
|
||||
"preview": "AIDC-AI--Ovis-Image-7B.jpg",
|
||||
"size": 23.38,
|
||||
"date": "2025 December",
|
||||
"extras": ""
|
||||
},
|
||||
|
||||
"HDM-XUT 340M Anime": {
|
||||
"path": "KBlueLeaf/HDM-xut-340M-anime",
|
||||
"skip": true,
|
||||
|
|
@ -1075,6 +1122,26 @@
|
|||
"size": 16.10,
|
||||
"extras": ""
|
||||
},
|
||||
"Qwen-Image-Edit-2511 sdnq-svd-uint4": {
|
||||
"path": "Disty0/Qwen-Image-Edit-2511-SDNQ-uint4-svd-r32",
|
||||
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
|
||||
"desc": "Quantization of Qwen/Qwen-Image-Edit-2511 using SDNQ: sdnq-svd 4-bit uint with svd rank 32",
|
||||
"skip": true,
|
||||
"tags": "quantized",
|
||||
"date": "2025 December",
|
||||
"size": 16.10,
|
||||
"extras": ""
|
||||
},
|
||||
"Qwen-Image-Layered sdnq-svd-uint4": {
|
||||
"path": "Disty0/Qwen-Image-Layered-SDNQ-uint4-svd-r32",
|
||||
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
|
||||
"desc": "Quantization of Qwen/Qwen-Image-Layered using SDNQ: sdnq-svd 4-bit uint with svd rank 32",
|
||||
"skip": true,
|
||||
"tags": "quantized",
|
||||
"date": "2025 December",
|
||||
"size": 16.10,
|
||||
"extras": ""
|
||||
},
|
||||
"nVidia ChronoEdit sdnq-svd-uint4": {
|
||||
"path": "Disty0/ChronoEdit-14B-SDNQ-uint4-svd-r32",
|
||||
"preview": "Disty0--ChronoEdit-14B-SDNQ-uint4-svd-r32.jpg",
|
||||
|
|
|
|||
51
installer.py
51
installer.py
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import overload, List, Optional
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
|
@ -19,7 +19,6 @@ class Dot(dict): # dot notation access to dictionary attributes
|
|||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
|
||||
version = {
|
||||
'app': 'sd.next',
|
||||
'updated': 'unknown',
|
||||
|
|
@ -94,19 +93,35 @@ def get_log():
|
|||
return log
|
||||
|
||||
|
||||
@overload
|
||||
def str_to_bool(val: str | bool) -> bool: ...
|
||||
@overload
|
||||
def str_to_bool(val: None) -> None: ...
|
||||
def str_to_bool(val: str | bool | None) -> bool | None:
|
||||
if isinstance(val, str):
|
||||
if val.strip() and val.strip().lower() in ("1", "true"):
|
||||
return True
|
||||
return False
|
||||
return val
|
||||
|
||||
|
||||
def install_traceback(suppress: list = []):
|
||||
from rich.traceback import install as traceback_install
|
||||
from rich.pretty import install as pretty_install
|
||||
|
||||
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
|
||||
if width is not None:
|
||||
width = int(width)
|
||||
traceback_install(
|
||||
console=console,
|
||||
extra_lines=os.environ.get('SD_TRACELINES', 1),
|
||||
max_frames=os.environ.get('SD_TRACEFRAMES', 16),
|
||||
width=os.environ.get('SD_TRACEWIDTH', console.width),
|
||||
word_wrap=os.environ.get('SD_TRACEWRAP', False),
|
||||
indent_guides=os.environ.get('SD_TRACEINDENT', False),
|
||||
show_locals=os.environ.get('SD_TRACELOCALS', False),
|
||||
locals_hide_dunder=os.environ.get('SD_TRACEDUNDER', True),
|
||||
locals_hide_sunder=os.environ.get('SD_TRACESUNDER', None),
|
||||
extra_lines=int(os.environ.get("SD_TRACELINES", 1)),
|
||||
max_frames=int(os.environ.get("SD_TRACEFRAMES", 16)),
|
||||
width=width,
|
||||
word_wrap=str_to_bool(os.environ.get("SD_TRACEWRAP", False)),
|
||||
indent_guides=str_to_bool(os.environ.get("SD_TRACEINDENT", False)),
|
||||
show_locals=str_to_bool(os.environ.get("SD_TRACELOCALS", False)),
|
||||
locals_hide_dunder=str_to_bool(os.environ.get("SD_TRACEDUNDER", True)),
|
||||
locals_hide_sunder=str_to_bool(os.environ.get("SD_TRACESUNDER", None)),
|
||||
suppress=suppress,
|
||||
)
|
||||
pretty_install(console=console)
|
||||
|
|
@ -633,7 +648,7 @@ def check_diffusers():
|
|||
t_start = time.time()
|
||||
if args.skip_all:
|
||||
return
|
||||
sha = '3d02cd543ef3101d821cb09c8fcab23c6e7ead33' # diffusers commit hash
|
||||
sha = 'f6b6a7181eb44f0120b29cd897c129275f366c2a' # diffusers commit hash
|
||||
# if args.use_rocm or args.use_zluda or args.use_directml:
|
||||
# sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now
|
||||
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
|
||||
|
|
@ -781,10 +796,10 @@ def install_rocm_zluda():
|
|||
else:
|
||||
#check_python(supported_minors=[10, 11, 12, 13, 14], reason='ROCm backend requires a Python version between 3.10 and 3.13')
|
||||
if args.use_nightly:
|
||||
if rocm.version is None or float(rocm.version) >= 7.0: # assume the latest if version check fails
|
||||
if rocm.version is None or float(rocm.version) >= 7.1: # assume the latest if version check fails
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1')
|
||||
else: # oldest rocm version on nightly is 7.0
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0')
|
||||
else: # oldest rocm version on nightly is 6.4
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.4')
|
||||
else:
|
||||
if rocm.version is None or float(rocm.version) >= 6.4: # assume the latest if version check fails
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+rocm6.4 torchvision==0.24.1+rocm6.4 --index-url https://download.pytorch.org/whl/rocm6.4')
|
||||
|
|
@ -1714,7 +1729,7 @@ def add_args(parser):
|
|||
group_install.add_argument('--skip-env', default=os.environ.get("SD_SKIPENV",False), action='store_true', help="Skips setting of env variables during startup, default: %(default)s")
|
||||
|
||||
group_compute = parser.add_argument_group('Compute Engine')
|
||||
group_compute.add_argument("--device-id", type=str, default=os.environ.get("SD_DEVICEID", None), help="Select the default CUDA device to use, default: %(default)s")
|
||||
group_compute.add_argument("--device-id", type=str, default=os.environ.get("SD_DEVICEID", None), help="Select the default GPU device to use, default: %(default)s")
|
||||
group_compute.add_argument("--use-cuda", default=os.environ.get("SD_USECUDA",False), action='store_true', help="Force use nVidia CUDA backend, default: %(default)s")
|
||||
group_compute.add_argument("--use-ipex", default=os.environ.get("SD_USEIPEX",False), action='store_true', help="Force use Intel OneAPI XPU backend, default: %(default)s")
|
||||
group_compute.add_argument("--use-rocm", default=os.environ.get("SD_USEROCM",False), action='store_true', help="Force use AMD ROCm backend, default: %(default)s")
|
||||
|
|
@ -1730,9 +1745,11 @@ def add_args(parser):
|
|||
group_paths.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
|
||||
group_paths.add_argument("--extensions-dir", type=str, default=os.environ.get("SD_EXTENSIONSDIR", None), help="Base path where all extensions are stored, default: %(default)s",)
|
||||
|
||||
group_ui = parser.add_argument_group('UI')
|
||||
group_ui.add_argument('--theme', type=str, default=os.environ.get("SD_THEME", None), help='Override UI theme')
|
||||
group_ui.add_argument('--locale', type=str, default=os.environ.get("SD_LOCALE", None), help='Override UI locale')
|
||||
|
||||
group_http = parser.add_argument_group('HTTP')
|
||||
group_http.add_argument('--theme', type=str, default=os.environ.get("SD_THEME", None), help='Override UI theme')
|
||||
group_http.add_argument('--locale', type=str, default=os.environ.get("SD_LOCALE", None), help='Override UI locale')
|
||||
group_http.add_argument("--server-name", type=str, default=os.environ.get("SD_SERVERNAME", None), help="Sets hostname of server, default: %(default)s")
|
||||
group_http.add_argument("--tls-keyfile", type=str, default=os.environ.get("SD_TLSKEYFILE", None), help="Enable TLS and specify key file, default: %(default)s")
|
||||
group_http.add_argument("--tls-certfile", type=str, default=os.environ.get("SD_TLSCERTFILE", None), help="Enable TLS and specify cert file, default: %(default)s")
|
||||
|
|
|
|||
|
|
@ -100,13 +100,20 @@ const generateForever = (genbuttonid) => {
|
|||
clearInterval(window.generateOnRepeatInterval);
|
||||
window.generateOnRepeatInterval = null;
|
||||
} else {
|
||||
log('generateForever: start');
|
||||
const genbutton = gradioApp().querySelector(genbuttonid);
|
||||
const busy = document.getElementById('progressbar')?.style.display === 'block';
|
||||
if (!busy) genbutton.click();
|
||||
const isBusy = () => {
|
||||
let busy = document.getElementById('progressbar')?.style.display === 'block';
|
||||
if (!busy) {
|
||||
// Also check in Modern UI
|
||||
const outerButton = genbutton.parentElement.closest('button');
|
||||
busy = outerButton?.classList.contains('generate') && outerButton?.classList.contains('active');
|
||||
}
|
||||
return busy;
|
||||
};
|
||||
log('generateForever: start');
|
||||
if (!isBusy()) genbutton.click();
|
||||
window.generateOnRepeatInterval = setInterval(() => {
|
||||
const pbBusy = document.getElementById('progressbar')?.style.display === 'block';
|
||||
if (!pbBusy) genbutton.click();
|
||||
if (!isBusy()) genbutton.click();
|
||||
}, 500);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@ function controlInputMode(inputMode, ...args) {
|
|||
if (updateEl) updateEl.click();
|
||||
const tab = gradioApp().querySelector('#control-tab-input button.selected');
|
||||
if (!tab) return ['Image', ...args];
|
||||
let inputTab = tab.innerText;
|
||||
// let inputTab = tab.innerText;
|
||||
const tabs = Array.from(gradioApp().querySelectorAll('#control-tab-input button'));
|
||||
const tabIdx = tabs.findIndex((btn) => btn.classList.contains('selected'));
|
||||
const tabNames = ['Image', 'Video', 'Batch', 'Folder'];
|
||||
let inputTab = tabNames[tabIdx] || 'Image';
|
||||
log('controlInputMode', { mode: inputMode, tab: inputTab, kanvas: typeof Kanvas });
|
||||
if ((inputTab === 'Image') && (typeof 'Kanvas' !== 'undefined')) {
|
||||
inputTab = 'Kanvas';
|
||||
|
|
|
|||
|
|
@ -312,8 +312,9 @@ function extraNetworksFilterVersion(event) {
|
|||
const version = event.target.textContent.trim();
|
||||
const activeTab = getENActiveTab();
|
||||
const activePage = getENActivePage().toLowerCase();
|
||||
const cardContainer = gradioApp().querySelector(`#${activeTab}_${activePage}_cards`);
|
||||
log('extraNetworksFilterVersion', version);
|
||||
let cardContainer = gradioApp().querySelector(`#${activeTab}_${activePage}_cards`);
|
||||
if (!cardContainer) cardContainer = gradioApp().querySelector(`#txt2img_extra_networks_${activePage}_cards`);
|
||||
log('extraNetworksFilterVersion', { version, activeTab, activePage, cardContainer });
|
||||
if (!cardContainer) return;
|
||||
if (cardContainer.dataset.activeVersion === version) {
|
||||
cardContainer.dataset.activeVersion = '';
|
||||
|
|
|
|||
|
|
@ -46,7 +46,9 @@ function forceLogin() {
|
|||
})
|
||||
.then(async (res) => {
|
||||
const json = await res.json();
|
||||
const txt = `${res.status}: ${res.statusText} - ${json.detail}`;
|
||||
let txt = '';
|
||||
if (res.status === 200) txt = 'login verified';
|
||||
else txt = `${res.status}: ${res.statusText} - ${json.detail}`;
|
||||
status.textContent = txt;
|
||||
console.log('login', txt);
|
||||
if (res.status === 200) location.reload();
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
const getModel = () => {
|
||||
const cp = opts?.sd_model_checkpoint || '';
|
||||
if (!cp) return 'unknown model';
|
||||
const noBracket = cp.replace(/\s*\[.*\]\s*$/, ''); // remove trailing [hash]
|
||||
const parts = noBracket.split(/[\\/]/); // split on / or \
|
||||
return parts[parts.length - 1].trim() || 'unknown model';
|
||||
};
|
||||
|
||||
async function updateIndicator(online, data, msg) {
|
||||
const el = document.getElementById('logo_nav');
|
||||
if (!el || !data) return;
|
||||
|
|
@ -5,9 +13,10 @@ async function updateIndicator(online, data, msg) {
|
|||
const date = new Date();
|
||||
const template = `
|
||||
Version: <b>${data.updated}</b><br>
|
||||
Commit: <b>${data.hash}</b><br>
|
||||
Commit: <b>${data.commit}</b><br>
|
||||
Branch: <b>${data.branch}</b><br>
|
||||
Status: ${status}<br>
|
||||
Model: <b>${getModel()}</b><br>
|
||||
Since: ${date.toLocaleString()}<br>
|
||||
`;
|
||||
if (online) {
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ textarea {
|
|||
}
|
||||
|
||||
span {
|
||||
font-size: var(--text-md) !important;
|
||||
font-size: var(--text-md);
|
||||
}
|
||||
|
||||
button {
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 64 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
|
|
@ -8,8 +8,15 @@ except Exception:
|
|||
|
||||
|
||||
nvml_initialized = False
|
||||
warned = False
|
||||
|
||||
|
||||
def warn_once(msg):
|
||||
global warned # pylint: disable=global-statement
|
||||
if not warned:
|
||||
log.error(msg)
|
||||
warned = True
|
||||
|
||||
def get_reason(val):
|
||||
throttle = {
|
||||
1: 'gpu idle',
|
||||
|
|
@ -28,6 +35,8 @@ def get_reason(val):
|
|||
|
||||
def get_nvml():
|
||||
global nvml_initialized # pylint: disable=global-statement
|
||||
if warned:
|
||||
return []
|
||||
try:
|
||||
from modules.memstats import ram_stats
|
||||
if not nvml_initialized:
|
||||
|
|
@ -71,7 +80,7 @@ def get_nvml():
|
|||
# log.debug(f'nmvl: {devices}')
|
||||
return devices
|
||||
except Exception as e:
|
||||
log.error(f'NVML: {e}')
|
||||
warn_once(f'NVML: {e}')
|
||||
return []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def civit_update_metadata(raw:bool=False):
|
|||
model.id = d['modelId']
|
||||
download_civit_meta(model.fn, model.id)
|
||||
fn = os.path.splitext(item['filename'])[0] + '.json'
|
||||
model.meta = readfile(fn, silent=True)
|
||||
model.meta = readfile(fn, silent=True, as_type="dict")
|
||||
model.name = model.meta.get('name', model.name)
|
||||
model.versions = len(model.meta.get('modelVersions', []))
|
||||
versions = model.meta.get('modelVersions', [])
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ def check_active(p, unit_type, units):
|
|||
p.is_tile = p.is_tile or 'tile' in u.mode.lower()
|
||||
p.control_tile = u.tile
|
||||
p.extra_generation_params["Control mode"] = u.mode
|
||||
shared.log.debug(f'Control ControlNet unit: i={num_units} process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}')
|
||||
shared.log.debug(f'Control unit: i={num_units} type=ControlNet process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}')
|
||||
elif unit_type == 'xs' and u.controlnet.model is not None:
|
||||
active_process.append(u.process)
|
||||
active_model.append(u.controlnet)
|
||||
|
|
@ -190,13 +190,13 @@ def check_active(p, unit_type, units):
|
|||
active_start.append(float(u.start))
|
||||
active_end.append(float(u.end))
|
||||
active_units.append(u)
|
||||
shared.log.debug(f'Control ControlNet-XS unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
|
||||
shared.log.debug(f'Control unit: i={num_units} type=ControlNetXS process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
|
||||
elif unit_type == 'lite' and u.controlnet.model is not None:
|
||||
active_process.append(u.process)
|
||||
active_model.append(u.controlnet)
|
||||
active_strength.append(float(u.strength))
|
||||
active_units.append(u)
|
||||
shared.log.debug(f'Control ControlLLite unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
|
||||
shared.log.debug(f'Control unit: i={num_units} type=ControlLLite process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
|
||||
elif unit_type == 'reference':
|
||||
p.override = u.override
|
||||
p.attention = u.attention
|
||||
|
|
@ -209,7 +209,7 @@ def check_active(p, unit_type, units):
|
|||
if u.process.processor_id is not None:
|
||||
active_process.append(u.process)
|
||||
active_units.append(u)
|
||||
shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}')
|
||||
shared.log.debug(f'Control unit: i={num_units} type=Process process={u.process.processor_id}')
|
||||
active_strength.append(float(u.strength))
|
||||
debug_log(f'Control active: process={len(active_process)} model={len(active_model)}')
|
||||
return active_process, active_model, active_strength, active_start, active_end, active_units
|
||||
|
|
@ -654,7 +654,7 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
|
|||
|
||||
debug_log(f'Control: pipeline units={len(active_model)} process={len(active_process)} outputs={len(output_images)}')
|
||||
except Exception as e:
|
||||
shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}')
|
||||
shared.log.error(f'Control: type={unit_type} units={len(active_model)} {e}')
|
||||
errors.display(e, 'Control')
|
||||
|
||||
if len(output_images) == 0:
|
||||
|
|
|
|||
|
|
@ -134,15 +134,15 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
|
|||
if image_file is None:
|
||||
self.process.override = None
|
||||
self.override = None
|
||||
log.debug('Control process clear image')
|
||||
log.debug('Control image: clear')
|
||||
return gr.update(value=None)
|
||||
try:
|
||||
self.process.override = Image.open(image_file.name)
|
||||
self.override = self.process.override
|
||||
log.debug(f'Control process upload image: path="{image_file.name}" image={self.process.override}')
|
||||
log.debug(f'Control image: upload={self.process.override} path="{image_file.name}"')
|
||||
return gr.update(visible=self.process.override is not None, value=self.process.override)
|
||||
except Exception as e:
|
||||
log.error(f'Control process upload image failed: path="{image_file.name}" error={e}')
|
||||
log.error(f'Control image: upload path="{image_file.name}" error={e}')
|
||||
return gr.update(visible=False, value=None)
|
||||
|
||||
def reuse_image(image):
|
||||
|
|
|
|||
|
|
@ -111,6 +111,11 @@ predefined_hunyuandit = {
|
|||
"HunyuanDiT Pose": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Pose',
|
||||
"HunyuanDiT Depth": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Depth',
|
||||
}
|
||||
predefined_zimage = {
|
||||
"Z-Image-Turbo Union 1.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union',
|
||||
"Z-Image-Turbo Union 2.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0',
|
||||
"Z-Image-Turbo Union 2.1": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1',
|
||||
}
|
||||
|
||||
variants = {
|
||||
'NoobAI Canny XL': 'fp16',
|
||||
|
|
@ -137,6 +142,7 @@ all_models.update(predefined_f1)
|
|||
all_models.update(predefined_sd3)
|
||||
all_models.update(predefined_qwen)
|
||||
all_models.update(predefined_hunyuandit)
|
||||
all_models.update(predefined_zimage)
|
||||
cache_dir = 'models/control/controlnet'
|
||||
load_lock = threading.Lock()
|
||||
|
||||
|
|
@ -175,6 +181,8 @@ def api_list_models(model_type: str = None):
|
|||
model_list += list(predefined_qwen)
|
||||
if model_type == 'hunyuandit' or model_type == 'all':
|
||||
model_list += list(predefined_hunyuandit)
|
||||
if model_type == 'z_image':
|
||||
model_list += list(predefined_zimage)
|
||||
model_list += sorted(find_models())
|
||||
return model_list
|
||||
|
||||
|
|
@ -199,9 +207,11 @@ def list_models(refresh=False):
|
|||
models = ['None'] + list(predefined_qwen) + sorted(find_models())
|
||||
elif modules.shared.sd_model_type == 'hunyuandit':
|
||||
models = ['None'] + list(predefined_hunyuandit) + sorted(find_models())
|
||||
elif modules.shared.sd_model_type == 'z_image':
|
||||
models = ['None'] + list(predefined_zimage) + sorted(find_models())
|
||||
else:
|
||||
log.warning(f'Control {what} model list failed: unknown model type')
|
||||
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(predefined_f1) + sorted(predefined_sd3) + sorted(find_models())
|
||||
models = ['None'] + list(all_models) + sorted(find_models())
|
||||
debug_log(f'Control list {what}: path={cache_dir} models={models}')
|
||||
return models
|
||||
|
||||
|
|
@ -263,6 +273,14 @@ class ControlNet():
|
|||
elif shared.sd_model_type == 'hunyuandit':
|
||||
from diffusers import HunyuanDiT2DControlNetModel as cls
|
||||
config = 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Canny'
|
||||
elif shared.sd_model_type == 'z_image':
|
||||
from diffusers import ZImageControlNetModel as cls
|
||||
if '2.0' in model_id:
|
||||
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0'
|
||||
elif '2.1' in model_id:
|
||||
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1'
|
||||
else:
|
||||
config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union'
|
||||
else:
|
||||
log.error(f'Control {what}: type={shared.sd_model_type} unsupported model')
|
||||
return None, None
|
||||
|
|
@ -508,6 +526,17 @@ class ControlNetPipeline():
|
|||
feature_extractor=None,
|
||||
controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list
|
||||
)
|
||||
elif detect.is_zimage(pipeline) and len(controlnets) > 0:
|
||||
from diffusers import ZImageControlNetPipeline
|
||||
self.pipeline = ZImageControlNetPipeline(
|
||||
vae=pipeline.vae,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
transformer=pipeline.transformer,
|
||||
scheduler=pipeline.scheduler,
|
||||
controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list
|
||||
)
|
||||
self.pipeline.task_args = { 'guidance_scale': 1 }
|
||||
elif len(loras) > 0:
|
||||
self.pipeline = pipeline
|
||||
for lora in loras:
|
||||
|
|
|
|||
|
|
@ -28,3 +28,6 @@ def is_qwen(model):
|
|||
|
||||
def is_hunyuandit(model):
|
||||
return is_compatible(model, pattern='HunyuanDiT')
|
||||
|
||||
def is_zimage(model):
|
||||
return is_compatible(model, pattern='ZImage')
|
||||
|
|
|
|||
|
|
@ -238,8 +238,8 @@ def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
|
|||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache() # cuda gc
|
||||
torch.cuda.ipc_collect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.error(f'GC: {e}')
|
||||
else:
|
||||
return gpu, ram
|
||||
t1 = time.time()
|
||||
|
|
@ -390,7 +390,7 @@ def test_triton(early: bool = False):
|
|||
def test_triton_func(a,b,c):
|
||||
return a * b + c
|
||||
test_triton_func = torch.compile(test_triton_func, fullgraph=True)
|
||||
test_triton_func(torch.randn(32, device=device), torch.randn(32, device=device), torch.randn(32, device=device))
|
||||
test_triton_func(torch.randn(16, device=device), torch.randn(16, device=device), torch.randn(16, device=device))
|
||||
triton_ok = True
|
||||
else:
|
||||
triton_ok = False
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
from datetime import datetime
|
||||
import git
|
||||
|
|
@ -5,7 +6,7 @@ from modules import shared, errors
|
|||
from modules.paths import extensions_dir, extensions_builtin_dir
|
||||
|
||||
|
||||
extensions = []
|
||||
extensions: list[Extension] = []
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ progress_ok = True
|
|||
def init_cache():
|
||||
global cache_data # pylint: disable=global-statement
|
||||
if cache_data is None:
|
||||
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
|
||||
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True, as_type="dict")
|
||||
|
||||
|
||||
def dump_cache():
|
||||
|
|
@ -22,7 +22,7 @@ def dump_cache():
|
|||
def cache(subsection):
|
||||
global cache_data # pylint: disable=global-statement
|
||||
if cache_data is None:
|
||||
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
|
||||
cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True, as_type="dict")
|
||||
s = cache_data.get(subsection, {})
|
||||
cache_data[subsection] = s
|
||||
return s
|
||||
|
|
|
|||
|
|
@ -120,9 +120,9 @@ def atomically_save_image():
|
|||
if not fn.endswith('.json'):
|
||||
fn += '.json'
|
||||
entries = shared.readfile(fn, silent=True)
|
||||
idx = len(list(entries))
|
||||
if idx == 0:
|
||||
if not isinstance(entries, list):
|
||||
entries = []
|
||||
idx = len(entries)
|
||||
entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo }
|
||||
entries.append(entry)
|
||||
shared.writefile(entries, fn, mode='w', silent=True)
|
||||
|
|
@ -132,7 +132,7 @@ def atomically_save_image():
|
|||
save_queue.task_done()
|
||||
|
||||
|
||||
save_queue = queue.Queue()
|
||||
save_queue: queue.Queue[tuple[Image.Image, str, str, script_callbacks.ImageSaveParams, str, str | None, bool]] = queue.Queue()
|
||||
save_thread = threading.Thread(target=atomically_save_image, daemon=True)
|
||||
save_thread.start()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import json
|
||||
from typing import overload, Literal
|
||||
import fasteners
|
||||
import orjson
|
||||
from installer import log
|
||||
|
|
@ -10,7 +11,13 @@ from installer import log
|
|||
locking_available = True # used by file read/write locking
|
||||
|
||||
|
||||
def readfile(filename, silent=False, lock=False):
|
||||
@overload
|
||||
def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type: Literal["dict"]) -> dict: ...
|
||||
@overload
|
||||
def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type: Literal["list"]) -> list: ...
|
||||
@overload
|
||||
def readfile(filename: str, silent: bool = False, lock: bool = False) -> dict | list: ...
|
||||
def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type="") -> dict | list:
|
||||
global locking_available # pylint: disable=global-statement
|
||||
data = {}
|
||||
lock_file = None
|
||||
|
|
@ -51,6 +58,13 @@ def readfile(filename, silent=False, lock=False):
|
|||
os.remove(f"{filename}.lock")
|
||||
except Exception:
|
||||
locking_available = False
|
||||
if isinstance(data, list) and as_type == "dict":
|
||||
data0 = data[0]
|
||||
if isinstance(data0, dict):
|
||||
return data0
|
||||
return {}
|
||||
if isinstance(data, dict) and as_type == "list":
|
||||
return [data]
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -94,9 +94,9 @@ timer.startup.record("torch")
|
|||
|
||||
try:
|
||||
import bitsandbytes # pylint: disable=W0611,C0411
|
||||
_bnb = True
|
||||
except Exception:
|
||||
from diffusers.utils import import_utils
|
||||
import_utils._bitsandbytes_available = False # pylint: disable=protected-access
|
||||
_bnb = False
|
||||
timer.startup.record("bnb")
|
||||
|
||||
import transformers # pylint: disable=W0611,C0411
|
||||
|
|
@ -134,6 +134,7 @@ try:
|
|||
import diffusers.utils.import_utils # pylint: disable=W0611,C0411
|
||||
diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
|
||||
diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
|
||||
diffusers.utils.import_utils._bitsandbytes_available = _bnb # pylint: disable=protected-access
|
||||
|
||||
import diffusers # pylint: disable=W0611,C0411
|
||||
import diffusers.loaders.single_file # pylint: disable=W0611,C0411
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ force_models_diffusers = [ # forced always
|
|||
'chrono',
|
||||
'z_image',
|
||||
'f2',
|
||||
'longcat',
|
||||
# video models
|
||||
'hunyuanvideo',
|
||||
'hunyuanvideo15'
|
||||
|
|
|
|||
|
|
@ -107,9 +107,7 @@ class NetworkOnDisk:
|
|||
if self.filename is not None:
|
||||
fn = os.path.splitext(self.filename)[0] + '.json'
|
||||
if os.path.exists(fn):
|
||||
data = shared.readfile(fn, silent=True)
|
||||
if type(data) is list:
|
||||
data = data[0]
|
||||
data = shared.readfile(fn, silent=True, as_type="dict")
|
||||
return data
|
||||
|
||||
def get_desc(self):
|
||||
|
|
@ -118,7 +116,8 @@ class NetworkOnDisk:
|
|||
if self.filename is not None:
|
||||
fn = os.path.splitext(self.filename)[0] + '.txt'
|
||||
if os.path.exists(fn):
|
||||
return shared.readfile(fn, silent=True)
|
||||
with open(fn, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
return None
|
||||
|
||||
def get_alias(self):
|
||||
|
|
|
|||
|
|
@ -76,6 +76,10 @@ def get_model_type(pipe):
|
|||
model_type = 'x-omni'
|
||||
elif 'Photoroom' in name:
|
||||
model_type = 'prx'
|
||||
elif 'LongCat' in name:
|
||||
model_type = 'longcat'
|
||||
elif 'Ovis-Image' in name:
|
||||
model_type = 'ovis'
|
||||
# video models
|
||||
elif "CogVideo" in name:
|
||||
model_type = 'cogvideo'
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from installer import log
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from gradio.components import Component
|
||||
from modules.shared_legacy import LegacyOption
|
||||
from modules.ui_components import DropdownEditable
|
||||
|
||||
|
||||
def options_section(section_identifier: tuple[str, str], options_dict: dict[str, OptionInfo | LegacyOption]):
|
||||
"""Set the `section` value for all OptionInfo/LegacyOption items"""
|
||||
for v in options_dict.values():
|
||||
v.section = section_identifier
|
||||
return options_dict
|
||||
|
|
@ -11,14 +21,14 @@ def options_section(section_identifier, options_dict):
|
|||
class OptionInfo:
|
||||
def __init__(
|
||||
self,
|
||||
default=None,
|
||||
default: Any | None = None,
|
||||
label="",
|
||||
component=None,
|
||||
component_args=None,
|
||||
onchange=None,
|
||||
section=None,
|
||||
refresh=None,
|
||||
folder=None,
|
||||
component: type[Component] | type[DropdownEditable] | None = None,
|
||||
component_args: dict | Callable[..., dict] | None = None,
|
||||
onchange: Callable | None = None,
|
||||
section: tuple[str, ...] | None = None,
|
||||
refresh: Callable | None = None,
|
||||
folder=False,
|
||||
submit=None,
|
||||
comment_before='',
|
||||
comment_after='',
|
||||
|
|
@ -40,7 +50,7 @@ class OptionInfo:
|
|||
self.exclude = ['sd_model_checkpoint', 'sd_model_refiner', 'sd_vae', 'sd_unet', 'sd_text_encoder']
|
||||
self.dynamic = callable(component_args)
|
||||
args = {} if self.dynamic else (component_args or {}) # executing callable here is too expensive
|
||||
self.visible = args.get('visible', True) and len(self.label) > 2
|
||||
self.visible = args.get('visible', True) and len(self.label) > 2 # type: ignore - Type checking only sees the value of self.dynamic, not the `callable` check
|
||||
|
||||
def needs_reload_ui(self):
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
|
|
@ -6,10 +7,12 @@ from modules import cmd_args, errors
|
|||
from modules.json_helpers import readfile, writefile
|
||||
from modules.shared_legacy import LegacyOption
|
||||
from installer import log
|
||||
if TYPE_CHECKING:
|
||||
from modules.options import OptionInfo
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from modules.options import OptionInfo
|
||||
|
||||
cmd_opts = cmd_args.parse_args()
|
||||
compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order']
|
||||
|
||||
|
|
@ -21,7 +24,9 @@ class Options():
|
|||
typemap = {int: float}
|
||||
debug = os.environ.get('SD_CONFIG_DEBUG', None) is not None
|
||||
|
||||
def __init__(self, options_templates:dict={}, restricted_opts:dict={}):
|
||||
def __init__(self, options_templates: dict[str, OptionInfo | LegacyOption] = {}, restricted_opts: set[str] | None = None):
|
||||
if restricted_opts is None:
|
||||
restricted_opts = set()
|
||||
self.data_labels = options_templates
|
||||
self.restricted_opts = restricted_opts
|
||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||
|
|
@ -163,12 +168,12 @@ class Options():
|
|||
log.debug(f'Settings: fn="{filename}" created')
|
||||
self.save(filename)
|
||||
return
|
||||
self.data = readfile(filename, lock=True)
|
||||
self.data = readfile(filename, lock=True, as_type="dict")
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
unknown_settings = []
|
||||
for k, v in self.data.items():
|
||||
info: OptionInfo = self.data_labels.get(k, None)
|
||||
info: OptionInfo | None = self.data_labels.get(k, None)
|
||||
if info is not None:
|
||||
if not info.validate(k, v):
|
||||
self.data[k] = info.default
|
||||
|
|
@ -180,7 +185,7 @@ class Options():
|
|||
if len(unknown_settings) > 0:
|
||||
log.warning(f"Setting validation: unknown={unknown_settings}")
|
||||
|
||||
def onchange(self, key, func, call=True):
|
||||
def onchange(self, key, func: Callable, call=True):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
if call:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ predefined = [ # <https://huggingface.co/vladmandic/yolo-detailers/tree/main>
|
|||
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/person_yolov8n-seg.pt',
|
||||
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/eyes-v1.pt',
|
||||
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/eyes-full-v1.pt',
|
||||
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/anzhc-eyes-seg.pt',
|
||||
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/anzhc-face-1024-seg-8n.pt',
|
||||
'https://huggingface.co/vladmandic/yolo-detailers/resolve/main/anzhc-head-seg-8n.pt',
|
||||
'https://huggingface.co/netrunner-exe/Face-Upscalers-onnx/resolve/main/codeformer.fp16.onnx',
|
||||
'https://huggingface.co/netrunner-exe/Face-Upscalers-onnx/resolve/main/restoreformer.fp16.onnx',
|
||||
'https://huggingface.co/netrunner-exe/Face-Upscalers-onnx/resolve/main/GFPGANv1.4.fp16.onnx',
|
||||
|
|
@ -136,25 +139,45 @@ class YoloRestorer(Detailer):
|
|||
boxes = prediction.boxes.xyxy.detach().int().cpu().numpy() if prediction.boxes is not None else []
|
||||
scores = prediction.boxes.conf.detach().float().cpu().numpy() if prediction.boxes is not None else []
|
||||
classes = prediction.boxes.cls.detach().float().cpu().numpy() if prediction.boxes is not None else []
|
||||
for score, box, cls in zip(scores, boxes, classes):
|
||||
masks = prediction.masks.data.cpu().float().numpy() if prediction.masks is not None else []
|
||||
if len(masks) < len(classes):
|
||||
masks = len(classes) * [None]
|
||||
for score, box, cls, seg in zip(scores, boxes, classes, masks):
|
||||
if seg is not None:
|
||||
try:
|
||||
seg = (255 * seg).astype(np.uint8)
|
||||
seg = Image.fromarray(seg).resize(image.size).convert('L')
|
||||
except Exception:
|
||||
seg = None
|
||||
cls = int(cls)
|
||||
label = prediction.names[cls] if cls < len(prediction.names) else f'cls{cls}'
|
||||
if len(desired) > 0 and label.lower() not in desired:
|
||||
continue
|
||||
box = box.tolist()
|
||||
mask_image = None
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
x_size, y_size = w/image.width, h/image.height
|
||||
min_size = shared.opts.detailer_min_size if shared.opts.detailer_min_size >= 0 and shared.opts.detailer_min_size <= 1 else 0
|
||||
max_size = shared.opts.detailer_max_size if shared.opts.detailer_max_size >= 0 and shared.opts.detailer_max_size <= 1 else 1
|
||||
if x_size >= min_size and y_size >=min_size and x_size <= max_size and y_size <= max_size:
|
||||
if mask:
|
||||
mask_image = image.copy()
|
||||
mask_image = Image.new('L', image.size, 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
draw.rectangle(box, fill="white", outline=None, width=0)
|
||||
if shared.opts.detailer_seg and seg is not None:
|
||||
masked = seg
|
||||
else:
|
||||
masked = Image.new('L', image.size, 0)
|
||||
draw = ImageDraw.Draw(masked)
|
||||
draw.rectangle(box, fill="white", outline=None, width=0)
|
||||
cropped = image.crop(box)
|
||||
res = YoloResult(cls=cls, label=label, score=round(score, 2), box=box, mask=mask_image, item=cropped, width=w, height=h, args=args)
|
||||
res = YoloResult(
|
||||
cls=cls,
|
||||
label=label,
|
||||
score=round(score, 2),
|
||||
box=box,
|
||||
mask=masked,
|
||||
item=cropped,
|
||||
width=w,
|
||||
height=h,
|
||||
args=args,
|
||||
)
|
||||
result.append(res)
|
||||
if len(result) >= shared.opts.detailer_max:
|
||||
break
|
||||
|
|
@ -217,18 +240,30 @@ class YoloRestorer(Detailer):
|
|||
)
|
||||
return [merged]
|
||||
|
||||
def draw_boxes(self, image: Image.Image, items: list[YoloResult]) -> Image.Image:
|
||||
if isinstance(image, Image.Image):
|
||||
draw = ImageDraw.Draw(image)
|
||||
else:
|
||||
def draw_masks(self, image: Image.Image, items: list[YoloResult]) -> Image.Image:
|
||||
if not isinstance(image, Image.Image):
|
||||
image = Image.fromarray(image)
|
||||
draw = ImageDraw.Draw(image)
|
||||
font = images.get_font(16)
|
||||
image = image.convert('RGBA')
|
||||
size = min(image.width, image.height) // 32
|
||||
font = images.get_font(size)
|
||||
color = (0, 190, 190)
|
||||
shared.log.debug(f'Detailer: draw={items}')
|
||||
for i, item in enumerate(items):
|
||||
draw.rectangle(item.box, outline="#00C8C8", width=3)
|
||||
draw.text((item.box[0]+4, item.box[1]+4), f'{i+1} {item.label} {item.score:.2f}', fill="black", font=font)
|
||||
draw.text((item.box[0]+2, item.box[1]+2), f'{i+1} {item.label} {item.score:.2f}', fill="white", font=font)
|
||||
if shared.opts.detailer_seg and item.mask is not None:
|
||||
mask = item.mask.convert('L')
|
||||
else:
|
||||
mask = Image.new('L', image.size, 0)
|
||||
draw_mask = ImageDraw.Draw(mask)
|
||||
draw_mask.rectangle(item.box, fill="white", outline=None, width=0)
|
||||
alpha = mask.point(lambda p: int(p * 0.5))
|
||||
overlay = Image.new("RGBA", image.size, color + (0,))
|
||||
overlay.putalpha(alpha)
|
||||
image = Image.alpha_composite(image, overlay)
|
||||
|
||||
draw_text = ImageDraw.Draw(image)
|
||||
draw_text.text((item.box[0] + 2, item.box[1] - size - 2), f'{i+1} {item.label} {item.score:.2f}', fill="black", font=font)
|
||||
draw_text.text((item.box[0] + 0, item.box[1] - size - 4), f'{i+1} {item.label} {item.score:.2f}', fill="white", font=font)
|
||||
image = image.convert("RGB")
|
||||
return np.array(image)
|
||||
|
||||
def restore(self, np_image, p: processing.StableDiffusionProcessing = None):
|
||||
|
|
@ -369,7 +404,7 @@ class YoloRestorer(Detailer):
|
|||
if shared.opts.detailer_sort:
|
||||
items = sorted(items, key=lambda x: x.box[0]) # sort items left-to-right to improve consistency
|
||||
if shared.opts.detailer_save:
|
||||
annotated = self.draw_boxes(annotated, items)
|
||||
annotated = self.draw_masks(annotated, items)
|
||||
|
||||
for j, item in enumerate(items):
|
||||
if item.mask is None:
|
||||
|
|
@ -382,7 +417,7 @@ class YoloRestorer(Detailer):
|
|||
pc.negative_prompts = [pc.negative_prompt]
|
||||
pc.prompts, pc.network_data = extra_networks.parse_prompts(pc.prompts)
|
||||
extra_networks.activate(pc, pc.network_data)
|
||||
shared.log.debug(f'Detail: model="{i+1}:{name}" item={j+1}/{len(items)} box={item.box} label="{item.label} score={item.score:.2f} prompt="{pc.prompt}"')
|
||||
shared.log.debug(f'Detail: model="{i+1}:{name}" item={j+1}/{len(items)} box={item.box} label="{item.label}" score={item.score:.2f} seg={shared.opts.detailer_seg} prompt="{pc.prompt}"')
|
||||
pc.init_images = [image]
|
||||
pc.image_mask = [item.mask]
|
||||
pc.overlay_images = []
|
||||
|
|
@ -437,7 +472,7 @@ class YoloRestorer(Detailer):
|
|||
return gr.update(visible=False), gr.update(visible=True, value=value), gr.update(visible=False)
|
||||
|
||||
def ui(self, tab: str):
|
||||
def ui_settings_change(merge, detailers, text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort):
|
||||
def ui_settings_change(merge, detailers, text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg):
|
||||
shared.opts.detailer_merge = merge
|
||||
shared.opts.detailer_models = detailers
|
||||
shared.opts.detailer_args = text if not self.ui_mode else ''
|
||||
|
|
@ -453,15 +488,18 @@ class YoloRestorer(Detailer):
|
|||
shared.opts.detailer_sigma_adjust_max = renoise_end
|
||||
shared.opts.detailer_save = save
|
||||
shared.opts.detailer_sort = sort
|
||||
shared.opts.detailer_seg = seg
|
||||
# shared.opts.detailer_resolution = resolution
|
||||
shared.opts.save(shared.config_filename, silent=True)
|
||||
shared.log.debug(f'Detailer settings: models={detailers} classes={classes} strength={strength} conf={min_confidence} max={max_detected} iou={iou} size={min_size}-{max_size} padding={padding} steps={steps} resolution={resolution} save={save} sort={sort}')
|
||||
shared.log.debug(f'Detailer settings: models={detailers} classes={classes} strength={strength} conf={min_confidence} max={max_detected} iou={iou} size={min_size}-{max_size} padding={padding} steps={steps} resolution={resolution} save={save} sort={sort} seg={seg}')
|
||||
if not self.ui_mode:
|
||||
shared.log.debug(f'Detailer expert: {text}')
|
||||
|
||||
with gr.Accordion(open=False, label="Detailer", elem_id=f"{tab}_detailer_accordion", elem_classes=["small-accordion"]):
|
||||
with gr.Row():
|
||||
enabled = gr.Checkbox(label="Enable detailer pass", elem_id=f"{tab}_detailer_enabled", value=False)
|
||||
with gr.Row():
|
||||
seg = gr.Checkbox(label="Use segmentation", elem_id=f"{tab}_detailer_seg", value=shared.opts.detailer_seg, visible=True)
|
||||
save = gr.Checkbox(label="Include detection results", elem_id=f"{tab}_detailer_save", value=shared.opts.detailer_save, visible=True)
|
||||
with gr.Row():
|
||||
merge = gr.Checkbox(label="Merge detailers", elem_id=f"{tab}_detailer_merge", value=shared.opts.detailer_merge, visible=True)
|
||||
|
|
@ -499,20 +537,21 @@ class YoloRestorer(Detailer):
|
|||
renoise_value = gr.Slider(minimum=0.5, maximum=1.5, step=0.01, label='Renoise', value=shared.opts.detailer_sigma_adjust, elem_id=f"{tab}_detailer_renoise")
|
||||
renoise_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Renoise end', value=shared.opts.detailer_sigma_adjust_max, elem_id=f"{tab}_detailer_renoise_end")
|
||||
|
||||
merge.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
detailers.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
detailers_text.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
classes.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
padding.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
blur.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
min_confidence.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
max_detected.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
min_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
max_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
iou.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
resolution.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
save.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
sort.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort], outputs=[])
|
||||
merge.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
detailers.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
detailers_text.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
classes.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
padding.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
blur.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
min_confidence.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
max_detected.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
min_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
max_size.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
iou.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
resolution.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
save.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
sort.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
seg.change(fn=ui_settings_change, inputs=[merge, detailers, detailers_text, classes, strength, padding, blur, min_confidence, max_detected, min_size, max_size, iou, steps, renoise_value, renoise_end, resolution, save, sort, seg], outputs=[])
|
||||
return enabled, prompt, negative, steps, strength, resolution
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,8 @@ def task_specific_kwargs(p, model):
|
|||
task_args['image'] = [Image.new('RGB', (p.width, p.height), (0, 0, 0))] # monkey-patch so qwen-image-edit pipeline does not error-out on t2i
|
||||
if ('QwenImageEditPlusPipeline' in model_cls) and (p.init_control is not None) and (len(p.init_control) > 0):
|
||||
task_args['image'] += p.init_control
|
||||
if ('QwenImageLayeredPipeline' in model_cls) and (p.init_images is not None) and (len(p.init_images) > 0):
|
||||
task_args['image'] = p.init_images[0].convert('RGBA')
|
||||
if ('Flux2' in model_cls) and (p.init_control is not None) and (len(p.init_control) > 0):
|
||||
task_args['image'] += p.init_control
|
||||
if ('LatentConsistencyModelPipeline' in model_cls) and (len(p.init_images) > 0):
|
||||
|
|
@ -235,6 +237,11 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
embeds = prompt_parser_diffusers.embedder('prompt_embeds')
|
||||
if embeds is None:
|
||||
shared.log.warning('Prompt parser encode: empty prompt embeds')
|
||||
prompt_parser_diffusers.embedder = None
|
||||
args['prompt'] = prompts
|
||||
elif embeds.device == torch.device('meta'):
|
||||
shared.log.warning('Prompt parser encode: embeds on meta device')
|
||||
prompt_parser_diffusers.embedder = None
|
||||
args['prompt'] = prompts
|
||||
else:
|
||||
args['prompt_embeds'] = embeds
|
||||
|
|
@ -271,6 +278,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
args['negative_prompt'] = negative_prompts[0]
|
||||
else:
|
||||
args['negative_prompt'] = negative_prompts
|
||||
|
||||
if 'complex_human_instruction' in possible:
|
||||
chi = shared.opts.te_complex_human_instruction
|
||||
p.extra_generation_params["CHI"] = chi
|
||||
|
|
@ -454,7 +462,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
args['max_area'] = args['width'] * args['height']
|
||||
|
||||
# handle implicit controlnet
|
||||
if 'control_image' in possible and 'control_image' not in args and 'image' in args:
|
||||
if ('control_image' in possible) and ('control_image' not in args) and ('image' in args):
|
||||
if sd_models.get_diffusers_task(model) != sd_models.DiffusersTaskType.MODULAR:
|
||||
debug_log('Process: set control image')
|
||||
args['control_image'] = args['image']
|
||||
|
|
|
|||
|
|
@ -185,6 +185,8 @@ def process_base(p: processing.StableDiffusionProcessing):
|
|||
output = SimpleNamespace(images=output)
|
||||
if isinstance(output, Image.Image):
|
||||
output = SimpleNamespace(images=[output])
|
||||
if hasattr(output, 'image'):
|
||||
output.images = output.image
|
||||
if hasattr(output, 'images'):
|
||||
shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops)
|
||||
timer.process.record('pipeline')
|
||||
|
|
|
|||
|
|
@ -408,7 +408,7 @@ def calculate_base_steps(p, use_denoise_start, use_refiner_start):
|
|||
if len(getattr(p, 'timesteps', [])) > 0:
|
||||
return None
|
||||
cls = shared.sd_model.__class__.__name__
|
||||
if 'Flex' in cls or 'Kontext' in cls or 'Edit' in cls or 'Wan' in cls or 'Flux2' in cls:
|
||||
if 'Flex' in cls or 'Kontext' in cls or 'Edit' in cls or 'Wan' in cls or 'Flux2' in cls or 'Layered' in cls:
|
||||
steps = p.steps
|
||||
elif is_modular():
|
||||
steps = p.steps
|
||||
|
|
|
|||
|
|
@ -410,7 +410,7 @@ def get_tokens(pipe, msg, prompt):
|
|||
fn = os.path.join(fn, 'vocab.json')
|
||||
else:
|
||||
fn = os.path.join(fn, 'tokenizer', 'vocab.json')
|
||||
token_dict = shared.readfile(fn, silent=True)
|
||||
token_dict = shared.readfile(fn, silent=True, as_type="dict")
|
||||
added_tokens = getattr(tokenizer, 'added_tokens_decoder', {})
|
||||
for k, v in added_tokens.items():
|
||||
token_dict[str(v)] = k
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ def select_checkpoint(op='model', sd_model_checkpoint=None):
|
|||
def init_metadata():
|
||||
global sd_metadata # pylint: disable=global-statement
|
||||
if sd_metadata is None:
|
||||
sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
|
||||
sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {}
|
||||
|
||||
|
||||
def extract_thumbnail(filename, data):
|
||||
|
|
@ -349,7 +349,7 @@ def extract_thumbnail(filename, data):
|
|||
def read_metadata_from_safetensors(filename):
|
||||
global sd_metadata # pylint: disable=global-statement
|
||||
if sd_metadata is None:
|
||||
sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
|
||||
sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {}
|
||||
res = sd_metadata.get(filename, None)
|
||||
if res is not None:
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -139,6 +139,10 @@ def guess_by_name(fn, current_guess):
|
|||
new_guess = 'NanoBanana'
|
||||
elif 'z-image' in fn.lower() or 'z_image' in fn.lower():
|
||||
new_guess = 'Z-Image'
|
||||
elif 'longcat-image' in fn.lower():
|
||||
new_guess = 'LongCat'
|
||||
elif 'ovis-image' in fn.lower():
|
||||
new_guess = 'Ovis-Image'
|
||||
if debug_load:
|
||||
shared.log.trace(f'Autodetect: method=name file="{fn}" previous="{current_guess}" current="{new_guess}"')
|
||||
return new_guess or current_guess
|
||||
|
|
@ -150,7 +154,7 @@ def guess_by_diffusers(fn, current_guess):
|
|||
return current_guess, None
|
||||
index = os.path.join(fn, 'model_index.json')
|
||||
if os.path.exists(index) and os.path.isfile(index):
|
||||
index = shared.readfile(index, silent=True)
|
||||
index = shared.readfile(index, silent=True, as_type="dict")
|
||||
name = index.get('_name_or_path', None)
|
||||
if name is not None and name in exclude_by_name:
|
||||
return current_guess, None
|
||||
|
|
@ -169,7 +173,7 @@ def guess_by_diffusers(fn, current_guess):
|
|||
is_quant = True
|
||||
break
|
||||
if folder.endswith('config.json'):
|
||||
quantization_config = shared.readfile(folder, silent=True).get("quantization_config", None)
|
||||
quantization_config = shared.readfile(folder, silent=True, as_type="dict").get("quantization_config", None)
|
||||
if quantization_config is not None:
|
||||
is_quant = True
|
||||
break
|
||||
|
|
@ -180,7 +184,7 @@ def guess_by_diffusers(fn, current_guess):
|
|||
is_quant = True
|
||||
break
|
||||
if f.endswith('config.json'):
|
||||
quantization_config = shared.readfile(f, silent=True).get("quantization_config", None)
|
||||
quantization_config = shared.readfile(f, silent=True, as_type="dict").get("quantization_config", None)
|
||||
if quantization_config is not None:
|
||||
is_quant = True
|
||||
break
|
||||
|
|
|
|||
|
|
@ -184,6 +184,23 @@ def set_diffuser_options(sd_model, vae=None, op:str='model', offload:bool=True,
|
|||
|
||||
|
||||
def move_model(model, device=None, force=False):
|
||||
def set_execution_device(module, device):
|
||||
if device == torch.device('cpu'):
|
||||
return
|
||||
if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device"): # pylint: disable=protected-access
|
||||
try:
|
||||
"""
|
||||
for k, v in module.named_parameters(recurse=True):
|
||||
if v.device == torch.device('meta'):
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
set_module_tensor_to_device(module, k, device, tied_params_map=module._hf_hook.tied_params_map)
|
||||
"""
|
||||
module._hf_hook.execution_device = device # pylint: disable=protected-access
|
||||
# module._hf_hook.offload = True
|
||||
except Exception as e:
|
||||
if os.environ.get('SD_MOVE_DEBUG', None):
|
||||
shared.log.error(f'Model move execution device: device={device} {e}')
|
||||
|
||||
if model is None or device is None:
|
||||
return
|
||||
|
||||
|
|
@ -204,20 +221,20 @@ def move_model(model, device=None, force=False):
|
|||
if not isinstance(m, torch.nn.Module) or name in model._exclude_from_cpu_offload: # pylint: disable=protected-access
|
||||
continue
|
||||
for module in m.modules():
|
||||
if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None): # pylint: disable=protected-access
|
||||
try:
|
||||
module._hf_hook.execution_device = device # pylint: disable=protected-access
|
||||
except Exception as e:
|
||||
if os.environ.get('SD_MOVE_DEBUG', None):
|
||||
shared.log.error(f'Model move execution device: device={device} {e}')
|
||||
set_execution_device(module, device)
|
||||
# set_execution_device(model, device)
|
||||
|
||||
if getattr(model, 'has_accelerate', False) and not force:
|
||||
return
|
||||
if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device) and not force:
|
||||
return
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
try:
|
||||
if hasattr(model, 'to'):
|
||||
if model.device == torch.device('meta'):
|
||||
set_execution_device(model, device)
|
||||
elif hasattr(model, 'to'):
|
||||
model.to(device)
|
||||
if hasattr(model, "prior_pipe"):
|
||||
model.prior_pipe.to(device)
|
||||
|
|
@ -458,6 +475,14 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
|
|||
from pipelines.model_z_image import load_z_image
|
||||
sd_model = load_z_image(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['LongCat']:
|
||||
from pipelines.model_longcat import load_longcat
|
||||
sd_model = load_longcat(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Overfit']:
|
||||
from pipelines.model_ovis import load_ovis
|
||||
sd_model = load_ovis(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
|
||||
if debug_load:
|
||||
|
|
@ -604,9 +629,9 @@ def load_sdnq_module(fn: str, module_name: str, load_method: str):
|
|||
quantization_config_path = os.path.join(fn, module_name, 'quantization_config.json')
|
||||
model_config_path = os.path.join(fn, module_name, 'config.json')
|
||||
if os.path.exists(quantization_config_path):
|
||||
quantization_config = shared.readfile(quantization_config_path, silent=True)
|
||||
quantization_config = shared.readfile(quantization_config_path, silent=True, as_type="dict")
|
||||
elif os.path.exists(model_config_path):
|
||||
quantization_config = shared.readfile(model_config_path, silent=True).get("quantization_config", None)
|
||||
quantization_config = shared.readfile(model_config_path, silent=True, as_type="dict").get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
return None, module_name, 0
|
||||
model_name = os.path.join(fn, module_name)
|
||||
|
|
@ -1050,11 +1075,13 @@ def copy_diffuser_options(new_pipe, orig_pipe):
|
|||
new_pipe.feature_extractor = getattr(orig_pipe, 'feature_extractor', None)
|
||||
new_pipe.mask_processor = getattr(orig_pipe, 'mask_processor', None)
|
||||
new_pipe.restore_pipeline = getattr(orig_pipe, 'restore_pipeline', None)
|
||||
new_pipe.task_args = getattr(orig_pipe, 'task_args', None)
|
||||
new_pipe.is_sdxl = getattr(orig_pipe, 'is_sdxl', False) # a1111 compatibility item
|
||||
new_pipe.is_sd2 = getattr(orig_pipe, 'is_sd2', False)
|
||||
new_pipe.is_sd1 = getattr(orig_pipe, 'is_sd1', True)
|
||||
add_noise_pred_to_diffusers_callback(new_pipe)
|
||||
if getattr(new_pipe, 'task_args', None) is None:
|
||||
new_pipe.task_args = {}
|
||||
new_pipe.task_args.update(getattr(orig_pipe, 'task_args', {}))
|
||||
if new_pipe.has_accelerate:
|
||||
set_accelerate(new_pipe)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from modules.timer import process as process_timer
|
|||
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
|
||||
verbose = os.environ.get('SD_MOVE_VERBOSE', None) is not None
|
||||
debug_move = log.trace if debug else lambda *args, **kwargs: None
|
||||
offload_warn = ['sc', 'sd3', 'f1', 'f2', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma', 'x-omni', 'hunyuanimage', 'hunyuanimage3']
|
||||
offload_warn = ['sc', 'sd3', 'f1', 'f2', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma', 'x-omni', 'hunyuanimage', 'hunyuanimage3', 'longcat']
|
||||
offload_post = ['h1']
|
||||
offload_hook_instance = None
|
||||
balanced_offload_exclude = ['CogView4Pipeline', 'MeissonicPipeline']
|
||||
|
|
@ -213,6 +213,7 @@ class OffloadHook(accelerate.hooks.ModelHook):
|
|||
return False
|
||||
return True
|
||||
|
||||
@torch.compiler.disable
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
_id = id(module)
|
||||
|
||||
|
|
@ -239,14 +240,26 @@ class OffloadHook(accelerate.hooks.ModelHook):
|
|||
max_memory = { device_index: self.gpu, "cpu": self.cpu }
|
||||
device_map = getattr(module, "balanced_offload_device_map", None)
|
||||
if (device_map is None) or (max_memory != getattr(module, "balanced_offload_max_memory", None)):
|
||||
device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory, no_split_module_classes=no_split_module_classes, verbose=verbose)
|
||||
device_map = accelerate.infer_auto_device_map(module,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=no_split_module_classes,
|
||||
verbose=verbose,
|
||||
clean_result=False,
|
||||
)
|
||||
offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
|
||||
if devices.backend == "directml":
|
||||
for k, v in device_map.items():
|
||||
if isinstance(v, int):
|
||||
device_map[k] = f"{devices.device.type}:{v}" # int implies CUDA or XPU device, but it will break DirectML backend so we add type
|
||||
if debug:
|
||||
shared.log.trace(f'Offload: type=balanced op=dispatch map={device_map}')
|
||||
if device_map is not None:
|
||||
module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir)
|
||||
module = accelerate.dispatch_model(module,
|
||||
main_device=torch.device(devices.device),
|
||||
device_map=device_map,
|
||||
offload_dir=offload_dir,
|
||||
force_hooks=True,
|
||||
)
|
||||
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
|
||||
module.balanced_offload_device_map = device_map
|
||||
module.balanced_offload_max_memory = max_memory
|
||||
|
|
@ -261,6 +274,7 @@ class OffloadHook(accelerate.hooks.ModelHook):
|
|||
self.last_pre = _id
|
||||
return args, kwargs
|
||||
|
||||
@torch.compiler.disable
|
||||
def post_forward(self, module, output):
|
||||
if self.last_post != id(module):
|
||||
self.last_post = id(module)
|
||||
|
|
@ -289,6 +303,15 @@ def get_pipe_variants(pipe=None):
|
|||
|
||||
|
||||
def get_module_names(pipe=None, exclude=None):
|
||||
def is_valid(module):
|
||||
if isinstance(getattr(pipe, module, None), torch.nn.ModuleDict):
|
||||
return True
|
||||
if isinstance(getattr(pipe, module, None), torch.nn.ModuleList):
|
||||
return True
|
||||
if isinstance(getattr(pipe, module, None), torch.nn.Module):
|
||||
return True
|
||||
return False
|
||||
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
if pipe is None:
|
||||
|
|
@ -296,12 +319,19 @@ def get_module_names(pipe=None, exclude=None):
|
|||
pipe = shared.sd_model
|
||||
else:
|
||||
return []
|
||||
if hasattr(pipe, "_internal_dict"):
|
||||
modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access
|
||||
else:
|
||||
modules_names = get_signature(pipe).keys()
|
||||
modules_names = []
|
||||
try:
|
||||
dict_keys = pipe._internal_dict.keys() # pylint: disable=protected-access
|
||||
modules_names.extend(dict_keys)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
dict_keys = get_signature(pipe).keys()
|
||||
modules_names.extend(dict_keys)
|
||||
except Exception:
|
||||
pass
|
||||
modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')]
|
||||
modules_names = [m for m in modules_names if isinstance(getattr(pipe, m, None), torch.nn.Module)]
|
||||
modules_names = [m for m in modules_names if is_valid(m)]
|
||||
modules_names = sorted(set(modules_names))
|
||||
return modules_names
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t
|
|||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
|
||||
flow_models = ['f1', 'f2', 'sd3', 'lumina', 'auraflow', 'sana', 'z_image', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen', 'omnigen2']
|
||||
flow_models = ['f1', 'f2', 'sd3', 'lumina', 'auraflow', 'sana', 'z_image', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen', 'omnigen2', 'longcat']
|
||||
warned = False
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def load_unet(model, repo_id:str=None):
|
|||
|
||||
config_file = os.path.splitext(unet_dict[shared.opts.sd_unet])[0] + '.json'
|
||||
if os.path.exists(config_file):
|
||||
config = shared.readfile(config_file)
|
||||
config = shared.readfile(config_file, as_type="dict")
|
||||
else:
|
||||
config = None
|
||||
config_file = 'default'
|
||||
|
|
|
|||
|
|
@ -128,13 +128,13 @@ def apply_vae_config(model_file, vae_file, sd_model):
|
|||
def get_vae_config():
|
||||
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(model_file))[0] + '_vae.json')
|
||||
if config_file is not None and os.path.exists(config_file):
|
||||
return shared.readfile(config_file)
|
||||
return shared.readfile(config_file, as_type="dict")
|
||||
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(vae_file))[0] + '.json') if vae_file else None
|
||||
if config_file is not None and os.path.exists(config_file):
|
||||
return shared.readfile(config_file)
|
||||
return shared.readfile(config_file, as_type="dict")
|
||||
config_file = os.path.join(paths.sd_configs_path, shared.sd_model_type, 'vae', 'config.json')
|
||||
if config_file is not None and os.path.exists(config_file):
|
||||
return shared.readfile(config_file)
|
||||
return shared.readfile(config_file, as_type="dict")
|
||||
return {}
|
||||
|
||||
if hasattr(sd_model, 'vae') and hasattr(sd_model.vae, 'config'):
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ prev_cls = ''
|
|||
prev_type = ''
|
||||
prev_model = ''
|
||||
lock = threading.Lock()
|
||||
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen']
|
||||
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat']
|
||||
|
||||
|
||||
def warn_once(msg, variant=None):
|
||||
|
|
@ -59,7 +59,7 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
model_cls = 'sd'
|
||||
elif model_cls in {'pixartsigma', 'hunyuandit', 'omnigen', 'auraflow'}:
|
||||
model_cls = 'sdxl'
|
||||
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma'}:
|
||||
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma', 'longcat'}:
|
||||
model_cls = 'f1'
|
||||
elif model_cls in {'wanai', 'qwen', 'chrono'}:
|
||||
variant = variant or 'TAE WanVideo'
|
||||
|
|
@ -156,8 +156,8 @@ def decode(latents):
|
|||
image = vae.decode(tensor, return_dict=False)[0]
|
||||
image = (image / 2.0 + 0.5).clamp(0, 1).detach()
|
||||
t1 = time.time()
|
||||
if (t1 - t0) > 1.0 and not first_run:
|
||||
shared.log.warning(f'Decode: type="taesd" variant="{variant}" time{t1 - t0:.2f}')
|
||||
if (t1 - t0) > 3.0 and not first_run:
|
||||
shared.log.warning(f'Decode: type="taesd" variant="{variant}" long decode time={t1 - t0:.2f}')
|
||||
first_run = False
|
||||
return image
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
|
||||
from modules import shared, devices
|
||||
|
||||
sdnq_version = "0.1.2"
|
||||
sdnq_version = "0.1.3"
|
||||
|
||||
dtype_dict = {
|
||||
"int32": {"min": -2147483648, "max": 2147483647, "num_bits": 32, "sign": 1, "exponent": 0, "mantissa": 31, "target_dtype": torch.int32, "torch_dtype": torch.int32, "storage_dtype": torch.int32, "is_unsigned": False, "is_integer": True, "is_packed": False},
|
||||
|
|
@ -34,17 +34,33 @@ dtype_dict = {
|
|||
"float8_e5m2": {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": torch.float8_e5m2, "torch_dtype": torch.float8_e5m2, "storage_dtype": torch.float8_e5m2, "is_unsigned": False, "is_integer": False, "is_packed": False},
|
||||
}
|
||||
|
||||
if hasattr(torch, "float8_e4m3fnuz"):
|
||||
dtype_dict["float8_e4m3fnuz"] = {"min": -240, "max": 240, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": "fp8", "torch_dtype": torch.float8_e4m3fnuz, "storage_dtype": torch.float8_e4m3fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
|
||||
if hasattr(torch, "float8_e5m2fnuz"):
|
||||
dtype_dict["float8_e5m2fnuz"] = {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": "fp8", "torch_dtype": torch.float8_e5m2fnuz, "storage_dtype": torch.float8_e5m2fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
|
||||
|
||||
dtype_dict["fp32"] = dtype_dict["float32"]
|
||||
dtype_dict["bf16"] = dtype_dict["bfloat16"]
|
||||
dtype_dict["fp16"] = dtype_dict["float16"]
|
||||
dtype_dict["fp8"] = dtype_dict["float8_e4m3fn"]
|
||||
dtype_dict["bool"] = dtype_dict["uint1"]
|
||||
|
||||
torch_dtype_dict = {
|
||||
torch.int32: "int32",
|
||||
torch.int16: "int16",
|
||||
torch.int8: "int8",
|
||||
torch.uint32: "uint32",
|
||||
torch.uint16: "uint16",
|
||||
torch.uint8: "uint8",
|
||||
torch.float32: "float32",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float16: "float16",
|
||||
torch.float8_e4m3fn: "float8_e4m3fn",
|
||||
torch.float8_e5m2: "float8_e5m2",
|
||||
}
|
||||
|
||||
if hasattr(torch, "float8_e4m3fnuz"):
|
||||
dtype_dict["float8_e4m3fnuz"] = {"min": -240, "max": 240, "num_bits": 8, "sign": 1, "exponent": 4, "mantissa": 3, "target_dtype": "fp8", "torch_dtype": torch.float8_e4m3fnuz, "storage_dtype": torch.float8_e4m3fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
|
||||
torch_dtype_dict[torch.float8_e4m3fnuz] = "float8_e4m3fnuz"
|
||||
if hasattr(torch, "float8_e5m2fnuz"):
|
||||
dtype_dict["float8_e5m2fnuz"] = {"min": -57344, "max": 57344, "num_bits": 8, "sign": 1, "exponent": 5, "mantissa": 2, "target_dtype": "fp8", "torch_dtype": torch.float8_e5m2fnuz, "storage_dtype": torch.float8_e5m2fnuz, "is_unsigned": False, "is_integer": False, "is_packed": False}
|
||||
torch_dtype_dict[torch.float8_e5m2fnuz] = "float8_e5m2fnuz"
|
||||
|
||||
linear_types = {"Linear"}
|
||||
conv_types = {"Conv1d", "Conv2d", "Conv3d"}
|
||||
conv_transpose_types = {"ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"}
|
||||
|
|
@ -172,7 +188,7 @@ module_skip_keys_dict = {
|
|||
{}
|
||||
],
|
||||
"ZImageTransformer2DModel": [
|
||||
["layers.0.adaLN_modulation.0.weight", "t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer"],
|
||||
["layers.0.adaLN_modulation.0.weight", "t_embedder", "cap_embedder", "siglip_embedder", "all_x_embedder", "all_final_layer"],
|
||||
{}
|
||||
],
|
||||
"HunyuanImage3ForCausalMM": [
|
||||
|
|
|
|||
|
|
@ -97,10 +97,10 @@ def quantize_fp_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str =
|
|||
|
||||
@devices.inference_context()
|
||||
def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
|
||||
mantissa_difference = mantissa_difference = 1 << (23 - dtype_dict[matmul_dtype]["mantissa"])
|
||||
mantissa_difference = 1 << (23 - dtype_dict[matmul_dtype]["mantissa"])
|
||||
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
|
||||
input = torch.div(input, scale).to(dtype=torch.float32).view(dtype=torch.int32)
|
||||
input = input.add_(torch.randint_like(input, low=0, high=mantissa_difference)).bitwise_and_(-mantissa_difference).view(dtype=torch.float32)
|
||||
input = input.add_(torch.randint_like(input, low=0, high=mantissa_difference, dtype=torch.int32)).view(dtype=torch.float32)
|
||||
input = input.nan_to_num_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"])
|
||||
return input, scale
|
||||
|
||||
|
|
|
|||
|
|
@ -50,14 +50,14 @@ def load_files(files: list[str], state_dict: dict = None, key_mapping: dict = No
|
|||
if isinstance(files, str):
|
||||
files = [files]
|
||||
if method is None:
|
||||
method = 'safetensors'
|
||||
method = "safetensors"
|
||||
if state_dict is None:
|
||||
state_dict = {}
|
||||
if method == 'safetensors':
|
||||
if method == "safetensors":
|
||||
load_safetensors(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
|
||||
elif method == 'threaded':
|
||||
elif method == "threaded":
|
||||
load_threaded(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
|
||||
elif method == 'streamer':
|
||||
elif method == "streamer":
|
||||
load_streamer(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loading method: {method}")
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def save_sdnq_model(model: ModelMixin, model_path: str, max_shard_size: str = "1
|
|||
model.config.quantization_config.to_json_file(quantization_config_path)
|
||||
|
||||
|
||||
def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: str = None, dtype: torch.dtype = None, device: torch.device = "cpu", dequantize_fp32: bool = None, use_quantized_matmul: bool = None, model_config: dict = None, quantization_config: dict = None, load_method: str = 'safetensors') -> ModelMixin:
|
||||
def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: str = None, dtype: torch.dtype = None, device: torch.device = "cpu", dequantize_fp32: bool = None, use_quantized_matmul: bool = None, model_config: dict = None, quantization_config: dict = None, load_method: str = "safetensors") -> ModelMixin:
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
with init_empty_weights():
|
||||
|
|
|
|||
|
|
@ -59,8 +59,7 @@ def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[i
|
|||
else:
|
||||
if use_stochastic_rounding:
|
||||
mantissa_difference = 1 << (23 - dtype_dict[weights_dtype]["mantissa"])
|
||||
quantized_weight = quantized_weight.view(dtype=torch.int32)
|
||||
quantized_weight = torch.randint_like(quantized_weight, low=0, high=mantissa_difference).add_(quantized_weight).bitwise_and_(-mantissa_difference).view(dtype=torch.float32)
|
||||
quantized_weight = quantized_weight.view(dtype=torch.int32).add_(torch.randint_like(quantized_weight, low=0, high=mantissa_difference, dtype=torch.int32)).view(dtype=torch.float32)
|
||||
quantized_weight.nan_to_num_()
|
||||
quantized_weight = quantized_weight.clamp_(dtype_dict[weights_dtype]["min"], dtype_dict[weights_dtype]["max"]).to(dtype_dict[weights_dtype]["torch_dtype"])
|
||||
return quantized_weight, scale, zero_point
|
||||
|
|
@ -211,6 +210,7 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
|
|||
result_shape = None
|
||||
original_shape = weight.shape
|
||||
original_stride = weight.stride()
|
||||
weight = weight.detach()
|
||||
|
||||
if torch_dtype is None:
|
||||
torch_dtype = weight.dtype
|
||||
|
|
@ -229,8 +229,6 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
|
|||
)
|
||||
|
||||
if layer_class_name in conv_types:
|
||||
if dtype_dict[weights_dtype]["num_bits"] < 4:
|
||||
weights_dtype = "uint4"
|
||||
is_conv_type = True
|
||||
reduction_axes = 1
|
||||
output_channel_size, channel_size = weight.shape[:2]
|
||||
|
|
@ -242,8 +240,6 @@ def sdnq_quantize_layer_weight(weight, layer_class_name=None, weights_dtype="int
|
|||
weight = weight.flatten(1,-1)
|
||||
reduction_axes = -1
|
||||
elif layer_class_name in conv_transpose_types:
|
||||
if dtype_dict[weights_dtype]["num_bits"] < 4:
|
||||
weights_dtype = "uint4"
|
||||
is_conv_transpose_type = True
|
||||
reduction_axes = 0
|
||||
channel_size, output_channel_size = weight.shape[:2]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
Modified from Triton MatMul example.
|
||||
PyTorch torch._int_mm is broken on backward pass with Nvidia.
|
||||
AMD RDNA2 doesn't support torch._int_mm, so we use int_mm via Triton.
|
||||
PyTorch doesn't support FP32 output type with FP16 MM so we use Triton for it too.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
|
@ -13,47 +14,47 @@ import triton.language as tl
|
|||
def get_autotune_config():
|
||||
if triton.runtime.driver.active.get_current_target().backend == "cuda":
|
||||
return [
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2),
|
||||
#
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=2),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=2),
|
||||
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=2),
|
||||
#
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4),
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(configs=get_autotune_config(), key=['M', 'N', 'K', 'stride_bk'])
|
||||
@triton.autotune(configs=get_autotune_config(), key=["M", "N", "K", "stride_bk"])
|
||||
@triton.jit
|
||||
def int_mm_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
|
|
@ -111,7 +112,7 @@ def int_mm(a, b):
|
|||
K, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
int_mm_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
|
|
@ -122,7 +123,7 @@ def int_mm(a, b):
|
|||
return c
|
||||
|
||||
|
||||
@triton.autotune(configs=get_autotune_config(), key=['M', 'N', 'K', 'stride_bk'])
|
||||
@triton.autotune(configs=get_autotune_config(), key=["M", "N", "K", "stride_bk"])
|
||||
@triton.jit
|
||||
def fp_mm_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
|
|
@ -180,7 +181,7 @@ def fp_mm(a, b):
|
|||
K, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
fp_mm_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from __future__ import annotations
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import contextlib
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
import gradio as gr
|
||||
import diffusers
|
||||
from modules.json_helpers import readfile, writefile # pylint: disable=W0611
|
||||
from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611
|
||||
from modules.shared_defaults import get_default_modes
|
||||
|
|
@ -23,6 +24,12 @@ import modules.styles
|
|||
import modules.paths as paths
|
||||
from installer import log, print_dict, console, get_version # pylint: disable=unused-import
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Behavior modified by __future__.annotations
|
||||
from diffusers import DiffusionPipeline
|
||||
from modules.shared_legacy import LegacyOption
|
||||
from modules.ui_extra_networks import ExtraNetworksPage
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ORIGINAL = 1
|
||||
|
|
@ -30,7 +37,7 @@ class Backend(Enum):
|
|||
|
||||
|
||||
errors.install([gr])
|
||||
demo: gr.Blocks = None
|
||||
demo: gr.Blocks | None = None
|
||||
api = None
|
||||
url = 'https://github.com/vladmandic/sdnext'
|
||||
cmd_opts = cmd_args.parse_args()
|
||||
|
|
@ -44,8 +51,8 @@ detailers = []
|
|||
face_restorers = []
|
||||
yolo = None
|
||||
tab_names = []
|
||||
extra_networks = []
|
||||
options_templates = {}
|
||||
extra_networks: list[ExtraNetworksPage] = []
|
||||
options_templates: dict[str, OptionInfo | LegacyOption] = {}
|
||||
hypernetworks = {}
|
||||
settings_components = {}
|
||||
restricted_opts = {
|
||||
|
|
@ -164,6 +171,11 @@ options_templates.update(options_section(('sd', "Model Loading"), {
|
|||
options_templates.update(options_section(('model_options', "Model Options"), {
|
||||
"model_modular_sep": OptionInfo("<h2>Modular Pipelines</h2>", "", gr.HTML),
|
||||
"model_modular_enable": OptionInfo(False, "Enable modular pipelines (experimental)"),
|
||||
"model_google_sep": OptionInfo("<h2>Google GenAI</h2>", "", gr.HTML),
|
||||
"google_use_vertexai": OptionInfo(False, "Google cloud use VertexAI endpoints"),
|
||||
"google_api_key": OptionInfo("", "Google cloud API key", gr.Textbox),
|
||||
"google_project_id": OptionInfo("", "Google Cloud project ID", gr.Textbox),
|
||||
"google_location_id": OptionInfo("", "Google Cloud location ID", gr.Textbox),
|
||||
"model_sd3_sep": OptionInfo("<h2>Stable Diffusion 3.x</h2>", "", gr.HTML),
|
||||
"model_sd3_disable_te5": OptionInfo(False, "Disable T5 text encoder"),
|
||||
"model_h1_sep": OptionInfo("<h2>HiDream</h2>", "", gr.HTML),
|
||||
|
|
@ -173,6 +185,8 @@ options_templates.update(options_section(('model_options', "Model Options"), {
|
|||
"model_wan_boundary": OptionInfo(0.85, "Stage boundary ratio", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05 }),
|
||||
"model_chrono_sep": OptionInfo("<h2>ChronoEdit</h2>", "", gr.HTML),
|
||||
"model_chrono_temporal_steps": OptionInfo(0, "Temporal steps", gr.Slider, {"minimum": 0, "maximum": 50, "step": 1 }),
|
||||
"model_qwen_layer_sep": OptionInfo("<h2>WanAI</h2>", "", gr.HTML),
|
||||
"model_qwen_layers": OptionInfo(2, "Qwen layered number of layers", gr.Slider, {"minimum": 2, "maximum": 9, "step": 1 }),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('offload', "Model Offloading"), {
|
||||
|
|
@ -626,7 +640,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
|||
|
||||
"postprocessing_sep_detailer": OptionInfo("<h2>Detailer</h2>", "", gr.HTML),
|
||||
"detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"),
|
||||
"detailer_augment": OptionInfo(True, "Detailer use model augment"),
|
||||
"detailer_augment": OptionInfo(False, "Detailer use model augment"),
|
||||
|
||||
"postprocessing_sep_seedvt": OptionInfo("<h2>SeedVT</h2>", "", gr.HTML),
|
||||
"seedvt_cfg_scale": OptionInfo(3.5, "SeedVR CFG Scale", gr.Slider, {"minimum": 1, "maximum": 15, "step": 1}),
|
||||
|
|
@ -802,6 +816,7 @@ options_templates.update(options_section(('hidden_options', "Hidden options"), {
|
|||
"detailer_merge": OptionInfo(False, "Merge multiple results from each detailer model", gr.Checkbox, {"visible": False}),
|
||||
"detailer_sort": OptionInfo(False, "Sort detailer output by location", gr.Checkbox, {"visible": False}),
|
||||
"detailer_save": OptionInfo(False, "Include detection results", gr.Checkbox, {"visible": False}),
|
||||
"detailer_seg": OptionInfo(False, "Use segmentation", gr.Checkbox, {"visible": False}),
|
||||
}))
|
||||
|
||||
|
||||
|
|
@ -822,7 +837,7 @@ log.info(f'Engine: backend={backend} compute={devices.backend} device={devices.g
|
|||
|
||||
profiler = None
|
||||
prompt_styles = modules.styles.StyleDatabase(opts)
|
||||
reference_models = readfile(os.path.join('html', 'reference.json')) if opts.extra_network_reference_enable else {}
|
||||
reference_models = readfile(os.path.join('html', 'reference.json'), as_type="dict") if opts.extra_network_reference_enable else {}
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or (cmd_opts.server_name or False)) and not cmd_opts.insecure
|
||||
devices.args = cmd_opts
|
||||
devices.opts = opts
|
||||
|
|
@ -884,8 +899,8 @@ def restore_defaults(restart=True):
|
|||
|
||||
|
||||
# startup def of shared.sd_model before its redefined in modeldata
|
||||
sd_model: diffusers.DiffusionPipeline = None # dummy and overwritten by class
|
||||
sd_refiner: diffusers.DiffusionPipeline = None # dummy and overwritten by class
|
||||
sd_model: DiffusionPipeline | None = None # dummy and overwritten by class
|
||||
sd_refiner: DiffusionPipeline | None = None # dummy and overwritten by class
|
||||
sd_model_type: str = '' # dummy and overwritten by class
|
||||
sd_refiner_type: str = '' # dummy and overwritten by class
|
||||
sd_loaded: bool = False # dummy and overwritten by class
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ def get_default_modes(cmd_opts, mem_stat):
|
|||
if "gpu" in mem_stat and gpu_memory != 0:
|
||||
if gpu_memory <= 4:
|
||||
cmd_opts.lowvram = True
|
||||
default_offload_mode = "sequential"
|
||||
default_offload_mode = "balanced"
|
||||
default_diffusers_offload_min_gpu_memory = 0
|
||||
log.info(f"Device detect: memory={gpu_memory:.1f} default=sequential optimization=lowvram")
|
||||
log.info(f"Device detect: memory={gpu_memory:.1f} default=balanced optimization=lowvram")
|
||||
elif gpu_memory <= 12:
|
||||
cmd_opts.medvram = True # VAE Tiling and other stuff
|
||||
default_offload_mode = "balanced"
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ pipelines = {
|
|||
'Qwen': getattr(diffusers, 'QwenImagePipeline', None),
|
||||
'HunyuanImage': getattr(diffusers, 'HunyuanImagePipeline', None),
|
||||
'Z-Image': getattr(diffusers, 'ZImagePipeline', None),
|
||||
'LongCat': getattr(diffusers, 'LongCatImagePipeline', None),
|
||||
# dynamically imported and redefined later
|
||||
'Meissonic': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'Monetico': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
|
|
|
|||
|
|
@ -414,7 +414,7 @@ def update_token_counter(text):
|
|||
token_count = 0
|
||||
max_length = 75
|
||||
if shared.state.job_count > 0:
|
||||
shared.log.info('Tokenizer busy')
|
||||
shared.log.debug('Tokenizer busy')
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
from modules import extra_networks
|
||||
prompt, _ = extra_networks.parse_prompt(text)
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ sort_ordering = {
|
|||
}
|
||||
|
||||
|
||||
def get_installed(ext) -> extensions.Extension:
|
||||
installed: extensions.Extension = [e for e in extensions.extensions if (e.remote or '').startswith(ext['url'].replace('.git', ''))]
|
||||
def get_installed(ext):
|
||||
installed = [e for e in extensions.extensions if (e.remote or '').startswith(ext['url'].replace('.git', ''))]
|
||||
return installed[0] if len(installed) > 0 else None
|
||||
|
||||
|
||||
|
|
@ -382,27 +382,27 @@ def create_html(search_text, sort_column):
|
|||
if ext.get('status', None) is None or type(ext['status']) == str: # old format
|
||||
ext['status'] = 0
|
||||
if ext['url'] is None or ext['url'] == '':
|
||||
status = f"<span style='cursor:pointer;color:#00C0FD' title='Local'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Local'>{ui_symbols.svg_bullet.color('#00C0FD')}</div>"
|
||||
elif ext['status'] > 0:
|
||||
if ext['status'] == 1:
|
||||
status = f"<span style='cursor:pointer;color:#00FD9C ' title='Verified'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Verified'>{ui_symbols.svg_bullet.color('#00FD9C')}</div>"
|
||||
elif ext['status'] == 2:
|
||||
status = f"<span style='cursor:pointer;color:#FFC300' title='Supported only with backend:Original'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Supported only with backend: Original'>{ui_symbols.svg_bullet.color('#FFC300')}</div>"
|
||||
elif ext['status'] == 3:
|
||||
status = f"<span style='cursor:pointer;color:#FFC300' title='Supported only with backend:Diffusers'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Supported only with backend: Diffusers'>{ui_symbols.svg_bullet.color('#FFC300')}</div>"
|
||||
elif ext['status'] == 4:
|
||||
status = f"<span style='cursor:pointer;color:#4E22FF' title=\"{ext.get('note', 'custom value')}\">{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title=\"{html.escape(ext.get('note', 'custom value'))}\">{ui_symbols.svg_bullet.color('#4E22FF')}</div>"
|
||||
elif ext['status'] == 5:
|
||||
status = f"<span style='cursor:pointer;color:#CE0000' title='Not supported'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Not supported'>{ui_symbols.svg_bullet.color('#CE0000')}</div>"
|
||||
elif ext['status'] == 6:
|
||||
status = f"<span style='cursor:pointer;color:#AEAEAE' title='Just discovered'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Just discovered'>{ui_symbols.svg_bullet.color('#AEAEAE')}</div>"
|
||||
else:
|
||||
status = f"<span style='cursor:pointer;color:#008EBC' title='Unknown status'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Unknown status'>{ui_symbols.svg_bullet.color('#008EBC')}</div>"
|
||||
else:
|
||||
if updated < datetime.timestamp(datetime.now() - timedelta(6*30)):
|
||||
status = f"<span style='cursor:pointer;color:#C000CF' title='Unmaintained'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='Unmaintained'>{ui_symbols.svg_bullet.color('#C000CF')}</div>"
|
||||
else:
|
||||
status = f"<span style='cursor:pointer;color:#7C7C7C' title='No info'>{ui_symbols.bullet}</span>"
|
||||
status = f"<div style='cursor:help;width:1em;' title='No info'>{ui_symbols.svg_bullet.color('#7C7C7C')}</div>"
|
||||
|
||||
code += f"""
|
||||
<tr style="display: {visible}">
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import html
|
|||
import base64
|
||||
import urllib.parse
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
from types import SimpleNamespace
|
||||
from pathlib import Path
|
||||
from html.parser import HTMLParser
|
||||
|
|
@ -440,7 +441,7 @@ class ExtraNetworksPage:
|
|||
def update_all_previews(self, items):
|
||||
global preview_map # pylint: disable=global-statement
|
||||
if preview_map is None:
|
||||
preview_map = shared.readfile('html/previews.json', silent=True)
|
||||
preview_map = shared.readfile('html/previews.json', silent=True, as_type="dict")
|
||||
t0 = time.time()
|
||||
reference_path = os.path.abspath(os.path.join('models', 'Reference'))
|
||||
possible_paths = list(set([os.path.dirname(item['filename']) for item in items] + [reference_path]))
|
||||
|
|
@ -520,12 +521,10 @@ class ExtraNetworksPage:
|
|||
t0 = time.time()
|
||||
fn = os.path.splitext(path)[0] + '.json'
|
||||
if not data and os.path.exists(fn):
|
||||
data = shared.readfile(fn, silent=True)
|
||||
data = shared.readfile(fn, silent=True, as_type="dict")
|
||||
fn = os.path.join(path, 'model_index.json')
|
||||
if not data and os.path.exists(fn):
|
||||
data = shared.readfile(fn, silent=True)
|
||||
if type(data) is list:
|
||||
data = data[0]
|
||||
data = shared.readfile(fn, silent=True, as_type="dict")
|
||||
t1 = time.time()
|
||||
self.info_time += t1-t0
|
||||
return data
|
||||
|
|
@ -862,13 +861,15 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
|
|||
is_valid = (item is not None) and hasattr(item, 'name') and hasattr(item, 'filename')
|
||||
|
||||
if is_valid:
|
||||
if TYPE_CHECKING:
|
||||
assert item is not None # Part of the definition of "is_valid"
|
||||
stat_size, stat_mtime = modelstats.stat(item.filename)
|
||||
if hasattr(item, 'size') and item.size > 0:
|
||||
stat_size = item.size
|
||||
if hasattr(item, 'mtime') and item.mtime is not None:
|
||||
stat_mtime = item.mtime
|
||||
desc = item.description
|
||||
fullinfo = shared.readfile(os.path.splitext(item.filename)[0] + '.json', silent=True)
|
||||
fullinfo = shared.readfile(os.path.splitext(item.filename)[0] + '.json', silent=True, as_type="dict")
|
||||
if 'modelVersions' in fullinfo: # sanitize massive objects
|
||||
fullinfo['modelVersions'] = []
|
||||
info = fullinfo
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||
shared.log.debug(f'Networks: type="reference" autodownload={shared.opts.sd_checkpoint_autodownload} enable={shared.opts.extra_network_reference_enable}')
|
||||
return []
|
||||
count = { 'total': 0, 'ready': 0, 'hidden': 0, 'experimental': 0, 'base': 0 }
|
||||
shared.reference_models = readfile(os.path.join('html', 'reference.json'))
|
||||
shared.reference_models = readfile(os.path.join('html', 'reference.json'), as_type="dict")
|
||||
for k, v in shared.reference_models.items():
|
||||
count['total'] += 1
|
||||
url = v['path']
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class UiLoadsave:
|
|||
|
||||
def read_from_file(self):
|
||||
from modules.shared import readfile
|
||||
return readfile(self.filename)
|
||||
return readfile(self.filename, as_type="dict")
|
||||
|
||||
def write_to_file(self, current_ui_settings):
|
||||
from modules.shared import writefile
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def create_setting_component(key, is_quicksettings=False):
|
|||
with gr.Row():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
|
||||
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"settings_{key}_refresh")
|
||||
elif info.folder is not None:
|
||||
elif info.folder:
|
||||
with gr.Row():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -40,3 +40,24 @@ sort_time_asc = '\uf0de'
|
|||
sort_time_dsc = '\uf0dd'
|
||||
style_apply = '↶'
|
||||
style_save = '↷'
|
||||
|
||||
class SVGSymbol:
|
||||
def __init__(self, svg: str):
|
||||
self.svg = svg
|
||||
self.before = ""
|
||||
self.after = ""
|
||||
self.supports_color = False
|
||||
if "currentColor" in self.svg:
|
||||
self.supports_color = True
|
||||
self.before, self.after = self.svg.split("currentColor", maxsplit=1)
|
||||
|
||||
def color(self, color: str):
|
||||
if self.supports_color:
|
||||
return self.before + color + self.after
|
||||
else:
|
||||
return self.svg
|
||||
|
||||
def __str__(self):
|
||||
return self.svg
|
||||
|
||||
svg_bullet = SVGSymbol("<svg style='stroke:currentColor;fill:none;stroke-width:2;' viewBox='0 0 16 16'><circle cx='8' cy='8' r='7'/></svg>")
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class Upscaler:
|
|||
def __init__(self, create_dirs=True):
|
||||
global models # pylint: disable=global-statement
|
||||
if models is None:
|
||||
models = shared.readfile('html/upscalers.json')
|
||||
models = shared.readfile('html/upscalers.json', as_type="dict")
|
||||
self.mod_pad_h = None
|
||||
self.tile_size = shared.opts.upscaler_tile_size
|
||||
self.tile_pad = shared.opts.upscaler_tile_overlap
|
||||
|
|
|
|||
|
|
@ -69,17 +69,38 @@ class GoogleVeoVideoPipeline():
|
|||
image=genai.types.Image(image_bytes=image_bytes.getvalue(), mime_type='image/jpeg'),
|
||||
)
|
||||
|
||||
def get_args(self):
|
||||
from modules.shared import opts
|
||||
api_key = os.getenv("GOOGLE_API_KEY") or opts.google_api_key
|
||||
vertex_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
if (api_key is None or len(api_key) == 0) and (vertex_credentials is None or len(vertex_credentials) == 0):
|
||||
log.error(f'Cloud: model="{self.model}" API key not provided')
|
||||
return None
|
||||
use_vertexai = (os.getenv("GOOGLE_GENAI_USE_VERTEXAI") is not None) or opts.google_use_vertexai
|
||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT") or opts.google_project_id
|
||||
location_id = os.getenv("GOOGLE_CLOUD_LOCATION") or opts.google_location_id
|
||||
args = {
|
||||
'api_key': api_key,
|
||||
'vertexai': use_vertexai,
|
||||
'project': project_id if len(project_id) > 0 else None,
|
||||
'location': location_id if len(location_id) > 0 else None,
|
||||
}
|
||||
args_copy = args.copy()
|
||||
args_copy['api_key'] = '...' + args_copy['api_key'][-4:] # last 4 chars
|
||||
args_copy['credentials'] = vertex_credentials
|
||||
log.debug(f'Cloud: model="{self.model}" args={args_copy}')
|
||||
return args
|
||||
|
||||
def __call__(self, prompt: list[str], width: int, height: int, image: Image.Image = None, num_frames: int = 4*24):
|
||||
from google import genai
|
||||
|
||||
if isinstance(prompt, list) and len(prompt) > 0:
|
||||
prompt = prompt[0]
|
||||
if self.client is None:
|
||||
api_key = os.getenv("GOOGLE_API_KEY", None)
|
||||
if api_key is None:
|
||||
log.error(f'Cloud: model="{self.model}" GOOGLE_API_KEY environment variable not set')
|
||||
args = self.get_args()
|
||||
if args is None:
|
||||
return None
|
||||
self.client = genai.Client(api_key=api_key, vertexai=False)
|
||||
self.client = genai.Client(**args)
|
||||
|
||||
resolution, aspect_ratio = get_size_buckets(width, height)
|
||||
duration = num_frames // 24
|
||||
|
|
@ -115,11 +136,12 @@ class GoogleVeoVideoPipeline():
|
|||
log.error(f'Cloud video: model="{self.model}" {operation} {e}')
|
||||
return None
|
||||
|
||||
if operation is None or operation.response is None or operation.response.generated_videos is None or len(operation.response.generated_videos) == 0:
|
||||
try:
|
||||
response: genai.types.GeneratedVideo = operation.response.generated_videos[0]
|
||||
except Exception:
|
||||
log.error(f'Cloud video: model="{self.model}" no response {operation}')
|
||||
return None
|
||||
try:
|
||||
response: genai.types.GeneratedVideo = operation.response.generated_videos[0]
|
||||
self.client.files.download(file=response.video)
|
||||
video_bytes = response.video.video_bytes
|
||||
return { 'bytes': video_bytes, 'images': [] }
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ def load_transformer(repo_id, cls_name, load_config=None, subfolder="transformer
|
|||
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" loader={_loader("diffusers")} args={load_args}')
|
||||
if dtype is not None:
|
||||
load_args['torch_dtype'] = dtype
|
||||
load_args.pop('device_map', None) # single-file uses different syntax
|
||||
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
|
||||
transformer = loader(
|
||||
local_file,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import io
|
||||
import os
|
||||
import time
|
||||
from PIL import Image
|
||||
from installer import install, reload, log
|
||||
|
||||
|
|
@ -67,14 +68,35 @@ class GoogleNanoBananaPipeline():
|
|||
],
|
||||
)
|
||||
|
||||
def get_args(self):
|
||||
from modules.shared import opts
|
||||
api_key = os.getenv("GOOGLE_API_KEY") or opts.google_api_key
|
||||
vertex_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
if (api_key is None or len(api_key) == 0) and (vertex_credentials is None or len(vertex_credentials) == 0):
|
||||
log.error(f'Cloud: model="{self.model}" API key not provided')
|
||||
return None
|
||||
use_vertexai = (os.getenv("GOOGLE_GENAI_USE_VERTEXAI") is not None) or opts.google_use_vertexai
|
||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT") or opts.google_project_id
|
||||
location_id = os.getenv("GOOGLE_CLOUD_LOCATION") or opts.google_location_id
|
||||
args = {
|
||||
'api_key': api_key,
|
||||
'vertexai': use_vertexai,
|
||||
'project': project_id if len(project_id) > 0 else None,
|
||||
'location': location_id if len(location_id) > 0 else None,
|
||||
}
|
||||
args_copy = args.copy()
|
||||
args_copy['api_key'] = '...' + args_copy['api_key'][-4:] # last 4 chars
|
||||
args_copy['credentials'] = vertex_credentials
|
||||
log.debug(f'Cloud: model="{self.model}" args={args_copy}')
|
||||
return args
|
||||
|
||||
def __call__(self, prompt: list[str], width: int, height: int, image: Image.Image = None):
|
||||
from google import genai
|
||||
if self.client is None:
|
||||
api_key = os.getenv("GOOGLE_API_KEY", None)
|
||||
if api_key is None:
|
||||
log.error(f'Cloud: model="{self.model}" GOOGLE_API_KEY environment variable not set')
|
||||
args = self.get_args()
|
||||
if args is None:
|
||||
return None
|
||||
self.client = genai.Client(api_key=api_key, vertexai=False)
|
||||
self.client = genai.Client(**args)
|
||||
|
||||
image_size, aspect_ratio = get_size_buckets(width, height)
|
||||
if 'gemini-3' in self.model:
|
||||
|
|
@ -85,14 +107,21 @@ class GoogleNanoBananaPipeline():
|
|||
response_modalities=["IMAGE"],
|
||||
image_config=image_config
|
||||
)
|
||||
log.debug(f'Cloud: prompt="{prompt}" size={image_size} ar={aspect_ratio} image={image} model="{self.model}"')
|
||||
log.debug(f'Cloud: model="{self.model}" prompt="{prompt}" size={image_size} ar={aspect_ratio} image={image}')
|
||||
# log.debug(f'Cloud: config={self.config}')
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
if image is not None:
|
||||
response = self.img2img(prompt, image)
|
||||
else:
|
||||
response = self.txt2img(prompt)
|
||||
t1 = time.time()
|
||||
try:
|
||||
tokens = response.usage_metadata.total_token_count
|
||||
except Exception:
|
||||
tokens = 0
|
||||
log.debug(f'Cloud: model="{self.model}" tokens={tokens} time={(t1 - t0):.2f}')
|
||||
except Exception as e:
|
||||
log.error(f'Cloud: model="{self.model}" {e}')
|
||||
return None
|
||||
|
|
@ -100,10 +129,16 @@ class GoogleNanoBananaPipeline():
|
|||
image = None
|
||||
if getattr(response, 'prompt_feedback', None) is not None:
|
||||
log.error(f'Cloud: model="{self.model}" {response.prompt_feedback}')
|
||||
if not hasattr(response, 'candidates') or (response.candidates is None) or (len(response.candidates) == 0):
|
||||
|
||||
parts = []
|
||||
try:
|
||||
for candidate in response.candidates:
|
||||
parts.extend(candidate.content.parts)
|
||||
except Exception:
|
||||
log.error(f'Cloud: model="{self.model}" no images received')
|
||||
return None
|
||||
for part in response.candidates[0].content.parts:
|
||||
|
||||
for part in parts:
|
||||
if part.inline_data is not None:
|
||||
image = Image.open(io.BytesIO(part.inline_data.data))
|
||||
return image
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
import transformers
|
||||
import diffusers
|
||||
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
|
||||
from pipelines import generic
|
||||
|
||||
|
||||
def load_longcat(checkpoint_info, diffusers_load_config=None):
|
||||
if diffusers_load_config is None:
|
||||
diffusers_load_config = {}
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
shared.log.debug(f'Load model: type=LongCat repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
|
||||
|
||||
transformer = generic.load_transformer(repo_id, cls_name=diffusers.LongCatImageTransformer2DModel, load_config=diffusers_load_config)
|
||||
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen2_5_VLForConditionalGeneration, load_config=diffusers_load_config)
|
||||
text_processor = transformers.Qwen2VLProcessor.from_pretrained(repo_id, subfolder='tokenizer', cache_dir=shared.opts.hfcache_dir)
|
||||
|
||||
if 'edit' in repo_id.lower():
|
||||
cls = diffusers.LongCatImageEditPipeline
|
||||
else:
|
||||
cls = diffusers.LongCatImagePipeline
|
||||
|
||||
pipe = cls.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
text_processor=text_processor,
|
||||
**load_args,
|
||||
)
|
||||
|
||||
del transformer
|
||||
del text_encoder
|
||||
del text_processor
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
import transformers
|
||||
import diffusers
|
||||
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
|
||||
from pipelines import generic
|
||||
|
||||
|
||||
def load_ovis(checkpoint_info, diffusers_load_config=None):
|
||||
if diffusers_load_config is None:
|
||||
diffusers_load_config = {}
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
shared.log.debug(f'Load model: type=OvisImage repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
|
||||
|
||||
transformer = generic.load_transformer(repo_id, cls_name=diffusers.OvisImageTransformer2DModel, load_config=diffusers_load_config)
|
||||
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen3Model, load_config=diffusers_load_config)
|
||||
|
||||
pipe = diffusers.OvisImagePipeline.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
**load_args,
|
||||
)
|
||||
|
||||
pipe.task_args = {
|
||||
'output_type': 'np',
|
||||
}
|
||||
|
||||
del transformer
|
||||
del text_encoder
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
|
|
@ -15,7 +15,7 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
|
|||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model')
|
||||
shared.log.debug(f'Load model: type=Qwen model="{checkpoint_info.name}" repo="{repo_id}" offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
if '2509' in repo_id :
|
||||
if '2509' in repo_id or '2511' in repo_id:
|
||||
cls_name = diffusers.QwenImageEditPlusPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPlusPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPlusPipeline
|
||||
|
|
@ -25,6 +25,11 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
|
|||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageEditPipeline
|
||||
elif 'Layered' in repo_id:
|
||||
cls_name = diffusers.QwenImageLayeredPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-layered"] = diffusers.QwenImageLayeredPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["qwen-layered"] = diffusers.QwenImageLayeredPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-layered"] = diffusers.QwenImageLayeredPipeline
|
||||
else:
|
||||
cls_name = diffusers.QwenImagePipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImagePipeline
|
||||
|
|
@ -69,6 +74,11 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
|
|||
pipe.task_args = {
|
||||
'output_type': 'np',
|
||||
}
|
||||
if 'Layered' in repo_id:
|
||||
pipe.task_args['use_en_prompt'] = True
|
||||
pipe.task_args['cfg_normalize'] = False
|
||||
pipe.task_args['layers'] = shared.opts.model_qwen_layers
|
||||
pipe.task_args['resolution'] = 640
|
||||
|
||||
del text_encoder
|
||||
del transformer
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ def load_z_image(checkpoint_info, diffusers_load_config=None):
|
|||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
shared.log.debug(f'Load model: type=Z-Image repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
|
||||
shared.log.debug(f'Load model: type=ZImage repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
|
||||
|
||||
transformer = generic.load_transformer(repo_id, cls_name=diffusers.ZImageTransformer2DModel, load_config=diffusers_load_config)
|
||||
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen3ForCausalLM, load_config=diffusers_load_config)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ PyWavelets
|
|||
pi-heif
|
||||
|
||||
# versioned
|
||||
fastapi==0.124.4
|
||||
rich==14.1.0
|
||||
safetensors==0.7.0
|
||||
tensordict==0.8.3
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit 12af554d26d12ce84d43e35f81da89bdbeac4057
|
||||
Subproject commit 01a5b7af78897212a8d1b32def6ff4bd3d03a352
|
||||
Loading…
Reference in New Issue