Merge remote-tracking branch 'origin/dev' into refactor/remove-face-restoration

# Conflicts:
#	.pylintrc
#	.ruff.toml
pull/4637/head
CalamitousFelicitousness 2026-02-10 14:38:14 +00:00
commit 385532154f
42 changed files with 853 additions and 643 deletions

View File

@ -9,22 +9,22 @@ body:
attributes:
label: Model name
description: Enter model name
value:
value:
- type: textarea
id: type
attributes:
label: Model type
description: Describe model type
value:
value:
- type: textarea
id: url
attributes:
label: Model URL
description: Enter URL to the model page
value:
value:
- type: textarea
id: description
attributes:
label: Reason
description: Enter a reason why you would like to see this model supported
value:
value:

283
.pylintrc
View File

@ -1,283 +0,0 @@
[MAIN]
analyse-fallback-blocks=no
clear-cache-post-run=no
extension-pkg-allow-list=
prefer-stubs=yes
extension-pkg-whitelist=
fail-on=
fail-under=10
ignore=CVS
ignore-paths=/usr/lib/.*$,
venv,
.git,
.ruff_cache,
.vscode,
modules/apg,
modules/cfgzero,
modules/control/proc,
modules/control/units,
modules/dml,
modules/flash_attn_triton_amd,
modules/ggml,
modules/hidiffusion,
modules/hijack/ddpm_edit.py,
modules/intel,
modules/intel/ipex,
modules/framepack/pipeline,
modules/onnx_impl,
modules/pag,
modules/postprocess/aurasr_arch.py,
modules/prompt_parser_xhinker.py,
modules/ras,
modules/seedvr,
modules/rife,
modules/schedulers,
modules/taesd,
modules/teacache,
modules/todo,
modules/res4lyf,
pipelines/bria,
pipelines/flex2,
pipelines/f_lite,
pipelines/hidream,
pipelines/hdm,
pipelines/meissonic,
pipelines/omnigen2,
pipelines/segmoe,
pipelines/xomni,
pipelines/chrono,
scripts/consistory,
scripts/ctrlx,
scripts/daam,
scripts/demofusion,
scripts/freescale,
scripts/infiniteyou,
scripts/instantir,
scripts/lbm,
scripts/layerdiffuse,
scripts/mod,
scripts/pixelsmith,
scripts/differential_diffusion.py,
scripts/pulid,
scripts/xadapter,
repositories,
extensions-builtin/sd-extension-chainner/nodes,
extensions-builtin/sd-webui-agent-scheduler,
extensions-builtin/sdnext-modernui/node_modules,
extensions-builtin/sdnext-kanvas/node_modules,
ignore-patterns=.*test*.py$,
.*_model.py$,
.*_arch.py$,
.*_model_arch.py*,
.*_model_arch_v2.py$,
ignored-modules=
jobs=8
limit-inference-results=100
load-plugins=
persistent=no
py-version=3.10
recursive=no
source-roots=
unsafe-load-any-extension=no
[BASIC]
argument-naming-style=snake_case
attr-naming-style=snake_case
bad-names=foo, bar, baz, toto, tutu, tata
bad-names-rgxs=
class-attribute-naming-style=any
class-const-naming-style=UPPER_CASE
class-naming-style=PascalCase
const-naming-style=snake_case
docstring-min-length=-1
function-naming-style=snake_case
good-names=i,j,k,e,ex,ok,p,x,y,id
good-names-rgxs=
include-naming-hint=no
inlinevar-naming-style=any
method-naming-style=snake_case
module-naming-style=snake_case
name-group=
no-docstring-rgx=^_
property-classes=abc.abstractproperty
variable-naming-style=snake_case
[CLASSES]
check-protected-access-in-special-methods=no
defining-attr-methods=__init__, __new__,
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
valid-classmethod-first-arg=cls
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
exclude-too-few-public-methods=
ignored-parents=
max-args=199
max-attributes=99
max-bool-expr=99
max-branches=199
max-locals=99
max-parents=99
max-public-methods=99
max-returns=99
max-statements=199
min-public-methods=1
[EXCEPTIONS]
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
expected-line-ending-format=
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
indent-after-paren=4
indent-string=' '
max-line-length=200
max-module-lines=9999
single-line-class-stmt=no
single-line-if-stmt=no
[IMPORTS]
allow-any-import-level=
allow-reexport-from-package=no
allow-wildcard-with-all=no
deprecated-modules=
ext-import-graph=
import-graph=
int-import-graph=
known-standard-library=
known-third-party=enchant
preferred-modules=
[LOGGING]
logging-format-style=new
logging-modules=logging
[MESSAGES CONTROL]
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# disable=C,R,W
disable=abstract-method,
bad-inline-option,
bare-except,
broad-exception-caught,
chained-comparison,
consider-iterating-dictionary,
consider-merging-isinstance,
consider-using-dict-items,
consider-using-enumerate,
consider-using-from-import,
consider-using-generator,
consider-using-get,
consider-using-in,
consider-using-max-builtin,
consider-using-min-builtin,
consider-using-sys-exit,
cyclic-import,
dangerous-default-value,
deprecated-pragma,
duplicate-code,
file-ignored,
import-error,
import-outside-toplevel,
invalid-name,
line-too-long,
locally-disabled,
logging-fstring-interpolation,
missing-class-docstring,
missing-function-docstring,
missing-module-docstring,
no-else-raise,
no-else-return,
not-callable,
pointless-string-statement,
raw-checker-failed,
simplifiable-if-expression,
suppressed-message,
too-few-public-methods,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-positional-arguments,
too-many-statements,
unidiomatic-typecheck,
unknown-option-value,
unnecessary-dict-index-lookup,
unnecessary-dunder-call,
unnecessary-lambda-assigment,
unnecessary-lambda,
unused-wildcard-import,
unpacking-non-sequence,
unsubscriptable-object,
useless-return,
use-dict-literal,
use-symbolic-message-instead,
useless-suppression,
wrong-import-position,
enable=c-extension-no-member
[METHOD_ARGS]
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
notes=FIXME,
XXX,
TODO
notes-rgx=
[REFACTORING]
max-nested-blocks=5
never-returning-functions=sys.exit,argparse.parse_error
[REPORTS]
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
msg-template=
reports=no
score=no
[SIMILARITIES]
ignore-comments=yes
ignore-docstrings=yes
ignore-imports=yes
ignore-signatures=yes
min-similarity-lines=4
[SPELLING]
max-spelling-suggestions=4
spelling-dict=
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
spelling-ignore-words=
spelling-private-dict-file=
spelling-store-unknown-words=no
[STRING]
check-quote-consistency=no
check-str-concat-over-line-jumps=no
[TYPECHECK]
contextmanager-decorators=contextlib.contextmanager
generated-members=numpy.*,logging.*,torch.*,cv2.*
ignore-none=yes
ignore-on-opaque-inference=yes
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
missing-member-hint=yes
missing-member-hint-distance=1
missing-member-max-choices=1
mixin-class-rgx=.*[Mm]ixin
signature-mutators=
[VARIABLES]
additional-builtins=
allow-global-unused-variables=yes
allowed-redefined-builtins=
callbacks=cb_,
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
ignored-argument-names=_.*|^ignored_|^unused_
init-import=no
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io

View File

@ -1,108 +0,0 @@
line-length = 250
indent-width = 4
target-version = "py310"
exclude = [
"venv",
".git",
".ruff_cache",
".vscode",
"modules/cfgzero",
"modules/flash_attn_triton_amd",
"modules/hidiffusion",
"modules/intel/ipex",
"modules/pag",
"modules/schedulers",
"modules/teacache",
"modules/seedvr",
"modules/control/proc",
"modules/control/units",
"modules/control/units/xs_pipe.py",
"modules/postprocess/aurasr_arch.py",
"pipelines/meissonic",
"pipelines/omnigen2",
"pipelines/hdm",
"pipelines/segmoe",
"pipelines/xomni",
"pipelines/chrono",
"scripts/lbm",
"scripts/daam",
"scripts/xadapter",
"scripts/pulid",
"scripts/instantir",
"scripts/freescale",
"scripts/consistory",
"repositories",
"extensions-builtin/Lora",
"extensions-builtin/sd-extension-chainner/nodes",
"extensions-builtin/sd-webui-agent-scheduler",
"extensions-builtin/sdnext-modernui/node_modules",
]
[lint]
select = [
"F",
"E",
"W",
"C",
"B",
"I",
"YTT",
"ASYNC",
"RUF",
"AIR",
"NPY",
"C4",
"T10",
"EXE",
"ISC",
"ICN",
"RSE",
"TCH",
"TID",
"INT",
"PLE",
]
ignore = [
"B006", # Do not use mutable data structures for argument defaults
"B008", # Do not perform function call in argument defaults
"B905", # Strict zip() usage
"C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead
"C408", # Unnecessary `dict` call
"I001", # Import block is un-sorted or un-formatted
"E402", # Module level import not at top of file
"E501", # Line too long
"E721", # Do not compare types, use `isinstance()`
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"F401", # Imported by unused
"EXE001", # file with shebang is not marked executable
"NPY002", # replace legacy random
"RUF005", # Consider iterable unpacking
"RUF008", # Do not use mutable default values for dataclass
"RUF010", # Use explicit conversion flag
"RUF012", # Mutable class attributes
"RUF013", # PEP 484 prohibits implicit `Optional`
"RUF015", # Prefer `next(...)` over single element slice
"RUF046", # Value being cast to `int` is already an integer
"RUF059", # Unpacked variables are not used
"RUF051", # Prefer pop over del
]
fixable = ["ALL"]
unfixable = []
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[lint.mccabe]
max-complexity = 150

View File

@ -1,6 +1,6 @@
# Change Log for SD.Next
## Update for 2026-02-07
## Update for 2026-02-09
- **Upscalers**
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
@ -11,12 +11,21 @@
- pipelines: add **ZImageInpaint**, thanks @CalamitousFelicitousness
- add `--remote` command line flag that reduces client/server chatter and improves link stability
for long-running generates, useful when running on remote servers
- hires: allow using different lora in refiner prompt
- **nunchaku** models are now listed in networks tab as reference models
instead of being used implicitly via quantization, thanks @CalamitousFelicitousness
- **UI**
- ui: **themes** add *CTD-NT64Light* and *CTD-NT64Dark*, thanks @resonantsky
- ui: **gallery** add option to auto-refresh gallery, thanks @awsr
- **Internal**
- refactor: switch to `pyproject.toml` for tool configs
- refactor: reorganize `cli` scripts
- refactor: move tests to dedicated `/test/`
- update `lint` rules, thanks @awsr
- update `requirements`
- **Fixes**
- fix: handle `clip` installer doing unwanted `setuptools` update
- fix: cleanup for `uv` installer fallback
- fix: add metadata restore to always-on scripts
- fix: improve wildcard weights parsing, thanks @Tillerz
- fix: ui gallery cace recursive cleanup, thanks @awsr

View File

@ -0,0 +1,209 @@
{
"FLUX.1-Dev Nunchaku SVDQuant": {
"path": "black-forest-labs/FLUX.1-dev",
"subfolder": "nunchaku",
"preview": "black-forest-labs--FLUX.1-dev.jpg",
"desc": "Nunchaku SVDQuant quantization of FLUX.1-dev transformer with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"FLUX.1-Schnell Nunchaku SVDQuant": {
"path": "black-forest-labs/FLUX.1-schnell",
"subfolder": "nunchaku",
"preview": "black-forest-labs--FLUX.1-schnell.jpg",
"desc": "Nunchaku SVDQuant quantization of FLUX.1-schnell transformer with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"extras": "sampler: Default, cfg_scale: 1.0, steps: 4",
"size": 0,
"date": "2025 June"
},
"FLUX.1-Kontext Nunchaku SVDQuant": {
"path": "black-forest-labs/FLUX.1-Kontext-dev",
"subfolder": "nunchaku",
"preview": "black-forest-labs--FLUX.1-Kontext-dev.jpg",
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Kontext-dev transformer with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"FLUX.1-Krea Nunchaku SVDQuant": {
"path": "black-forest-labs/FLUX.1-Krea-dev",
"subfolder": "nunchaku",
"preview": "black-forest-labs--FLUX.1-Krea-dev.jpg",
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Krea-dev transformer with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"FLUX.1-Fill Nunchaku SVDQuant": {
"path": "black-forest-labs/FLUX.1-Fill-dev",
"subfolder": "nunchaku",
"preview": "black-forest-labs--FLUX.1-Fill-dev.jpg",
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Fill-dev transformer for inpainting",
"skip": true,
"hidden": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"FLUX.1-Depth Nunchaku SVDQuant": {
"path": "black-forest-labs/FLUX.1-Depth-dev",
"subfolder": "nunchaku",
"preview": "black-forest-labs--FLUX.1-Depth-dev.jpg",
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Depth-dev transformer for depth-conditioned generation",
"skip": true,
"hidden": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"Shuttle Jaguar Nunchaku SVDQuant": {
"path": "shuttleai/shuttle-jaguar",
"subfolder": "nunchaku",
"preview": "shuttleai--shuttle-jaguar.jpg",
"desc": "Nunchaku SVDQuant quantization of Shuttle Jaguar transformer",
"skip": true,
"nunchaku": ["Model", "TE"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"Qwen-Image Nunchaku SVDQuant": {
"path": "Qwen/Qwen-Image",
"subfolder": "nunchaku",
"preview": "Qwen--Qwen-Image.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Image transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"Qwen-Lightning (8-step) Nunchaku SVDQuant": {
"path": "vladmandic/Qwen-Lightning",
"subfolder": "nunchaku",
"preview": "vladmandic--Qwen-Lightning.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning (8-step distilled) transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"extras": "steps: 8",
"size": 0,
"date": "2025 June"
},
"Qwen-Lightning (4-step) Nunchaku SVDQuant": {
"path": "vladmandic/Qwen-Lightning",
"subfolder": "nunchaku-4step",
"preview": "vladmandic--Qwen-Lightning.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning (4-step distilled) transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"extras": "steps: 4",
"size": 0,
"date": "2025 June"
},
"Qwen-Image-Edit Nunchaku SVDQuant": {
"path": "Qwen/Qwen-Image-Edit",
"subfolder": "nunchaku",
"preview": "Qwen--Qwen-Image-Edit.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Image-Edit transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"Qwen-Lightning-Edit (8-step) Nunchaku SVDQuant": {
"path": "vladmandic/Qwen-Lightning-Edit",
"subfolder": "nunchaku",
"preview": "vladmandic--Qwen-Lightning-Edit.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning-Edit (8-step distilled editing) transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"extras": "steps: 8",
"size": 0,
"date": "2025 June"
},
"Qwen-Lightning-Edit (4-step) Nunchaku SVDQuant": {
"path": "vladmandic/Qwen-Lightning-Edit",
"subfolder": "nunchaku-4step",
"preview": "vladmandic--Qwen-Lightning-Edit.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning-Edit (4-step distilled editing) transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"extras": "steps: 4",
"size": 0,
"date": "2025 June"
},
"Qwen-Image-Edit-2509 Nunchaku SVDQuant": {
"path": "Qwen/Qwen-Image-Edit-2509",
"subfolder": "nunchaku",
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
"desc": "Nunchaku SVDQuant quantization of Qwen-Image-Edit-2509 transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"size": 0,
"date": "2025 September"
},
"Sana 1.6B 1k Nunchaku SVDQuant": {
"path": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
"subfolder": "nunchaku",
"preview": "Efficient-Large-Model--Sana_1600M_1024px_BF16_diffusers.jpg",
"desc": "Nunchaku SVDQuant quantization of Sana 1.6B 1024px transformer with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"Z-Image-Turbo Nunchaku SVDQuant": {
"path": "Tongyi-MAI/Z-Image-Turbo",
"subfolder": "nunchaku",
"preview": "Tongyi-MAI--Z-Image-Turbo.jpg",
"desc": "Nunchaku SVDQuant quantization of Z-Image-Turbo transformer with INT4 and SVD rank 128",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"extras": "sampler: Default, cfg_scale: 1.0, steps: 9",
"size": 0,
"date": "2025 June"
},
"SDXL Base Nunchaku SVDQuant": {
"path": "stabilityai/stable-diffusion-xl-base-1.0",
"subfolder": "nunchaku",
"preview": "stabilityai--stable-diffusion-xl-base-1.0.jpg",
"desc": "Nunchaku SVDQuant quantization of SDXL Base 1.0 UNet with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"size": 0,
"date": "2025 June"
},
"SDXL Turbo Nunchaku SVDQuant": {
"path": "stabilityai/sdxl-turbo",
"subfolder": "nunchaku",
"preview": "stabilityai--sdxl-turbo.jpg",
"desc": "Nunchaku SVDQuant quantization of SDXL Turbo UNet with INT4 and SVD rank 32",
"skip": true,
"nunchaku": ["Model"],
"tags": "nunchaku",
"extras": "sampler: Default, cfg_scale: 1.0, steps: 4",
"size": 0,
"date": "2025 June"
}
}

View File

@ -337,6 +337,8 @@
{"id":"","label":"Model Options","localized":"","reload":"","hint":"Settings related to behavior of specific models"},
{"id":"","label":"Model Offloading","localized":"","reload":"","hint":"Settings related to model offloading and memory management"},
{"id":"","label":"Model Quantization","localized":"","reload":"","hint":"Settings related to model quantization which is used to reduce memory usage"},
{"id":"","label":"Nunchaku attention","localized":"","reload":"","hint":"Replaces default attention with Nunchaku's custom FP16 attention kernel for faster inference on consumer NVIDIA GPUs.<br>Might provide performance improvement on GPUs which have higher FP16 tensor cores throughput than BF16.<br><br>Currently only affects Flux-based models (Dev, Schnell, Kontext, Fill, Depth, etc.). Has no effect on Qwen, SDXL, Sana, or other architectures.<br><br>Disabled by default."},
{"id":"","label":"Nunchaku offloading","localized":"","reload":"","hint":"Enables Nunchaku's own per-block CPU offloading with asynchronous CUDA streams to reduce VRAM usage.<br>Uses a ping-pong buffer strategy: while one transformer block computes on GPU, the next block preloads from CPU in the background, hiding most of the transfer latency.<br><br>Can reduce VRAM usage at the cost of slower inference.<br>This replaces SD.Next's pipeline offloading for the transformer component.<br><br>Only useful on low-VRAM GPUs. If your GPU has enough memory to hold the quantized model (16+ GB), keep this disabled for maximum speed.<br>Supports Flux and Qwen models. Not supported for SDXL where this setting is ignored.<br>Disabled by default."},
{"id":"","label":"Image Metadata","localized":"","reload":"","hint":"Settings related to handling of metadata that is created with generated images"},
{"id":"","label":"Legacy Options","localized":"","reload":"","hint":"Settings related to legacy options - should not be used"},
{"id":"","label":"Restart server","localized":"","reload":"","hint":"Restart server"},

7
html/logo-dark.svg Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg id="Layer_1" xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:svg="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<!-- Generator: Adobe Illustrator 30.2.0, SVG Export Plug-In . SVG Version: 2.1.1 Build 105) -->
<path d="M333.62,470.26l-9.4,4.63-39.05-21.52,16.16,32.8-9.4,4.63-24.4-49.52,9.4-4.63,39.09,21.59-16.19-32.87,9.4-4.63,24.4,49.52h-.01ZM442.86,365.07h0ZM334.71,418.37l6.05,12.27h.01l22.32-11.01,4.03,8.18-22.32,11,6.26,12.7,38.07-18.76,3.55-23.51-18.59-10.28-39.38,19.41h0ZM457.12,338.1l-205,101,31.64,65.17,205.12-100.53-31.76-65.64h0ZM423.12,426.16l-21.57-12.47-2.75,24.46-11.34,5.59h0l-41.79,20.59-24.4-49.52,43.06-21.21.19.07-.16-.09h-.01.01l11.54-5.68,19.04,10.95,2.2-21.42,11.34-5.59-4.76,31.42,30.96,17.21-11.55,5.69h-.01ZM433.46,369.7l-14.85,7.32-4.03-8.18,39.11-19.27,4.03,8.18-14.86,7.32,20.37,41.34-9.4,4.63-20.37-41.34h0Z" fill="#231f20"/>
<polygon points="248.11 427.73 453.11 326.73 453.11 113.73 248.11 7.73 248.11 117.73 333.11 167.73 333.11 267.73 248.11 317.73 248.11 427.73" fill="#231f20"/>
<polygon points="23.11 150.73 23.11 275.73 103.11 319.73 103.11 331.73 23.11 375.73 23.11 500.73 228.11 386.73 228.11 261.73 143.11 221.73 143.11 209.73 228.11 161.73 228.11 36.73 23.11 150.73" fill="#231f20"/>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

7
html/logo-light.svg Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg id="Layer_1" xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:svg="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<!-- Generator: Adobe Illustrator 30.2.0, SVG Export Plug-In . SVG Version: 2.1.1 Build 105) -->
<path d="M333.62,470.26l-9.4,4.63-39.05-21.52,16.16,32.8-9.4,4.63-24.4-49.52,9.4-4.63,39.09,21.59-16.19-32.87,9.4-4.63,24.4,49.52h-.01ZM442.86,365.07h0ZM334.71,418.37l6.05,12.27h.01l22.32-11.01,4.03,8.18-22.32,11,6.26,12.7,38.07-18.76,3.55-23.51-18.59-10.28-39.38,19.41h0ZM457.12,338.1l-205,101,31.64,65.17,205.12-100.53-31.76-65.64h0ZM423.12,426.16l-21.57-12.47-2.75,24.46-11.34,5.59h0l-41.79,20.59-24.4-49.52,43.06-21.21.19.07-.16-.09h-.01.01l11.54-5.68,19.04,10.95,2.2-21.42,11.34-5.59-4.76,31.42,30.96,17.21-11.55,5.69h-.01ZM433.46,369.7l-14.85,7.32-4.03-8.18,39.11-19.27,4.03,8.18-14.86,7.32,20.37,41.34-9.4,4.63-20.37-41.34h0Z" fill="#fff"/>
<polygon points="248.11 427.73 453.11 326.73 453.11 113.73 248.11 7.73 248.11 117.73 333.11 167.73 333.11 267.73 248.11 317.73 248.11 427.73" fill="#fff"/>
<polygon points="23.11 150.73 23.11 275.73 103.11 319.73 103.11 331.73 23.11 375.73 23.11 500.73 228.11 386.73 228.11 261.73 143.11 221.73 143.11 209.73 228.11 161.73 228.11 36.73 23.11 150.73" fill="#fff"/>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

19
html/logo-robot.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 473 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 10 KiB

View File

@ -431,6 +431,26 @@ def run(cmd: str, arg: str):
return txt
def cleanup_broken_packages():
"""Remove dist-info directories with missing RECORD files that uv may have left behind"""
try:
import site
for site_dir in site.getsitepackages():
if not os.path.isdir(site_dir):
continue
for entry in os.listdir(site_dir):
if not entry.endswith('.dist-info'):
continue
dist_info = os.path.join(site_dir, entry)
record_file = os.path.join(dist_info, 'RECORD')
if not os.path.isfile(record_file):
pkg_name = entry.split('-')[0]
log.warning(f'Install: removing broken package metadata: {pkg_name} path={dist_info}')
shutil.rmtree(dist_info, ignore_errors=True)
except Exception:
pass
def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
t_start = time.time()
originalArg = arg
@ -454,6 +474,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
err = result.stderr.decode(encoding="utf8", errors="ignore")
log.warning(f'Install: cmd="{pipCmd}" args="{all_args}" cannot use uv, fallback to pip')
debug(f'Install: uv pip error: {err}')
cleanup_broken_packages()
return pip(originalArg, ignore, quiet, uv=False)
else:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
@ -468,7 +489,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
# install package using pip if not already installed
def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False, force: bool = False):
def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False, force: bool = False, no_build_isolation: bool = False):
t_start = time.time()
res = ''
if args.reinstall or args.upgrade:
@ -476,7 +497,8 @@ def install(package, friendly: str = None, ignore: bool = False, reinstall: bool
quick_allowed = False
if (args.reinstall) or (reinstall) or (not installed(package, friendly, quiet=quiet)):
deps = '' if not no_deps else '--no-deps '
cmd = f"install{' --upgrade' if not args.uv else ''}{' --force-reinstall' if force else ''} {deps}{package}"
isolation = '' if not no_build_isolation else '--no-build-isolation '
cmd = f"install{' --upgrade' if not args.uv else ''}{' --force-reinstall' if force else ''} {deps}{isolation}{package}"
res = pip(cmd, ignore=ignore, uv=package != "uv" and not package.startswith('git+'))
try:
importlib.reload(pkg_resources)
@ -665,7 +687,7 @@ def check_diffusers():
t_start = time.time()
if args.skip_all:
return
sha = '99e2cfff27dec514a43e260e885c5e6eca038b36' # diffusers commit hash
sha = '5bf248ddd8796b4f4958559429071a28f9b2dd3a' # diffusers commit hash
# if args.use_rocm or args.use_zluda or args.use_directml:
# sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
@ -1100,8 +1122,9 @@ def install_packages():
pr.enable()
# log.info('Install: verifying packages')
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git")
install(clip_package, 'clip', quiet=True)
install(clip_package, 'clip', quiet=True, no_build_isolation=True)
install('open-clip-torch', no_deps=True, quiet=True)
install(clip_package, 'ftfy', quiet=True, no_build_isolation=True)
# tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', 'tensorflow==2.13.0')
# tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', None)
# if tensorflow_package is not None:

View File

@ -171,6 +171,12 @@ async function filterExtraNetworksForTab(searchTerm) {
.toLowerCase()
.includes('quantized') ? '' : 'none';
});
} else if (searchTerm === 'nunchaku/') {
cards.forEach((elem) => {
elem.style.display = elem.dataset.tags
.toLowerCase()
.includes('nunchaku') ? '' : 'none';
});
} else if (searchTerm === 'local/') {
cards.forEach((elem) => {
elem.style.display = elem.dataset.name

View File

@ -1353,6 +1353,7 @@ async function initGallery() { // triggered on gradio change to monitor when ui
monitorGalleries();
updateFolders();
initGalleryAutoRefresh();
[
'browser_folders',
'outdir_samples',

View File

@ -156,7 +156,6 @@ def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNe
if prompt is None:
return "", res
if isinstance(prompt, list):
shared.log.warning(f"parse_prompt was called with a list instead of a string: {prompt}")
return parse_prompts(prompt)
def found(m: re.Match[str]):
@ -168,13 +167,17 @@ def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNe
return updated_prompt, res
def parse_prompts(prompts: list[str]):
def parse_prompts(prompts: list[str], extra_data=None):
updated_prompt_list: list[str] = []
extra_data: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list)
extra_data: defaultdict[str, list[ExtraNetworkParams]] = extra_data or defaultdict(list)
for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)
if not extra_data:
extra_data = parsed_extra_data
elif parsed_extra_data:
extra_data = parsed_extra_data
else:
pass
updated_prompt_list.append(updated_prompt)
return updated_prompt_list, extra_data

View File

@ -12,24 +12,10 @@ import piexif
import piexif.helper
from PIL import Image, PngImagePlugin, ExifTags, ImageDraw
from modules import sd_samplers, shared, script_callbacks, errors, paths
from modules.images_grid import (
image_grid as image_grid,
get_grid_size as get_grid_size,
split_grid as split_grid,
combine_grid as combine_grid,
check_grid_size as check_grid_size,
get_font as get_font,
draw_grid_annotations as draw_grid_annotations,
draw_prompt_matrix as draw_prompt_matrix,
GridAnnotation as GridAnnotation,
Grid as Grid,
)
from modules.images_resize import resize_image as resize_image
from modules.images_namegen import (
FilenameGenerator as FilenameGenerator,
get_next_sequence_number as get_next_sequence_number,
)
from modules.video import save_video as save_video
from modules.images_grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import
from modules.images_resize import resize_image # pylint: disable=unused-import
from modules.images_namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import
from modules.video import save_video # pylint: disable=unused-import
debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None

View File

@ -4,10 +4,27 @@ from installer import log, pip
from modules import devices
nunchaku_ver = '1.1.0'
nunchaku_versions = {
'2.5': '1.0.1',
'2.6': '1.0.1',
'2.7': '1.1.0',
'2.8': '1.1.0',
'2.9': '1.1.0',
'2.10': '1.0.2',
'2.11': '1.1.0',
}
ok = False
def _expected_ver():
try:
import torch
torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
return nunchaku_versions.get(torch_ver)
except Exception:
return None
def check():
global ok # pylint: disable=global-statement
if ok:
@ -16,8 +33,9 @@ def check():
import nunchaku
import nunchaku.utils
from nunchaku import __version__
expected = _expected_ver()
log.info(f'Nunchaku: path={nunchaku.__path__} version={__version__.__version__} precision={nunchaku.utils.get_precision()}')
if __version__.__version__ != nunchaku_ver:
if expected is not None and __version__.__version__ != expected:
ok = False
return False
ok = True
@ -49,14 +67,16 @@ def install_nunchaku():
if devices.backend not in ['cuda']:
log.error(f'Nunchaku: backend={devices.backend} unsupported')
return False
torch_ver = torch.__version__[:3]
if torch_ver not in ['2.5', '2.6', '2.7', '2.8', '2.9', '2.10']:
torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
nunchaku_ver = nunchaku_versions.get(torch_ver)
if nunchaku_ver is None:
log.error(f'Nunchaku: torch={torch.__version__} unsupported')
return False
suffix = 'x86_64' if arch == 'linux' else 'win_amd64'
url = os.environ.get('NUNCHAKU_COMMAND', None)
if url is None:
arch = f'{arch}_' if arch == 'linux' else ''
url = f'https://huggingface.co/nunchaku-tech/nunchaku/resolve/main/nunchaku-{nunchaku_ver}'
url = f'https://huggingface.co/nunchaku-ai/nunchaku/resolve/main/nunchaku-{nunchaku_ver}'
url += f'+torch{torch_ver}-cp{python_ver}-cp{python_ver}-{arch}{suffix}.whl'
cmd = f'install --upgrade {url}'
log.debug(f'Nunchaku: install="{url}"')

View File

@ -255,13 +255,25 @@ def check_quant(module: str = ''):
def check_nunchaku(module: str = ''):
from modules import shared
if module not in shared.opts.nunchaku_quantization:
model_name = getattr(shared.opts, 'sd_model_checkpoint', '')
if '+nunchaku' not in model_name:
return False
from modules import mit_nunchaku
mit_nunchaku.install_nunchaku()
if not mit_nunchaku.ok:
return False
return True
base_path = model_name.split('+')[0]
for v in shared.reference_models.values():
if v.get('path', '') != base_path:
continue
nunchaku_modules = v.get('nunchaku', None)
if nunchaku_modules is None:
continue
if isinstance(nunchaku_modules, bool) and nunchaku_modules:
nunchaku_modules = ['Model', 'TE']
if not isinstance(nunchaku_modules, list):
continue
if module in nunchaku_modules:
from modules import mit_nunchaku
mit_nunchaku.install_nunchaku()
return mit_nunchaku.ok
return False
def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None, modules_dtype_dict: dict = None):

View File

@ -270,7 +270,6 @@ def process_init(p: StableDiffusionProcessing):
p.all_prompts, p.all_negative_prompts = shared.prompt_styles.apply_styles_to_prompts(p.all_prompts, p.all_negative_prompts, p.styles, p.all_seeds)
p.prompts = p.all_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
p.prompts, _ = extra_networks.parse_prompts(p.prompts)
def process_samples(p: StableDiffusionProcessing, samples):

View File

@ -171,6 +171,7 @@ def process_base(p: processing.StableDiffusionProcessing):
modelstats.analyze()
try:
t0 = time.time()
p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts, p.network_data)
extra_networks.activate(p, exclude=['text_encoder', 'text_encoder_2', 'text_encoder_3'])
if hasattr(shared.sd_model, 'tgate') and getattr(p, 'gate_step', -1) > 0:
@ -297,10 +298,20 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
p.denoising_strength = strength
orig_image = p.task_args.pop('image', None) # remove image override from hires
process_pre(p)
prompts = p.prompts
reset_prompts = False
if len(p.refiner_prompt) > 0:
prompts = len(output.images)* [p.refiner_prompt]
prompts, p.network_data = extra_networks.parse_prompts(prompts)
reset_prompts = True
if reset_prompts or ('base' in p.skip):
extra_networks.activate(p)
hires_args = set_pipeline_args(
p=p,
model=shared.sd_model,
prompts=len(output.images)* [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
prompts=prompts,
negative_prompts=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
@ -314,11 +325,10 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
strength=strength,
desc='Hires',
)
hires_steps = hires_args.get('prior_num_inference_steps', None) or p.hr_second_pass_steps or hires_args.get('num_inference_steps', None)
shared.state.update(get_job_name(p, shared.sd_model), hires_steps, 1)
try:
if 'base' in p.skip:
extra_networks.activate(p)
taskid = shared.state.begin('Inference')
output = shared.sd_model(**hires_args) # pylint: disable=not-callable
shared.state.end(taskid)

View File

@ -18,14 +18,15 @@ def load_unet_sdxl_nunchaku(repo_id):
shared.log.error(f'Load module: quant=Nunchaku module=unet repo="{repo_id}" low nunchaku version')
return None
if 'turbo' in repo_id.lower():
nunchaku_repo = 'nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors'
nunchaku_repo = 'nunchaku-ai/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors'
else:
nunchaku_repo = 'nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors'
nunchaku_repo = 'nunchaku-ai/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors'
shared.log.debug(f'Load module: quant=Nunchaku module=unet repo="{nunchaku_repo}" offload={shared.opts.nunchaku_offload}')
if shared.opts.nunchaku_offload:
shared.log.warning('Load module: quant=Nunchaku module=unet offload not supported for SDXL, ignoring')
shared.log.debug(f'Load module: quant=Nunchaku module=unet repo="{nunchaku_repo}"')
unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
nunchaku_repo,
offload=shared.opts.nunchaku_offload,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)

View File

@ -4,53 +4,23 @@ import os
import sys
import time
import contextlib
from enum import Enum
from typing import TYPE_CHECKING
import gradio as gr
from installer import (
log as log,
print_dict,
console as console,
get_version as get_version,
)
log.debug("Initializing: shared module")
from installer import log, print_dict, console, get_version # pylint: disable=unused-import
log.debug('Initializing: shared module')
import modules.memmon
import modules.paths as paths
from modules.json_helpers import (
readfile as readfile,
writefile as writefile,
)
from modules.shared_helpers import (
listdir as listdir,
walk_files as walk_files,
html_path as html_path,
html as html,
req as req,
total_tqdm as total_tqdm,
)
from modules.json_helpers import readfile, writefile # pylint: disable=W0611
from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611
from modules import errors, devices, shared_state, cmd_args, theme, history, files_cache
from modules.shared_defaults import get_default_modes
from modules.paths import (
models_path as models_path, # For compatibility, do not modify from here...
script_path as script_path,
data_path as data_path,
sd_configs_path as sd_configs_path,
sd_default_config as sd_default_config,
sd_model_file as sd_model_file,
default_sd_model_file as default_sd_model_file,
extensions_dir as extensions_dir,
extensions_builtin_dir as extensions_builtin_dir, # ... to here.
)
from modules.memstats import (
memory_stats,
ram_stats as ram_stats,
)
from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import
log.debug("Initializing: pipelines")
log.debug('Initializing: pipelines')
from modules import shared_items
from modules.interrogate.openclip import caption_models, caption_types, get_clip_models, refresh_clip_models
from modules.interrogate.vqa import vlm_models, vlm_prompts, vlm_system, vlm_default
@ -280,7 +250,6 @@ options_templates.update(options_section(("quantization", "Model Quantization"),
"sdnq_quantize_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox),
"nunchaku_sep": OptionInfo("<h2>Nunchaku Engine</h2>", "", gr.HTML),
"nunchaku_quantization": OptionInfo([], "SVDQuant enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}),
"nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox),
"nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox),

View File

@ -305,6 +305,7 @@ class ExtraNetworksPage:
subdirs['Reference'] = 1
subdirs['Distilled'] = 1
subdirs['Quantized'] = 1
subdirs['Nunchaku'] = 1
subdirs['Community'] = 1
subdirs['Cloud'] = 1
subdirs[diffusers_base] = 1
@ -324,6 +325,8 @@ class ExtraNetworksPage:
subdirs.move_to_end('Distilled', last=True)
if 'Quantized' in subdirs:
subdirs.move_to_end('Quantized', last=True)
if 'Nunchaku' in subdirs:
subdirs.move_to_end('Nunchaku', last=True)
if 'Community' in subdirs:
subdirs.move_to_end('Community', last=True)
if 'Cloud' in subdirs:
@ -332,7 +335,7 @@ class ExtraNetworksPage:
for subdir in subdirs:
if len(subdir) == 0:
continue
if subdir in ['All', 'Local', 'Diffusers', 'Reference', 'Distilled', 'Quantized', 'Community', 'Cloud']:
if subdir in ['All', 'Local', 'Diffusers', 'Reference', 'Distilled', 'Quantized', 'Nunchaku', 'Community', 'Cloud']:
style = 'network-reference'
else:
style = 'network-folder'

View File

@ -3,7 +3,7 @@ import html
import json
import concurrent
from datetime import datetime
from modules import shared, ui_extra_networks, sd_models, modelstats, paths
from modules import shared, ui_extra_networks, sd_models, modelstats, paths, devices
from modules.json_helpers import readfile
@ -48,16 +48,21 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
reference_distilled = readfile(os.path.join('data', 'reference-distilled.json'), as_type="dict")
reference_community = readfile(os.path.join('data', 'reference-community.json'), as_type="dict")
reference_cloud = readfile(os.path.join('data', 'reference-cloud.json'), as_type="dict")
reference_nunchaku = readfile(os.path.join('data', 'reference-nunchaku.json'), as_type="dict")
shared.reference_models = {}
shared.reference_models.update(reference_base)
shared.reference_models.update(reference_quant)
shared.reference_models.update(reference_community)
shared.reference_models.update(reference_distilled)
shared.reference_models.update(reference_cloud)
shared.reference_models.update(reference_nunchaku)
for k, v in shared.reference_models.items():
count['total'] += 1
url = v['path']
if v.get('hidden', False):
count['hidden'] += 1
continue
experimental = v.get('experimental', False)
if experimental:
if shared.cmd_opts.experimental:
@ -83,6 +88,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
path = f'{v.get("path", "")}'
tag = v.get('tags', '')
if tag == 'nunchaku' and (devices.backend != 'cuda' and not shared.cmd_opts.experimental):
count['hidden'] += 1
continue
if tag in count:
count[tag] += 1
elif tag != '':

View File

@ -9,19 +9,19 @@ def load_flux_nunchaku(repo_id):
if 'srpo' in repo_id.lower():
pass
elif 'flux.1-dev' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
elif 'flux.1-schnell' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-kontext' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
elif 'flux.1-krea' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-krea-dev/svdq-{nunchaku_precision}_r32-flux.1-krea-dev.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-krea-dev/svdq-{nunchaku_precision}_r32-flux.1-krea-dev.safetensors"
elif 'flux.1-fill' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-fill-dev/svdq-{nunchaku_precision}-flux.1-fill-dev.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-fill-dev/svdq-{nunchaku_precision}-flux.1-fill-dev.safetensors"
elif 'flux.1-depth' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{nunchaku_precision}-flux.1-depth-dev.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-depth-dev/svdq-{nunchaku_precision}-flux.1-depth-dev.safetensors"
elif 'shuttle' in repo_id.lower():
nunchaku_repo = f"nunchaku-tech/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}-shuttle-jaguar.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}-shuttle-jaguar.safetensors"
else:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:

View File

@ -152,7 +152,7 @@ def load_text_encoder(repo_id, cls_name, load_config=None, subfolder="text_encod
elif cls_name == transformers.T5EncoderModel and allow_shared and shared.opts.te_shared_t5:
if model_quant.check_nunchaku('TE'):
import nunchaku
repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
repo_id = 'nunchaku-ai/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
cls_name = nunchaku.NunchakuT5EncoderModel
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant" loader={_loader("transformers")}')
text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained(

View File

@ -37,7 +37,7 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageInpaintPipeline
if model_quant.check_nunchaku('Model'):
transformer = qwen.load_qwen_nunchaku(repo_id)
transformer = qwen.load_qwen_nunchaku(repo_id, subfolder=repo_subfolder)
if 'Qwen-Image-Distill-Full' in repo_id:
repo_transformer = repo_id
@ -63,6 +63,8 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
text_encoder = generic.load_text_encoder(repo_te, cls_name=transformers.Qwen2_5_VLForConditionalGeneration, load_config=diffusers_load_config)
repo_id, repo_subfolder = qwen.check_qwen_pruning(repo_id, repo_subfolder)
if repo_subfolder is not None and repo_subfolder.startswith('nunchaku'):
repo_subfolder = None
pipe = cls_name.from_pretrained(
repo_id,
transformer=transformer,

View File

@ -9,7 +9,7 @@ def load_quants(kwargs, repo_id, cache_dir):
if 'Sana_1600M_1024px' in repo_id and model_quant.check_nunchaku('Model'): # only available model
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = "nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
nunchaku_repo = "nunchaku-ai/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} attention={shared.opts.nunchaku_attention}')
kwargs['transformer'] = nunchaku.NunchakuSanaTransformer2DModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir)
elif model_quant.check_quant('Model'):

View File

@ -8,7 +8,7 @@ def load_nunchaku():
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_rank = 128
nunchaku_repo = f"nunchaku-tech/nunchaku-z-image-turbo/svdq-{nunchaku_precision}_r{nunchaku_rank}-z-image-turbo.safetensors"
nunchaku_repo = f"nunchaku-ai/nunchaku-z-image-turbo/svdq-{nunchaku_precision}_r{nunchaku_rank}-z-image-turbo.safetensors"
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" attention={shared.opts.nunchaku_attention}')
transformer = nunchaku.NunchakuZImageTransformer2DModel.from_pretrained(
nunchaku_repo,

View File

@ -1,11 +1,12 @@
from modules import shared, devices
def load_qwen_nunchaku(repo_id):
def load_qwen_nunchaku(repo_id, subfolder=None):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
transformer = None
four_step = subfolder is not None and '4step' in subfolder
try:
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
except Exception:
@ -14,15 +15,21 @@ def load_qwen_nunchaku(repo_id):
if 'pruning' in repo_id.lower() or 'distill' in repo_id.lower():
return None
elif repo_id.lower().endswith('qwen-image'):
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors" # r32 vs r128
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors"
elif repo_id.lower().endswith('qwen-lightning'):
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.1-8steps.safetensors" # 8-step variant
if four_step:
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.0-4steps.safetensors"
else:
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.1-8steps.safetensors"
elif repo_id.lower().endswith('qwen-image-edit-2509'):
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit-2509/svdq-{nunchaku_precision}_r128-qwen-image-edit-2509.safetensors" # 8-step variant
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit-2509/svdq-{nunchaku_precision}_r128-qwen-image-edit-2509.safetensors"
elif repo_id.lower().endswith('qwen-image-edit'):
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit.safetensors" # 8-step variant
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit.safetensors"
elif repo_id.lower().endswith('qwen-lightning-edit'):
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-8steps.safetensors" # 8-step variant
if four_step:
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-4steps.safetensors"
else:
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-8steps.safetensors"
else:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:

360
pyproject.toml Normal file
View File

@ -0,0 +1,360 @@
[project]
name = "SD.Next"
version = "0.0.0"
description = "SD.Next: All-in-one WebUI for AI generative image and video creation and captioning"
readme = "README.md"
requires-python = ">=3.10"
[tool.ruff]
line-length = 250
indent-width = 4
target-version = "py310"
exclude = [
"venv",
".git",
".ruff_cache",
".vscode",
"modules/cfgzero",
"modules/flash_attn_triton_amd",
"modules/hidiffusion",
"modules/intel/ipex",
"modules/pag",
"modules/schedulers",
"modules/teacache",
"modules/seedvr",
"modules/control/proc",
"modules/control/units",
"modules/control/units/xs_pipe.py",
"modules/postprocess/aurasr_arch.py",
"pipelines/meissonic",
"pipelines/omnigen2",
"pipelines/hdm",
"pipelines/segmoe",
"pipelines/xomni",
"pipelines/chrono",
"scripts/lbm",
"scripts/daam",
"scripts/xadapter",
"scripts/pulid",
"scripts/instantir",
"scripts/freescale",
"scripts/consistory",
"repositories",
"extensions-builtin/Lora",
"extensions-builtin/sd-extension-chainner/nodes",
"extensions-builtin/sd-webui-agent-scheduler",
"extensions-builtin/sdnext-modernui/node_modules",
]
[tool.ruff.lint]
select = [
"F",
"E",
"W",
"C",
"B",
"I",
"YTT",
"ASYNC",
"RUF",
"AIR",
"NPY",
"C4",
"T10",
"EXE",
"ISC",
"ICN",
"RSE",
"TC",
"TID",
"INT",
"PLE",
]
ignore = [
"B006", # Do not use mutable data structures for argument defaults
"B008", # Do not perform function call in argument defaults
"B905", # Strict zip() usage
"C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead
"C408", # Unnecessary `dict` call
"I001", # Import block is un-sorted or un-formatted
"E402", # Module level import not at top of file
"E501", # Line too long
"E721", # Do not compare types, use `isinstance()`
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"F401", # Imported by unused
"EXE001", # file with shebang is not marked executable
"NPY002", # replace legacy random
"RUF005", # Consider iterable unpacking
"RUF008", # Do not use mutable default values for dataclass
"RUF010", # Use explicit conversion flag
"RUF012", # Mutable class attributes
"RUF013", # PEP 484 prohibits implicit `Optional`
"RUF015", # Prefer `next(...)` over single element slice
"RUF046", # Value being cast to `int` is already an integer
"RUF059", # Unpacked variables are not used
"RUF051", # Prefer pop over del
]
fixable = ["ALL"]
unfixable = []
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.ruff.lint.mccabe]
max-complexity = 150
[tool.pylint]
main.py-version="3.10"
main.analyse-fallback-blocks=false
main.clear-cache-post-run=false
main.extension-pkg-allow-list=""
main.prefer-stubs=true
main.extension-pkg-whitelist=""
main.fail-on=""
main.fail-under=10
main.ignore="CVS"
main.ignore-paths=[
"venv",
".git",
".ruff_cache",
".vscode",
"modules/apg",
"modules/cfgzero",
"modules/control/proc",
"modules/control/units",
"modules/dml",
"modules/flash_attn_triton_amd",
"modules/ggml",
"modules/hidiffusion",
"modules/hijack/ddpm_edit.py",
"modules/intel",
"modules/intel/ipex",
"modules/framepack/pipeline",
"modules/onnx_impl",
"modules/pag",
"modules/postprocess/aurasr_arch.py",
"modules/prompt_parser_xhinker.py",
"modules/ras",
"modules/seedvr",
"modules/rife",
"modules/schedulers",
"modules/taesd",
"modules/teacache",
"modules/todo",
"modules/res4lyf",
"pipelines/bria",
"pipelines/flex2",
"pipelines/f_lite",
"pipelines/hidream",
"pipelines/hdm",
"pipelines/meissonic",
"pipelines/omnigen2",
"pipelines/segmoe",
"pipelines/xomni",
"pipelines/chrono",
"scripts/consistory",
"scripts/ctrlx",
"scripts/daam",
"scripts/demofusion",
"scripts/freescale",
"scripts/infiniteyou",
"scripts/instantir",
"scripts/lbm",
"scripts/layerdiffuse",
"scripts/mod",
"scripts/pixelsmith",
"scripts/differential_diffusion.py",
"scripts/pulid",
"scripts/xadapter",
"repositories",
"extensions-builtin/sd-extension-chainner/nodes",
"extensions-builtin/sd-webui-agent-scheduler",
"extensions-builtin/sdnext-modernui/node_modules",
"extensions-builtin/sdnext-kanvas/node_modules",
]
main.ignore-patterns=[
".*test*.py$",
".*_model.py$",
".*_arch.py$",
".*_model_arch.py*",
".*_model_arch_v2.py$",
]
main.ignored-modules=""
main.jobs=8
main.limit-inference-results=100
main.load-plugins=""
main.persistent=false
main.recursive=false
main.source-roots=""
main.unsafe-load-any-extension=false
basic.argument-naming-style="snake_case"
basic.attr-naming-style="snake_case"
basic.bad-names=["foo", "bar", "baz", "toto", "tutu", "tata"]
basic.bad-names-rgxs=""
basic.class-attribute-naming-style="any"
basic.class-const-naming-style="UPPER_CASE"
basic.class-naming-style="PascalCase"
basic.const-naming-style="snake_case"
basic.docstring-min-length=-1
basic.function-naming-style="snake_case"
basic.good-names=["i","j","k","e","ex","ok","p","x","y","id"]
basic.good-names-rgxs=""
basic.include-naming-hint=false
basic.inlinevar-naming-style="any"
basic.method-naming-style="snake_case"
basic.module-naming-style="snake_case"
basic.name-group=""
basic.no-docstring-rgx="^_"
basic.property-classes="abc.abstractproperty"
basic.variable-naming-style="snake_case"
classes.check-protected-access-in-special-methods=false
classes.defining-attr-methods=["__init__", "__new__"]
classes.exclude-protected=["_asdict","_fields","_replace","_source","_make","os._exit"]
classes.valid-classmethod-first-arg="cls"
classes.valid-metaclass-classmethod-first-arg="mcs"
design.exclude-too-few-public-methods=""
design.ignored-parents=""
design.max-args=199
design.max-attributes=99
design.max-bool-expr=99
design.max-branches=199
design.max-locals=99
design.max-parents=99
design.max-public-methods=99
design.max-returns=99
design.max-statements=199
design.min-public-methods=1
exceptions.overgeneral-exceptions=["builtins.BaseException","builtins.Exception"]
format.expected-line-ending-format=""
# format.ignore-long-lines="^\s*(# )?<?https?://\S+>?$"
format.indent-after-paren=4
format.indent-string=' '
format.max-line-length=200
format.max-module-lines=9999
format.single-line-class-stmt=false
format.single-line-if-stmt=false
imports.allow-any-import-level=""
imports.allow-reexport-from-package=false
imports.allow-wildcard-with-all=false
imports.deprecated-modules=""
imports.ext-import-graph=""
imports.import-graph=""
imports.int-import-graph=""
imports.known-standard-library=""
imports.known-third-party="enchant"
imports.preferred-modules=""
logging.logging-format-style="new"
logging.logging-modules="logging"
messages_control.confidence=["HIGH","CONTROL_FLOW","INFERENCE","INFERENCE_FAILURE","UNDEFINED"]
messages_control.disable=[
"abstract-method",
"bad-inline-option",
"bare-except",
"broad-exception-caught",
"chained-comparison",
"consider-iterating-dictionary",
"consider-merging-isinstance",
"consider-using-dict-items",
"consider-using-enumerate",
"consider-using-from-import",
"consider-using-generator",
"consider-using-get",
"consider-using-in",
"consider-using-max-builtin",
"consider-using-min-builtin",
"consider-using-sys-exit",
"cyclic-import",
"dangerous-default-value",
"deprecated-pragma",
"duplicate-code",
"file-ignored",
"import-error",
"import-outside-toplevel",
"invalid-name",
"line-too-long",
"locally-disabled",
"logging-fstring-interpolation",
"missing-class-docstring",
"missing-function-docstring",
"missing-module-docstring",
"no-else-raise",
"no-else-return",
"not-callable",
"pointless-string-statement",
"raw-checker-failed",
"simplifiable-if-expression",
"suppressed-message",
"too-few-public-methods",
"too-many-instance-attributes",
"too-many-locals",
"too-many-nested-blocks",
"too-many-positional-arguments",
"too-many-statements",
"unidiomatic-typecheck",
"unknown-option-value",
"unnecessary-dict-index-lookup",
"unnecessary-dunder-call",
"unnecessary-lambda-assigment",
"unnecessary-lambda",
"unused-wildcard-import",
"unpacking-non-sequence",
"unsubscriptable-object",
"useless-return",
"use-dict-literal",
"use-symbolic-message-instead",
"useless-suppression",
"wrong-import-position",
]
messages_control.enable="c-extension-no-member"
method_args.timeout-methods=["requests.api.delete","requests.api.get","requests.api.head","requests.api.options","requests.api.patch","requests.api.post","requests.api.put","requests.api.request"]
miscellaneous.notes=["FIXME","XXX","TODO"]
miscellaneous.notes-rgx=""
refactoring.max-nested-blocks=5
refactoring.never-returning-functions=["sys.exit","argparse.parse_error"]
reports.msg-template=""
reports.reports=false
reports.score=false
similarities.ignore-comments=true
similarities.ignore-docstrings=true
similarities.ignore-imports=true
similarities.ignore-signatures=true
similarities.min-similarity-lines=4
spelling.max-spelling-suggestions=4
# spelling.dict=""
# spelling.ignore-comment-directives=["fmt: on","fmt: off","noqa:","noqa","nosec","isort:skip","mypy:"]
# spelling.ignore-words=""
# spelling.private-dict-file=""
# spelling.store-unknown-words=false
string.check-quote-consistency=false
string.check-str-concat-over-line-jumps=false
typecheck.contextmanager-decorators="contextlib.contextmanager"
typecheck.generated-members=["numpy.*","logging.*","torch.*","cv2.*"]
typecheck.ignore-none=true
typecheck.ignore-on-opaque-inference=true
typecheck.ignored-checks-for-mixins=["no-member","not-async-context-manager","not-context-manager","attribute-defined-outside-init"]
typecheck.ignored-classes=["optparse.Values","thread._local","_thread._local","argparse.Namespace","unittest.case._AssertRaisesContext","unittest.case._AssertWarnsContext"]
typecheck.missing-member-hint=true
typecheck.missing-member-hint-distance=1
typecheck.missing-member-max-choices=1
typecheck.mixin-class-rgx=".*[Mm]ixin"
typecheck.signature-mutators=""
variables.additional-builtins=""
variables.allow-global-unused-variables=true
variables.allowed-redefined-builtins=""
variables.callbacks=["cb_",]
variables.dummy-variables-rgx="_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_"
variables.ignored-argument-names="_.*|^ignored_|^unused_"
variables.init-import=false
variables.redefining-builtins-modules=["six.moves","past.builtins","future.builtins","builtins","io"]

View File

@ -55,7 +55,7 @@ protobuf==4.25.3
pytorch_lightning==2.6.0
urllib3==1.26.19
Pillow==10.4.0
timm==1.0.16
timm==1.0.24
pyparsing==3.2.3
typing-extensions==4.14.1
sentencepiece==0.2.1

View File

@ -25,7 +25,7 @@ class Script(scripts_manager.Script):
with gr.Row():
gr.HTML('<a href="https://blackforestlabs.ai/flux-1-tools/">&nbsp Flux.1 Redux</a><br>')
with gr.Row():
tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Canny', 'Depth'], value='None')
tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Fill (Nunchaku)', 'Canny', 'Depth', 'Depth (Nunchaku)'], value='None')
with gr.Row():
prompt = gr.Slider(label='Redux prompt strength', minimum=0, maximum=2, step=0.01, value=0, visible=False)
process = gr.Checkbox(label='Control preprocess input images', value=True, visible=False)
@ -34,8 +34,8 @@ class Script(scripts_manager.Script):
def display(tool):
return [
gr.update(visible=tool in ['Redux']),
gr.update(visible=tool in ['Canny', 'Depth']),
gr.update(visible=tool in ['Canny', 'Depth']),
gr.update(visible=tool in ['Canny', 'Depth', 'Depth (Nunchaku)']),
gr.update(visible=tool in ['Canny', 'Depth', 'Depth (Nunchaku)']),
]
tool.change(fn=display, inputs=[tool], outputs=[prompt, process, strength])
@ -91,13 +91,15 @@ class Script(scripts_manager.Script):
shared.log.debug(f'{title}: tool=Redux unload')
redux_pipe = None
if tool == 'Fill':
if tool in ['Fill', 'Fill (Nunchaku)']:
# pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, revision="refs/pr/4").to("cuda")
if p.image_mask is None:
shared.log.error(f'{title}: tool={tool} no image_mask')
return None
if shared.sd_model.__class__.__name__ != 'FluxFillPipeline':
shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Fill-dev"
nunchaku_suffix = '+nunchaku' if tool == 'Fill (Nunchaku)' else ''
checkpoint = f"black-forest-labs/FLUX.1-Fill-dev{nunchaku_suffix}"
if shared.sd_model.__class__.__name__ != 'FluxFillPipeline' or shared.opts.sd_model_checkpoint != checkpoint:
shared.opts.data["sd_model_checkpoint"] = checkpoint
sd_models.reload_model_weights(op='model', revision="refs/pr/4")
p.task_args['image'] = image
p.task_args['mask_image'] = p.image_mask
@ -124,11 +126,13 @@ class Script(scripts_manager.Script):
shared.log.debug(f'{title}: tool=Canny unload processor')
processor_canny = None
if tool == 'Depth':
if tool in ['Depth', 'Depth (Nunchaku)']:
# pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda")
install('git+https://github.com/huggingface/image_gen_aux.git', 'image_gen_aux')
if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Depth' not in shared.opts.sd_model_checkpoint:
shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Depth-dev"
nunchaku_suffix = '+nunchaku' if tool == 'Depth (Nunchaku)' else ''
checkpoint = f"black-forest-labs/FLUX.1-Depth-dev{nunchaku_suffix}"
if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or shared.opts.sd_model_checkpoint != checkpoint:
shared.opts.data["sd_model_checkpoint"] = checkpoint
sd_models.reload_model_weights(op='model', revision="refs/pr/1")
if processor_depth is None:
from image_gen_aux import DepthPreprocessor

View File

@ -1,18 +1,25 @@
from typing import Dict, Any
import time
import warnings
import torch
import torch.nn.functional as F
from typing import Dict, Any
warnings.filterwarnings("ignore", category=UserWarning)
warmup = 2
repeats = 50
dtypes = [torch.bfloat16] # , torch.float16]
# if hasattr(torch, "float8_e4m3fn"):
# dtypes.append(torch.float8_e4m3fn)
PROFILES = {
backends = [
"sdpa_math",
"sdpa_mem_efficient",
"sdpa_flash",
"sdpa_all",
"flex_attention",
"xformers",
"flash_attn",
"sage_attn",
]
profiles = {
"sdxl": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 128},
"flux.1": {"l_q": 16717, "l_k": 16717, "h": 24, "d": 128},
"sd35": {"l_q": 16538, "l_k": 16538, "h": 24, "d": 128},
@ -30,19 +37,19 @@ def get_stats(reset: bool = False):
m = torch.cuda.max_memory_allocated()
t = time.perf_counter()
return m / (1024 ** 2), t
def print_gpu_info():
if not torch.cuda.is_available():
print("GPU: Not available")
return
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
total_mem = props.total_memory / (1024**3)
free_mem, _ = torch.cuda.mem_get_info(device)
free_mem = free_mem / (1024**3)
major, minor = torch.cuda.get_device_capability(device)
print(f"gpu: {torch.cuda.get_device_name(device)}")
print(f"vram: total={total_mem:.2f}GB free={free_mem:.2f}GB")
print(f"cuda: capability={major}.{minor} version={torch.version.cuda}")
@ -56,11 +63,9 @@ def benchmark_attention(
l_k: int = 4096,
h: int = 32,
d: int = 128,
warmup: int = 10,
repeats: int = 100
) -> Dict[str, Any]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize tensors
q = torch.randn(b, h, l_q, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
k = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
@ -78,7 +83,7 @@ def benchmark_attention(
try:
if backend.startswith("sdpa_"):
from torch.nn.attention import sdpa_kernel, SDPBackend
sdp_type = backend[len("sdpa_"):]
sdp_type = backend[len("sdpa_"):]
# Map friendly names to new SDPA backends
backend_map = {
"math": [SDPBackend.MATH],
@ -88,21 +93,21 @@ def benchmark_attention(
}
if sdp_type not in backend_map:
raise ValueError(f"Unknown SDPA type: {sdp_type}")
results["version"] = torch.__version__
with sdpa_kernel(backend_map[sdp_type]):
# Warmup
for _ in range(warmup):
_ = F.scaled_dot_product_attention(q, k, v)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = F.scaled_dot_product_attention(q, k, v)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
@ -113,17 +118,17 @@ def benchmark_attention(
q_fa = q.transpose(1, 2)
k_fa = k.transpose(1, 2)
v_fa = v.transpose(1, 2)
for _ in range(warmup):
_ = flash_attn_func(q_fa, k_fa, v_fa)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = flash_attn_func(q_fa, k_fa, v_fa)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
@ -135,17 +140,17 @@ def benchmark_attention(
q_xf = q.transpose(1, 2)
k_xf = k.transpose(1, 2)
v_xf = v.transpose(1, 2)
for _ in range(warmup):
_ = memory_efficient_attention(q_xf, k_xf, v_xf)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = memory_efficient_attention(q_xf, k_xf, v_xf)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
@ -158,18 +163,18 @@ def benchmark_attention(
results["version"] = importlib.metadata.version("sageattention")
except Exception:
results["version"] = getattr(sageattention, "__version__", "N/A")
# SageAttention expects (B, H, L, D) logic
for _ in range(warmup):
_ = sageattn(q, k, v)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = sageattn(q, k, v)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
@ -179,62 +184,52 @@ def benchmark_attention(
# flex_attention requires torch.compile for performance
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
# Warmup (important to trigger compilation)
for _ in range(warmup):
_ = flex_attention_compiled(q, k, v)
start_mem, start_time = get_stats(True)
for _ in range(repeats):
_ = flex_attention_compiled(q, k, v)
end_mem, end_time = get_stats()
results["latency_ms"] = (end_time - start_time) / repeats * 1000
results["memory_mb"] = end_mem - start_mem
except Exception as e:
results["status"] = "fail"
results["error"] = str(e)[:49]
print(e)
return results
def main():
backends = [
"sdpa_math",
"sdpa_mem_efficient",
"sdpa_flash",
"flex_attention",
"xformers",
"flash_attn",
"sage_attn",
]
def main():
all_results = []
print_gpu_info()
print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}')
for name, config in PROFILES.items():
print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}')
for name, config in profiles.items():
print(f"profile: {name} (L_q={config['l_q']}, L_k={config['l_k']}, H={config['h']}, D={config['d']})")
for dtype in dtypes:
print(f" dtype: {dtype}")
print(f" {'backend':<20} | {'version':<12} | {'status':<8} | {'latency':<10} | {'memory':<12} | ")
for backend in backends:
res = benchmark_attention(
backend,
dtype,
l_q=config["l_q"],
l_k=config["l_k"],
h=config["h"],
d=config["d"],
warmup=warmup,
repeats=repeats
backend,
dtype,
l_q=config["l_q"],
l_k=config["l_k"],
h=config["h"],
d=config["d"],
)
all_results.append(res)
latency = f"{res['latency_ms']:.4f} ms"
memory = f"{res['memory_mb']:.2f} MB"
print(f" {res['backend']:<20} | {res['version']:<12} | {res['status']:<8} | {latency:<10} | {memory:<12} | {res['error']}")
if __name__ == "__main__":

0
cli/full-test.sh → test/full-test.sh Executable file → Normal file
View File

View File

0
cli/localize.js → test/localize.js Executable file → Normal file
View File

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
import sys, os
import sys
import os
from collections import Counter
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@ -20,25 +21,25 @@ tolerance_pct = 5
# tests
tests = [
# - empty
["", { '': 100 } ],
["", { '': 100 } ],
# - no weights
[ "red|blonde|black", { 'black': 33, 'red': 33, 'blonde': 33 } ],
[ "red|blonde|black", { 'black': 33, 'red': 33, 'blonde': 33 } ],
# - full weights <= 1
[ "red:0.1|blonde:0.9", { 'blonde': 90, 'red': 10 } ],
[ "red:0.1|blonde:0.9", { 'blonde': 90, 'red': 10 } ],
# - weights > 1 to test normalization
[ "red:1|blonde:2|black:5", { 'blonde': 25, 'red': 12.5, 'black': 62.5 } ],
[ "red:1|blonde:2|black:5", { 'blonde': 25, 'red': 12.5, 'black': 62.5 } ],
# - disabling 0 weights to force one result
[ "red:0|blonde|black:0", { 'blonde': 100 } ],
[ "red:0|blonde|black:0", { 'blonde': 100 } ],
# - weights <= 1 with distribution of the leftover
[ "red:0.5|blonde|black:0.3|brown", { 'red': 50, 'black': 30, 'brown': 10, 'blonde': 10 } ],
[ "red:0.5|blonde|black:0.3|brown", { 'red': 50, 'black': 30, 'brown': 10, 'blonde': 10 } ],
# - weights > 1, unweightes should get default of 1
[ "red:2|blonde|black", { 'red': 50, 'blonde': 25, 'black': 25 } ],
[ "red:2|blonde|black", { 'red': 50, 'blonde': 25, 'black': 25 } ],
# - ignore content of ()
[ "red:0.5|(blonde:1.3)", { 'red': 50, '(blonde:1.3)': 50 } ],
[ "red:0.5|(blonde:1.3)", { 'red': 50, '(blonde:1.3)': 50 } ],
# - ignore content of []
[ "red:0.5|[stuff:1.3]", { '[stuff:1.3]': 50, 'red': 50 } ],
[ "red:0.5|[stuff:1.3]", { '[stuff:1.3]': 50, 'red': 50 } ],
# - ignore content of <>
[ "red:0.5|<lora:1.0>", { '<lora:1.0>': 50, 'red': 50 } ]
[ "red:0.5|<lora:1.0>", { '<lora:1.0>': 50, 'red': 50 } ]
]
# -------------------------------------------------
@ -109,4 +110,4 @@ with open(fn, 'r', encoding='utf-8') as f:
print("RESULT: FAILED (distribution)")
else:
print("RESULT: PASSED")
print('')
print('')

0
cli/validate-locale.py → test/validate-locale.py Executable file → Normal file
View File