diff --git a/.github/ISSUE_TEMPLATE/model_request b/.github/ISSUE_TEMPLATE/model_request index adf8d10f8..72afc26ce 100644 --- a/.github/ISSUE_TEMPLATE/model_request +++ b/.github/ISSUE_TEMPLATE/model_request @@ -9,22 +9,22 @@ body: attributes: label: Model name description: Enter model name - value: + value: - type: textarea id: type attributes: label: Model type description: Describe model type - value: + value: - type: textarea id: url attributes: label: Model URL description: Enter URL to the model page - value: + value: - type: textarea id: description attributes: label: Reason description: Enter a reason why you would like to see this model supported - value: + value: diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 8fc99d274..000000000 --- a/.pylintrc +++ /dev/null @@ -1,283 +0,0 @@ -[MAIN] -analyse-fallback-blocks=no -clear-cache-post-run=no -extension-pkg-allow-list= -prefer-stubs=yes -extension-pkg-whitelist= -fail-on= -fail-under=10 -ignore=CVS -ignore-paths=/usr/lib/.*$, - venv, - .git, - .ruff_cache, - .vscode, - modules/apg, - modules/cfgzero, - modules/control/proc, - modules/control/units, - modules/dml, - modules/flash_attn_triton_amd, - modules/ggml, - modules/hidiffusion, - modules/hijack/ddpm_edit.py, - modules/intel, - modules/intel/ipex, - modules/framepack/pipeline, - modules/onnx_impl, - modules/pag, - modules/postprocess/aurasr_arch.py, - modules/prompt_parser_xhinker.py, - modules/ras, - modules/seedvr, - modules/rife, - modules/schedulers, - modules/taesd, - modules/teacache, - modules/todo, - modules/res4lyf, - pipelines/bria, - pipelines/flex2, - pipelines/f_lite, - pipelines/hidream, - pipelines/hdm, - pipelines/meissonic, - pipelines/omnigen2, - pipelines/segmoe, - pipelines/xomni, - pipelines/chrono, - scripts/consistory, - scripts/ctrlx, - scripts/daam, - scripts/demofusion, - scripts/freescale, - scripts/infiniteyou, - scripts/instantir, - scripts/lbm, - scripts/layerdiffuse, - scripts/mod, - scripts/pixelsmith, - scripts/differential_diffusion.py, - scripts/pulid, - scripts/xadapter, - repositories, - extensions-builtin/sd-extension-chainner/nodes, - extensions-builtin/sd-webui-agent-scheduler, - extensions-builtin/sdnext-modernui/node_modules, - extensions-builtin/sdnext-kanvas/node_modules, -ignore-patterns=.*test*.py$, - .*_model.py$, - .*_arch.py$, - .*_model_arch.py*, - .*_model_arch_v2.py$, -ignored-modules= -jobs=8 -limit-inference-results=100 -load-plugins= -persistent=no -py-version=3.10 -recursive=no -source-roots= -unsafe-load-any-extension=no - -[BASIC] -argument-naming-style=snake_case -attr-naming-style=snake_case -bad-names=foo, bar, baz, toto, tutu, tata -bad-names-rgxs= -class-attribute-naming-style=any -class-const-naming-style=UPPER_CASE -class-naming-style=PascalCase -const-naming-style=snake_case -docstring-min-length=-1 -function-naming-style=snake_case -good-names=i,j,k,e,ex,ok,p,x,y,id -good-names-rgxs= -include-naming-hint=no -inlinevar-naming-style=any -method-naming-style=snake_case -module-naming-style=snake_case -name-group= -no-docstring-rgx=^_ -property-classes=abc.abstractproperty -variable-naming-style=snake_case - -[CLASSES] -check-protected-access-in-special-methods=no -defining-attr-methods=__init__, __new__, -exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit -valid-classmethod-first-arg=cls -valid-metaclass-classmethod-first-arg=mcs - -[DESIGN] -exclude-too-few-public-methods= -ignored-parents= -max-args=199 -max-attributes=99 -max-bool-expr=99 -max-branches=199 -max-locals=99 -max-parents=99 -max-public-methods=99 -max-returns=99 -max-statements=199 -min-public-methods=1 - -[EXCEPTIONS] -overgeneral-exceptions=builtins.BaseException,builtins.Exception - -[FORMAT] -expected-line-ending-format= -ignore-long-lines=^\s*(# )??$ -indent-after-paren=4 -indent-string=' ' -max-line-length=200 -max-module-lines=9999 -single-line-class-stmt=no -single-line-if-stmt=no - -[IMPORTS] -allow-any-import-level= -allow-reexport-from-package=no -allow-wildcard-with-all=no -deprecated-modules= -ext-import-graph= -import-graph= -int-import-graph= -known-standard-library= -known-third-party=enchant -preferred-modules= - -[LOGGING] -logging-format-style=new -logging-modules=logging - -[MESSAGES CONTROL] -confidence=HIGH, - CONTROL_FLOW, - INFERENCE, - INFERENCE_FAILURE, - UNDEFINED -# disable=C,R,W -disable=abstract-method, - bad-inline-option, - bare-except, - broad-exception-caught, - chained-comparison, - consider-iterating-dictionary, - consider-merging-isinstance, - consider-using-dict-items, - consider-using-enumerate, - consider-using-from-import, - consider-using-generator, - consider-using-get, - consider-using-in, - consider-using-max-builtin, - consider-using-min-builtin, - consider-using-sys-exit, - cyclic-import, - dangerous-default-value, - deprecated-pragma, - duplicate-code, - file-ignored, - import-error, - import-outside-toplevel, - invalid-name, - line-too-long, - locally-disabled, - logging-fstring-interpolation, - missing-class-docstring, - missing-function-docstring, - missing-module-docstring, - no-else-raise, - no-else-return, - not-callable, - pointless-string-statement, - raw-checker-failed, - simplifiable-if-expression, - suppressed-message, - too-few-public-methods, - too-many-instance-attributes, - too-many-locals, - too-many-nested-blocks, - too-many-positional-arguments, - too-many-statements, - unidiomatic-typecheck, - unknown-option-value, - unnecessary-dict-index-lookup, - unnecessary-dunder-call, - unnecessary-lambda-assigment, - unnecessary-lambda, - unused-wildcard-import, - unpacking-non-sequence, - unsubscriptable-object, - useless-return, - use-dict-literal, - use-symbolic-message-instead, - useless-suppression, - wrong-import-position, -enable=c-extension-no-member - -[METHOD_ARGS] -timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request - -[MISCELLANEOUS] -notes=FIXME, - XXX, - TODO -notes-rgx= - -[REFACTORING] -max-nested-blocks=5 -never-returning-functions=sys.exit,argparse.parse_error - -[REPORTS] -evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) -msg-template= -reports=no -score=no - -[SIMILARITIES] -ignore-comments=yes -ignore-docstrings=yes -ignore-imports=yes -ignore-signatures=yes -min-similarity-lines=4 - -[SPELLING] -max-spelling-suggestions=4 -spelling-dict= -spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: -spelling-ignore-words= -spelling-private-dict-file= -spelling-store-unknown-words=no - -[STRING] -check-quote-consistency=no -check-str-concat-over-line-jumps=no - -[TYPECHECK] -contextmanager-decorators=contextlib.contextmanager -generated-members=numpy.*,logging.*,torch.*,cv2.* -ignore-none=yes -ignore-on-opaque-inference=yes -ignored-checks-for-mixins=no-member, - not-async-context-manager, - not-context-manager, - attribute-defined-outside-init -ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace -missing-member-hint=yes -missing-member-hint-distance=1 -missing-member-max-choices=1 -mixin-class-rgx=.*[Mm]ixin -signature-mutators= - -[VARIABLES] -additional-builtins= -allow-global-unused-variables=yes -allowed-redefined-builtins= -callbacks=cb_, -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ -ignored-argument-names=_.*|^ignored_|^unused_ -init-import=no -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index 10e1e2d71..000000000 --- a/.ruff.toml +++ /dev/null @@ -1,108 +0,0 @@ -line-length = 250 -indent-width = 4 -target-version = "py310" -exclude = [ - "venv", - ".git", - ".ruff_cache", - ".vscode", - - "modules/cfgzero", - "modules/flash_attn_triton_amd", - "modules/hidiffusion", - "modules/intel/ipex", - "modules/pag", - "modules/schedulers", - "modules/teacache", - "modules/seedvr", - - "modules/control/proc", - "modules/control/units", - "modules/control/units/xs_pipe.py", - "modules/postprocess/aurasr_arch.py", - - "pipelines/meissonic", - "pipelines/omnigen2", - "pipelines/hdm", - "pipelines/segmoe", - "pipelines/xomni", - "pipelines/chrono", - - "scripts/lbm", - "scripts/daam", - "scripts/xadapter", - "scripts/pulid", - "scripts/instantir", - "scripts/freescale", - "scripts/consistory", - - "repositories", - - "extensions-builtin/Lora", - "extensions-builtin/sd-extension-chainner/nodes", - "extensions-builtin/sd-webui-agent-scheduler", - "extensions-builtin/sdnext-modernui/node_modules", -] - -[lint] -select = [ - "F", - "E", - "W", - "C", - "B", - "I", - "YTT", - "ASYNC", - "RUF", - "AIR", - "NPY", - "C4", - "T10", - "EXE", - "ISC", - "ICN", - "RSE", - "TCH", - "TID", - "INT", - "PLE", -] -ignore = [ - "B006", # Do not use mutable data structures for argument defaults - "B008", # Do not perform function call in argument defaults - "B905", # Strict zip() usage - "C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead - "C408", # Unnecessary `dict` call - "I001", # Import block is un-sorted or un-formatted - "E402", # Module level import not at top of file - "E501", # Line too long - "E721", # Do not compare types, use `isinstance()` - "E731", # Do not assign a `lambda` expression, use a `def` - "E741", # Ambiguous variable name - "F401", # Imported by unused - "EXE001", # file with shebang is not marked executable - "NPY002", # replace legacy random - "RUF005", # Consider iterable unpacking - "RUF008", # Do not use mutable default values for dataclass - "RUF010", # Use explicit conversion flag - "RUF012", # Mutable class attributes - "RUF013", # PEP 484 prohibits implicit `Optional` - "RUF015", # Prefer `next(...)` over single element slice - "RUF046", # Value being cast to `int` is already an integer - "RUF059", # Unpacked variables are not used - "RUF051", # Prefer pop over del -] -fixable = ["ALL"] -unfixable = [] -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -[format] -quote-style = "double" -indent-style = "space" -skip-magic-trailing-comma = false -line-ending = "auto" -docstring-code-format = false - -[lint.mccabe] -max-complexity = 150 diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a9090ffc..96d74dc25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log for SD.Next -## Update for 2026-02-07 +## Update for 2026-02-09 - **Upscalers** - add support for [spandrel](https://github.com/chaiNNer-org/spandrel) @@ -11,12 +11,21 @@ - pipelines: add **ZImageInpaint**, thanks @CalamitousFelicitousness - add `--remote` command line flag that reduces client/server chatter and improves link stability for long-running generates, useful when running on remote servers + - hires: allow using different lora in refiner prompt + - **nunchaku** models are now listed in networks tab as reference models + instead of being used implicitly via quantization, thanks @CalamitousFelicitousness - **UI** - ui: **themes** add *CTD-NT64Light* and *CTD-NT64Dark*, thanks @resonantsky - ui: **gallery** add option to auto-refresh gallery, thanks @awsr - **Internal** + - refactor: switch to `pyproject.toml` for tool configs - refactor: reorganize `cli` scripts + - refactor: move tests to dedicated `/test/` + - update `lint` rules, thanks @awsr + - update `requirements` - **Fixes** + - fix: handle `clip` installer doing unwanted `setuptools` update + - fix: cleanup for `uv` installer fallback - fix: add metadata restore to always-on scripts - fix: improve wildcard weights parsing, thanks @Tillerz - fix: ui gallery cace recursive cleanup, thanks @awsr diff --git a/data/reference-nunchaku.json b/data/reference-nunchaku.json new file mode 100644 index 000000000..a22be9e91 --- /dev/null +++ b/data/reference-nunchaku.json @@ -0,0 +1,209 @@ +{ + "FLUX.1-Dev Nunchaku SVDQuant": { + "path": "black-forest-labs/FLUX.1-dev", + "subfolder": "nunchaku", + "preview": "black-forest-labs--FLUX.1-dev.jpg", + "desc": "Nunchaku SVDQuant quantization of FLUX.1-dev transformer with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "FLUX.1-Schnell Nunchaku SVDQuant": { + "path": "black-forest-labs/FLUX.1-schnell", + "subfolder": "nunchaku", + "preview": "black-forest-labs--FLUX.1-schnell.jpg", + "desc": "Nunchaku SVDQuant quantization of FLUX.1-schnell transformer with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "extras": "sampler: Default, cfg_scale: 1.0, steps: 4", + "size": 0, + "date": "2025 June" + }, + "FLUX.1-Kontext Nunchaku SVDQuant": { + "path": "black-forest-labs/FLUX.1-Kontext-dev", + "subfolder": "nunchaku", + "preview": "black-forest-labs--FLUX.1-Kontext-dev.jpg", + "desc": "Nunchaku SVDQuant quantization of FLUX.1-Kontext-dev transformer with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "FLUX.1-Krea Nunchaku SVDQuant": { + "path": "black-forest-labs/FLUX.1-Krea-dev", + "subfolder": "nunchaku", + "preview": "black-forest-labs--FLUX.1-Krea-dev.jpg", + "desc": "Nunchaku SVDQuant quantization of FLUX.1-Krea-dev transformer with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "FLUX.1-Fill Nunchaku SVDQuant": { + "path": "black-forest-labs/FLUX.1-Fill-dev", + "subfolder": "nunchaku", + "preview": "black-forest-labs--FLUX.1-Fill-dev.jpg", + "desc": "Nunchaku SVDQuant quantization of FLUX.1-Fill-dev transformer for inpainting", + "skip": true, + "hidden": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "FLUX.1-Depth Nunchaku SVDQuant": { + "path": "black-forest-labs/FLUX.1-Depth-dev", + "subfolder": "nunchaku", + "preview": "black-forest-labs--FLUX.1-Depth-dev.jpg", + "desc": "Nunchaku SVDQuant quantization of FLUX.1-Depth-dev transformer for depth-conditioned generation", + "skip": true, + "hidden": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "Shuttle Jaguar Nunchaku SVDQuant": { + "path": "shuttleai/shuttle-jaguar", + "subfolder": "nunchaku", + "preview": "shuttleai--shuttle-jaguar.jpg", + "desc": "Nunchaku SVDQuant quantization of Shuttle Jaguar transformer", + "skip": true, + "nunchaku": ["Model", "TE"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "Qwen-Image Nunchaku SVDQuant": { + "path": "Qwen/Qwen-Image", + "subfolder": "nunchaku", + "preview": "Qwen--Qwen-Image.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Image transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "Qwen-Lightning (8-step) Nunchaku SVDQuant": { + "path": "vladmandic/Qwen-Lightning", + "subfolder": "nunchaku", + "preview": "vladmandic--Qwen-Lightning.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Lightning (8-step distilled) transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "extras": "steps: 8", + "size": 0, + "date": "2025 June" + }, + "Qwen-Lightning (4-step) Nunchaku SVDQuant": { + "path": "vladmandic/Qwen-Lightning", + "subfolder": "nunchaku-4step", + "preview": "vladmandic--Qwen-Lightning.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Lightning (4-step distilled) transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "extras": "steps: 4", + "size": 0, + "date": "2025 June" + }, + "Qwen-Image-Edit Nunchaku SVDQuant": { + "path": "Qwen/Qwen-Image-Edit", + "subfolder": "nunchaku", + "preview": "Qwen--Qwen-Image-Edit.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Image-Edit transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "Qwen-Lightning-Edit (8-step) Nunchaku SVDQuant": { + "path": "vladmandic/Qwen-Lightning-Edit", + "subfolder": "nunchaku", + "preview": "vladmandic--Qwen-Lightning-Edit.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Lightning-Edit (8-step distilled editing) transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "extras": "steps: 8", + "size": 0, + "date": "2025 June" + }, + "Qwen-Lightning-Edit (4-step) Nunchaku SVDQuant": { + "path": "vladmandic/Qwen-Lightning-Edit", + "subfolder": "nunchaku-4step", + "preview": "vladmandic--Qwen-Lightning-Edit.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Lightning-Edit (4-step distilled editing) transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "extras": "steps: 4", + "size": 0, + "date": "2025 June" + }, + "Qwen-Image-Edit-2509 Nunchaku SVDQuant": { + "path": "Qwen/Qwen-Image-Edit-2509", + "subfolder": "nunchaku", + "preview": "Qwen--Qwen-Image-Edit-2509.jpg", + "desc": "Nunchaku SVDQuant quantization of Qwen-Image-Edit-2509 transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "size": 0, + "date": "2025 September" + }, + "Sana 1.6B 1k Nunchaku SVDQuant": { + "path": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + "subfolder": "nunchaku", + "preview": "Efficient-Large-Model--Sana_1600M_1024px_BF16_diffusers.jpg", + "desc": "Nunchaku SVDQuant quantization of Sana 1.6B 1024px transformer with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "Z-Image-Turbo Nunchaku SVDQuant": { + "path": "Tongyi-MAI/Z-Image-Turbo", + "subfolder": "nunchaku", + "preview": "Tongyi-MAI--Z-Image-Turbo.jpg", + "desc": "Nunchaku SVDQuant quantization of Z-Image-Turbo transformer with INT4 and SVD rank 128", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "extras": "sampler: Default, cfg_scale: 1.0, steps: 9", + "size": 0, + "date": "2025 June" + }, + "SDXL Base Nunchaku SVDQuant": { + "path": "stabilityai/stable-diffusion-xl-base-1.0", + "subfolder": "nunchaku", + "preview": "stabilityai--stable-diffusion-xl-base-1.0.jpg", + "desc": "Nunchaku SVDQuant quantization of SDXL Base 1.0 UNet with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "size": 0, + "date": "2025 June" + }, + "SDXL Turbo Nunchaku SVDQuant": { + "path": "stabilityai/sdxl-turbo", + "subfolder": "nunchaku", + "preview": "stabilityai--sdxl-turbo.jpg", + "desc": "Nunchaku SVDQuant quantization of SDXL Turbo UNet with INT4 and SVD rank 32", + "skip": true, + "nunchaku": ["Model"], + "tags": "nunchaku", + "extras": "sampler: Default, cfg_scale: 1.0, steps: 4", + "size": 0, + "date": "2025 June" + } +} diff --git a/html/locale_en.json b/html/locale_en.json index c2a6b7c79..da8278188 100644 --- a/html/locale_en.json +++ b/html/locale_en.json @@ -337,6 +337,8 @@ {"id":"","label":"Model Options","localized":"","reload":"","hint":"Settings related to behavior of specific models"}, {"id":"","label":"Model Offloading","localized":"","reload":"","hint":"Settings related to model offloading and memory management"}, {"id":"","label":"Model Quantization","localized":"","reload":"","hint":"Settings related to model quantization which is used to reduce memory usage"}, + {"id":"","label":"Nunchaku attention","localized":"","reload":"","hint":"Replaces default attention with Nunchaku's custom FP16 attention kernel for faster inference on consumer NVIDIA GPUs.
Might provide performance improvement on GPUs which have higher FP16 tensor cores throughput than BF16.

Currently only affects Flux-based models (Dev, Schnell, Kontext, Fill, Depth, etc.). Has no effect on Qwen, SDXL, Sana, or other architectures.

Disabled by default."}, + {"id":"","label":"Nunchaku offloading","localized":"","reload":"","hint":"Enables Nunchaku's own per-block CPU offloading with asynchronous CUDA streams to reduce VRAM usage.
Uses a ping-pong buffer strategy: while one transformer block computes on GPU, the next block preloads from CPU in the background, hiding most of the transfer latency.

Can reduce VRAM usage at the cost of slower inference.
This replaces SD.Next's pipeline offloading for the transformer component.

Only useful on low-VRAM GPUs. If your GPU has enough memory to hold the quantized model (16+ GB), keep this disabled for maximum speed.
Supports Flux and Qwen models. Not supported for SDXL where this setting is ignored.
Disabled by default."}, {"id":"","label":"Image Metadata","localized":"","reload":"","hint":"Settings related to handling of metadata that is created with generated images"}, {"id":"","label":"Legacy Options","localized":"","reload":"","hint":"Settings related to legacy options - should not be used"}, {"id":"","label":"Restart server","localized":"","reload":"","hint":"Restart server"}, diff --git a/html/logo-dark.svg b/html/logo-dark.svg new file mode 100644 index 000000000..c949d8847 --- /dev/null +++ b/html/logo-dark.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/html/logo-light.svg b/html/logo-light.svg new file mode 100644 index 000000000..62c70dd2a --- /dev/null +++ b/html/logo-light.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/html/logo-margins.png b/html/logo-margins.png deleted file mode 100644 index 5bbc3f252..000000000 Binary files a/html/logo-margins.png and /dev/null differ diff --git a/html/logo-robot.svg b/html/logo-robot.svg new file mode 100644 index 000000000..52d86b16d --- /dev/null +++ b/html/logo-robot.svg @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/html/logo-wm.png b/html/logo-wm.png deleted file mode 100644 index 1bfdbf53a..000000000 Binary files a/html/logo-wm.png and /dev/null differ diff --git a/html/logo.svg b/html/logo.svg deleted file mode 100644 index 737b53886..000000000 --- a/html/logo.svg +++ /dev/null @@ -1,62 +0,0 @@ - - - - diff --git a/installer.py b/installer.py index 6111d6679..84bce28b6 100644 --- a/installer.py +++ b/installer.py @@ -431,6 +431,26 @@ def run(cmd: str, arg: str): return txt +def cleanup_broken_packages(): + """Remove dist-info directories with missing RECORD files that uv may have left behind""" + try: + import site + for site_dir in site.getsitepackages(): + if not os.path.isdir(site_dir): + continue + for entry in os.listdir(site_dir): + if not entry.endswith('.dist-info'): + continue + dist_info = os.path.join(site_dir, entry) + record_file = os.path.join(dist_info, 'RECORD') + if not os.path.isfile(record_file): + pkg_name = entry.split('-')[0] + log.warning(f'Install: removing broken package metadata: {pkg_name} path={dist_info}') + shutil.rmtree(dist_info, ignore_errors=True) + except Exception: + pass + + def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True): t_start = time.time() originalArg = arg @@ -454,6 +474,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True): err = result.stderr.decode(encoding="utf8", errors="ignore") log.warning(f'Install: cmd="{pipCmd}" args="{all_args}" cannot use uv, fallback to pip') debug(f'Install: uv pip error: {err}') + cleanup_broken_packages() return pip(originalArg, ignore, quiet, uv=False) else: txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") @@ -468,7 +489,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True): # install package using pip if not already installed -def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False, force: bool = False): +def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False, force: bool = False, no_build_isolation: bool = False): t_start = time.time() res = '' if args.reinstall or args.upgrade: @@ -476,7 +497,8 @@ def install(package, friendly: str = None, ignore: bool = False, reinstall: bool quick_allowed = False if (args.reinstall) or (reinstall) or (not installed(package, friendly, quiet=quiet)): deps = '' if not no_deps else '--no-deps ' - cmd = f"install{' --upgrade' if not args.uv else ''}{' --force-reinstall' if force else ''} {deps}{package}" + isolation = '' if not no_build_isolation else '--no-build-isolation ' + cmd = f"install{' --upgrade' if not args.uv else ''}{' --force-reinstall' if force else ''} {deps}{isolation}{package}" res = pip(cmd, ignore=ignore, uv=package != "uv" and not package.startswith('git+')) try: importlib.reload(pkg_resources) @@ -665,7 +687,7 @@ def check_diffusers(): t_start = time.time() if args.skip_all: return - sha = '99e2cfff27dec514a43e260e885c5e6eca038b36' # diffusers commit hash + sha = '5bf248ddd8796b4f4958559429071a28f9b2dd3a' # diffusers commit hash # if args.use_rocm or args.use_zluda or args.use_directml: # sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now pkg = pkg_resources.working_set.by_key.get('diffusers', None) @@ -1100,8 +1122,9 @@ def install_packages(): pr.enable() # log.info('Install: verifying packages') clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git") - install(clip_package, 'clip', quiet=True) + install(clip_package, 'clip', quiet=True, no_build_isolation=True) install('open-clip-torch', no_deps=True, quiet=True) + install(clip_package, 'ftfy', quiet=True, no_build_isolation=True) # tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', 'tensorflow==2.13.0') # tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', None) # if tensorflow_package is not None: diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index acb305da8..eaf516ff2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -171,6 +171,12 @@ async function filterExtraNetworksForTab(searchTerm) { .toLowerCase() .includes('quantized') ? '' : 'none'; }); + } else if (searchTerm === 'nunchaku/') { + cards.forEach((elem) => { + elem.style.display = elem.dataset.tags + .toLowerCase() + .includes('nunchaku') ? '' : 'none'; + }); } else if (searchTerm === 'local/') { cards.forEach((elem) => { elem.style.display = elem.dataset.name diff --git a/javascript/gallery.js b/javascript/gallery.js index 51ce478f9..5361cd8c7 100644 --- a/javascript/gallery.js +++ b/javascript/gallery.js @@ -1353,6 +1353,7 @@ async function initGallery() { // triggered on gradio change to monitor when ui monitorGalleries(); updateFolders(); + initGalleryAutoRefresh(); [ 'browser_folders', 'outdir_samples', diff --git a/modules/extra_networks.py b/modules/extra_networks.py index eab3ab7ad..01913b187 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -156,7 +156,6 @@ def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNe if prompt is None: return "", res if isinstance(prompt, list): - shared.log.warning(f"parse_prompt was called with a list instead of a string: {prompt}") return parse_prompts(prompt) def found(m: re.Match[str]): @@ -168,13 +167,17 @@ def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNe return updated_prompt, res -def parse_prompts(prompts: list[str]): +def parse_prompts(prompts: list[str], extra_data=None): updated_prompt_list: list[str] = [] - extra_data: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list) + extra_data: defaultdict[str, list[ExtraNetworkParams]] = extra_data or defaultdict(list) for prompt in prompts: updated_prompt, parsed_extra_data = parse_prompt(prompt) if not extra_data: extra_data = parsed_extra_data + elif parsed_extra_data: + extra_data = parsed_extra_data + else: + pass updated_prompt_list.append(updated_prompt) return updated_prompt_list, extra_data diff --git a/modules/images.py b/modules/images.py index 64b51401a..5e4c5c9dd 100644 --- a/modules/images.py +++ b/modules/images.py @@ -12,24 +12,10 @@ import piexif import piexif.helper from PIL import Image, PngImagePlugin, ExifTags, ImageDraw from modules import sd_samplers, shared, script_callbacks, errors, paths -from modules.images_grid import ( - image_grid as image_grid, - get_grid_size as get_grid_size, - split_grid as split_grid, - combine_grid as combine_grid, - check_grid_size as check_grid_size, - get_font as get_font, - draw_grid_annotations as draw_grid_annotations, - draw_prompt_matrix as draw_prompt_matrix, - GridAnnotation as GridAnnotation, - Grid as Grid, -) -from modules.images_resize import resize_image as resize_image -from modules.images_namegen import ( - FilenameGenerator as FilenameGenerator, - get_next_sequence_number as get_next_sequence_number, -) -from modules.video import save_video as save_video +from modules.images_grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import +from modules.images_resize import resize_image # pylint: disable=unused-import +from modules.images_namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import +from modules.video import save_video # pylint: disable=unused-import debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None diff --git a/modules/mit_nunchaku.py b/modules/mit_nunchaku.py index b5e82c1da..6b4e524cc 100644 --- a/modules/mit_nunchaku.py +++ b/modules/mit_nunchaku.py @@ -4,10 +4,27 @@ from installer import log, pip from modules import devices -nunchaku_ver = '1.1.0' +nunchaku_versions = { + '2.5': '1.0.1', + '2.6': '1.0.1', + '2.7': '1.1.0', + '2.8': '1.1.0', + '2.9': '1.1.0', + '2.10': '1.0.2', + '2.11': '1.1.0', +} ok = False +def _expected_ver(): + try: + import torch + torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2]) + return nunchaku_versions.get(torch_ver) + except Exception: + return None + + def check(): global ok # pylint: disable=global-statement if ok: @@ -16,8 +33,9 @@ def check(): import nunchaku import nunchaku.utils from nunchaku import __version__ + expected = _expected_ver() log.info(f'Nunchaku: path={nunchaku.__path__} version={__version__.__version__} precision={nunchaku.utils.get_precision()}') - if __version__.__version__ != nunchaku_ver: + if expected is not None and __version__.__version__ != expected: ok = False return False ok = True @@ -49,14 +67,16 @@ def install_nunchaku(): if devices.backend not in ['cuda']: log.error(f'Nunchaku: backend={devices.backend} unsupported') return False - torch_ver = torch.__version__[:3] - if torch_ver not in ['2.5', '2.6', '2.7', '2.8', '2.9', '2.10']: + torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2]) + nunchaku_ver = nunchaku_versions.get(torch_ver) + if nunchaku_ver is None: log.error(f'Nunchaku: torch={torch.__version__} unsupported') + return False suffix = 'x86_64' if arch == 'linux' else 'win_amd64' url = os.environ.get('NUNCHAKU_COMMAND', None) if url is None: arch = f'{arch}_' if arch == 'linux' else '' - url = f'https://huggingface.co/nunchaku-tech/nunchaku/resolve/main/nunchaku-{nunchaku_ver}' + url = f'https://huggingface.co/nunchaku-ai/nunchaku/resolve/main/nunchaku-{nunchaku_ver}' url += f'+torch{torch_ver}-cp{python_ver}-cp{python_ver}-{arch}{suffix}.whl' cmd = f'install --upgrade {url}' log.debug(f'Nunchaku: install="{url}"') diff --git a/modules/model_quant.py b/modules/model_quant.py index 1a501be0a..3cad91181 100644 --- a/modules/model_quant.py +++ b/modules/model_quant.py @@ -255,13 +255,25 @@ def check_quant(module: str = ''): def check_nunchaku(module: str = ''): from modules import shared - if module not in shared.opts.nunchaku_quantization: + model_name = getattr(shared.opts, 'sd_model_checkpoint', '') + if '+nunchaku' not in model_name: return False - from modules import mit_nunchaku - mit_nunchaku.install_nunchaku() - if not mit_nunchaku.ok: - return False - return True + base_path = model_name.split('+')[0] + for v in shared.reference_models.values(): + if v.get('path', '') != base_path: + continue + nunchaku_modules = v.get('nunchaku', None) + if nunchaku_modules is None: + continue + if isinstance(nunchaku_modules, bool) and nunchaku_modules: + nunchaku_modules = ['Model', 'TE'] + if not isinstance(nunchaku_modules, list): + continue + if module in nunchaku_modules: + from modules import mit_nunchaku + mit_nunchaku.install_nunchaku() + return mit_nunchaku.ok + return False def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None, modules_dtype_dict: dict = None): diff --git a/modules/processing.py b/modules/processing.py index 3c6595e68..13f2c275c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -270,7 +270,6 @@ def process_init(p: StableDiffusionProcessing): p.all_prompts, p.all_negative_prompts = shared.prompt_styles.apply_styles_to_prompts(p.all_prompts, p.all_negative_prompts, p.styles, p.all_seeds) p.prompts = p.all_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)] p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)] - p.prompts, _ = extra_networks.parse_prompts(p.prompts) def process_samples(p: StableDiffusionProcessing, samples): diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index a410497e2..04ae68ea2 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -171,6 +171,7 @@ def process_base(p: processing.StableDiffusionProcessing): modelstats.analyze() try: t0 = time.time() + p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts, p.network_data) extra_networks.activate(p, exclude=['text_encoder', 'text_encoder_2', 'text_encoder_3']) if hasattr(shared.sd_model, 'tgate') and getattr(p, 'gate_step', -1) > 0: @@ -297,10 +298,20 @@ def process_hires(p: processing.StableDiffusionProcessing, output): p.denoising_strength = strength orig_image = p.task_args.pop('image', None) # remove image override from hires process_pre(p) + + prompts = p.prompts + reset_prompts = False + if len(p.refiner_prompt) > 0: + prompts = len(output.images)* [p.refiner_prompt] + prompts, p.network_data = extra_networks.parse_prompts(prompts) + reset_prompts = True + if reset_prompts or ('base' in p.skip): + extra_networks.activate(p) + hires_args = set_pipeline_args( p=p, model=shared.sd_model, - prompts=len(output.images)* [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, + prompts=prompts, negative_prompts=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, @@ -314,11 +325,10 @@ def process_hires(p: processing.StableDiffusionProcessing, output): strength=strength, desc='Hires', ) + hires_steps = hires_args.get('prior_num_inference_steps', None) or p.hr_second_pass_steps or hires_args.get('num_inference_steps', None) shared.state.update(get_job_name(p, shared.sd_model), hires_steps, 1) try: - if 'base' in p.skip: - extra_networks.activate(p) taskid = shared.state.begin('Inference') output = shared.sd_model(**hires_args) # pylint: disable=not-callable shared.state.end(taskid) diff --git a/modules/sd_unet.py b/modules/sd_unet.py index c73ca8dc5..643d1c08e 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -18,14 +18,15 @@ def load_unet_sdxl_nunchaku(repo_id): shared.log.error(f'Load module: quant=Nunchaku module=unet repo="{repo_id}" low nunchaku version') return None if 'turbo' in repo_id.lower(): - nunchaku_repo = 'nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors' + nunchaku_repo = 'nunchaku-ai/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors' else: - nunchaku_repo = 'nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors' + nunchaku_repo = 'nunchaku-ai/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors' - shared.log.debug(f'Load module: quant=Nunchaku module=unet repo="{nunchaku_repo}" offload={shared.opts.nunchaku_offload}') + if shared.opts.nunchaku_offload: + shared.log.warning('Load module: quant=Nunchaku module=unet offload not supported for SDXL, ignoring') + shared.log.debug(f'Load module: quant=Nunchaku module=unet repo="{nunchaku_repo}"') unet = NunchakuSDXLUNet2DConditionModel.from_pretrained( nunchaku_repo, - offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, ) diff --git a/modules/shared.py b/modules/shared.py index dcf345f78..3c05030ad 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,53 +4,23 @@ import os import sys import time import contextlib + from enum import Enum from typing import TYPE_CHECKING - import gradio as gr - -from installer import ( - log as log, - print_dict, - console as console, - get_version as get_version, -) - -log.debug("Initializing: shared module") +from installer import log, print_dict, console, get_version # pylint: disable=unused-import +log.debug('Initializing: shared module') import modules.memmon import modules.paths as paths -from modules.json_helpers import ( - readfile as readfile, - writefile as writefile, -) -from modules.shared_helpers import ( - listdir as listdir, - walk_files as walk_files, - html_path as html_path, - html as html, - req as req, - total_tqdm as total_tqdm, -) +from modules.json_helpers import readfile, writefile # pylint: disable=W0611 +from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611 from modules import errors, devices, shared_state, cmd_args, theme, history, files_cache from modules.shared_defaults import get_default_modes -from modules.paths import ( - models_path as models_path, # For compatibility, do not modify from here... - script_path as script_path, - data_path as data_path, - sd_configs_path as sd_configs_path, - sd_default_config as sd_default_config, - sd_model_file as sd_model_file, - default_sd_model_file as default_sd_model_file, - extensions_dir as extensions_dir, - extensions_builtin_dir as extensions_builtin_dir, # ... to here. -) -from modules.memstats import ( - memory_stats, - ram_stats as ram_stats, -) +from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611 +from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import -log.debug("Initializing: pipelines") +log.debug('Initializing: pipelines') from modules import shared_items from modules.interrogate.openclip import caption_models, caption_types, get_clip_models, refresh_clip_models from modules.interrogate.vqa import vlm_models, vlm_prompts, vlm_system, vlm_default @@ -280,7 +250,6 @@ options_templates.update(options_section(("quantization", "Model Quantization"), "sdnq_quantize_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox), "nunchaku_sep": OptionInfo("

Nunchaku Engine

", "", gr.HTML), - "nunchaku_quantization": OptionInfo([], "SVDQuant enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}), "nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox), "nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 3de17896f..d98e7a14b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -305,6 +305,7 @@ class ExtraNetworksPage: subdirs['Reference'] = 1 subdirs['Distilled'] = 1 subdirs['Quantized'] = 1 + subdirs['Nunchaku'] = 1 subdirs['Community'] = 1 subdirs['Cloud'] = 1 subdirs[diffusers_base] = 1 @@ -324,6 +325,8 @@ class ExtraNetworksPage: subdirs.move_to_end('Distilled', last=True) if 'Quantized' in subdirs: subdirs.move_to_end('Quantized', last=True) + if 'Nunchaku' in subdirs: + subdirs.move_to_end('Nunchaku', last=True) if 'Community' in subdirs: subdirs.move_to_end('Community', last=True) if 'Cloud' in subdirs: @@ -332,7 +335,7 @@ class ExtraNetworksPage: for subdir in subdirs: if len(subdir) == 0: continue - if subdir in ['All', 'Local', 'Diffusers', 'Reference', 'Distilled', 'Quantized', 'Community', 'Cloud']: + if subdir in ['All', 'Local', 'Diffusers', 'Reference', 'Distilled', 'Quantized', 'Nunchaku', 'Community', 'Cloud']: style = 'network-reference' else: style = 'network-folder' diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 9470ec9bb..65078ef97 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,7 +3,7 @@ import html import json import concurrent from datetime import datetime -from modules import shared, ui_extra_networks, sd_models, modelstats, paths +from modules import shared, ui_extra_networks, sd_models, modelstats, paths, devices from modules.json_helpers import readfile @@ -48,16 +48,21 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): reference_distilled = readfile(os.path.join('data', 'reference-distilled.json'), as_type="dict") reference_community = readfile(os.path.join('data', 'reference-community.json'), as_type="dict") reference_cloud = readfile(os.path.join('data', 'reference-cloud.json'), as_type="dict") + reference_nunchaku = readfile(os.path.join('data', 'reference-nunchaku.json'), as_type="dict") shared.reference_models = {} shared.reference_models.update(reference_base) shared.reference_models.update(reference_quant) shared.reference_models.update(reference_community) shared.reference_models.update(reference_distilled) shared.reference_models.update(reference_cloud) + shared.reference_models.update(reference_nunchaku) for k, v in shared.reference_models.items(): count['total'] += 1 url = v['path'] + if v.get('hidden', False): + count['hidden'] += 1 + continue experimental = v.get('experimental', False) if experimental: if shared.cmd_opts.experimental: @@ -83,6 +88,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): path = f'{v.get("path", "")}' tag = v.get('tags', '') + if tag == 'nunchaku' and (devices.backend != 'cuda' and not shared.cmd_opts.experimental): + count['hidden'] += 1 + continue if tag in count: count[tag] += 1 elif tag != '': diff --git a/pipelines/flux/flux_nunchaku.py b/pipelines/flux/flux_nunchaku.py index 9a737d103..d8761e186 100644 --- a/pipelines/flux/flux_nunchaku.py +++ b/pipelines/flux/flux_nunchaku.py @@ -9,19 +9,19 @@ def load_flux_nunchaku(repo_id): if 'srpo' in repo_id.lower(): pass elif 'flux.1-dev' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors" elif 'flux.1-schnell' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors" elif 'flux.1-kontext' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors" elif 'flux.1-krea' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-krea-dev/svdq-{nunchaku_precision}_r32-flux.1-krea-dev.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-krea-dev/svdq-{nunchaku_precision}_r32-flux.1-krea-dev.safetensors" elif 'flux.1-fill' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-fill-dev/svdq-{nunchaku_precision}-flux.1-fill-dev.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-fill-dev/svdq-{nunchaku_precision}-flux.1-fill-dev.safetensors" elif 'flux.1-depth' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{nunchaku_precision}-flux.1-depth-dev.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-depth-dev/svdq-{nunchaku_precision}-flux.1-depth-dev.safetensors" elif 'shuttle' in repo_id.lower(): - nunchaku_repo = f"nunchaku-tech/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}-shuttle-jaguar.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}-shuttle-jaguar.safetensors" else: shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported') if nunchaku_repo is not None: diff --git a/pipelines/generic.py b/pipelines/generic.py index 6ad9a3fd5..3b1b4bc66 100644 --- a/pipelines/generic.py +++ b/pipelines/generic.py @@ -152,7 +152,7 @@ def load_text_encoder(repo_id, cls_name, load_config=None, subfolder="text_encod elif cls_name == transformers.T5EncoderModel and allow_shared and shared.opts.te_shared_t5: if model_quant.check_nunchaku('TE'): import nunchaku - repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors' + repo_id = 'nunchaku-ai/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors' cls_name = nunchaku.NunchakuT5EncoderModel shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant" loader={_loader("transformers")}') text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained( diff --git a/pipelines/model_qwen.py b/pipelines/model_qwen.py index 3bea5c121..546e755cc 100644 --- a/pipelines/model_qwen.py +++ b/pipelines/model_qwen.py @@ -37,7 +37,7 @@ def load_qwen(checkpoint_info, diffusers_load_config=None): diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageInpaintPipeline if model_quant.check_nunchaku('Model'): - transformer = qwen.load_qwen_nunchaku(repo_id) + transformer = qwen.load_qwen_nunchaku(repo_id, subfolder=repo_subfolder) if 'Qwen-Image-Distill-Full' in repo_id: repo_transformer = repo_id @@ -63,6 +63,8 @@ def load_qwen(checkpoint_info, diffusers_load_config=None): text_encoder = generic.load_text_encoder(repo_te, cls_name=transformers.Qwen2_5_VLForConditionalGeneration, load_config=diffusers_load_config) repo_id, repo_subfolder = qwen.check_qwen_pruning(repo_id, repo_subfolder) + if repo_subfolder is not None and repo_subfolder.startswith('nunchaku'): + repo_subfolder = None pipe = cls_name.from_pretrained( repo_id, transformer=transformer, diff --git a/pipelines/model_sana.py b/pipelines/model_sana.py index 73da28472..8b04fcfb0 100644 --- a/pipelines/model_sana.py +++ b/pipelines/model_sana.py @@ -9,7 +9,7 @@ def load_quants(kwargs, repo_id, cache_dir): if 'Sana_1600M_1024px' in repo_id and model_quant.check_nunchaku('Model'): # only available model import nunchaku nunchaku_precision = nunchaku.utils.get_precision() - nunchaku_repo = "nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors" + nunchaku_repo = "nunchaku-ai/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors" shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} attention={shared.opts.nunchaku_attention}') kwargs['transformer'] = nunchaku.NunchakuSanaTransformer2DModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir) elif model_quant.check_quant('Model'): diff --git a/pipelines/model_z_image.py b/pipelines/model_z_image.py index 9f8dd51e1..3aa269b48 100644 --- a/pipelines/model_z_image.py +++ b/pipelines/model_z_image.py @@ -8,7 +8,7 @@ def load_nunchaku(): import nunchaku nunchaku_precision = nunchaku.utils.get_precision() nunchaku_rank = 128 - nunchaku_repo = f"nunchaku-tech/nunchaku-z-image-turbo/svdq-{nunchaku_precision}_r{nunchaku_rank}-z-image-turbo.safetensors" + nunchaku_repo = f"nunchaku-ai/nunchaku-z-image-turbo/svdq-{nunchaku_precision}_r{nunchaku_rank}-z-image-turbo.safetensors" shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" attention={shared.opts.nunchaku_attention}') transformer = nunchaku.NunchakuZImageTransformer2DModel.from_pretrained( nunchaku_repo, diff --git a/pipelines/qwen/qwen_nunchaku.py b/pipelines/qwen/qwen_nunchaku.py index 4c89b7b1c..4fd964df3 100644 --- a/pipelines/qwen/qwen_nunchaku.py +++ b/pipelines/qwen/qwen_nunchaku.py @@ -1,11 +1,12 @@ from modules import shared, devices -def load_qwen_nunchaku(repo_id): +def load_qwen_nunchaku(repo_id, subfolder=None): import nunchaku nunchaku_precision = nunchaku.utils.get_precision() nunchaku_repo = None transformer = None + four_step = subfolder is not None and '4step' in subfolder try: from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel except Exception: @@ -14,15 +15,21 @@ def load_qwen_nunchaku(repo_id): if 'pruning' in repo_id.lower() or 'distill' in repo_id.lower(): return None elif repo_id.lower().endswith('qwen-image'): - nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors" # r32 vs r128 + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors" elif repo_id.lower().endswith('qwen-lightning'): - nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.1-8steps.safetensors" # 8-step variant + if four_step: + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.0-4steps.safetensors" + else: + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.1-8steps.safetensors" elif repo_id.lower().endswith('qwen-image-edit-2509'): - nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit-2509/svdq-{nunchaku_precision}_r128-qwen-image-edit-2509.safetensors" # 8-step variant + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit-2509/svdq-{nunchaku_precision}_r128-qwen-image-edit-2509.safetensors" elif repo_id.lower().endswith('qwen-image-edit'): - nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit.safetensors" # 8-step variant + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit.safetensors" elif repo_id.lower().endswith('qwen-lightning-edit'): - nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-8steps.safetensors" # 8-step variant + if four_step: + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-4steps.safetensors" + else: + nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-8steps.safetensors" else: shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported') if nunchaku_repo is not None: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..05bf11ff5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,360 @@ +[project] +name = "SD.Next" +version = "0.0.0" +description = "SD.Next: All-in-one WebUI for AI generative image and video creation and captioning" +readme = "README.md" +requires-python = ">=3.10" + +[tool.ruff] +line-length = 250 +indent-width = 4 +target-version = "py310" +exclude = [ + "venv", + ".git", + ".ruff_cache", + ".vscode", + + "modules/cfgzero", + "modules/flash_attn_triton_amd", + "modules/hidiffusion", + "modules/intel/ipex", + "modules/pag", + "modules/schedulers", + "modules/teacache", + "modules/seedvr", + + "modules/control/proc", + "modules/control/units", + "modules/control/units/xs_pipe.py", + "modules/postprocess/aurasr_arch.py", + + "pipelines/meissonic", + "pipelines/omnigen2", + "pipelines/hdm", + "pipelines/segmoe", + "pipelines/xomni", + "pipelines/chrono", + + "scripts/lbm", + "scripts/daam", + "scripts/xadapter", + "scripts/pulid", + "scripts/instantir", + "scripts/freescale", + "scripts/consistory", + + "repositories", + + "extensions-builtin/Lora", + "extensions-builtin/sd-extension-chainner/nodes", + "extensions-builtin/sd-webui-agent-scheduler", + "extensions-builtin/sdnext-modernui/node_modules", +] + +[tool.ruff.lint] +select = [ + "F", + "E", + "W", + "C", + "B", + "I", + "YTT", + "ASYNC", + "RUF", + "AIR", + "NPY", + "C4", + "T10", + "EXE", + "ISC", + "ICN", + "RSE", + "TC", + "TID", + "INT", + "PLE", +] +ignore = [ + "B006", # Do not use mutable data structures for argument defaults + "B008", # Do not perform function call in argument defaults + "B905", # Strict zip() usage + "C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead + "C408", # Unnecessary `dict` call + "I001", # Import block is un-sorted or un-formatted + "E402", # Module level import not at top of file + "E501", # Line too long + "E721", # Do not compare types, use `isinstance()` + "E731", # Do not assign a `lambda` expression, use a `def` + "E741", # Ambiguous variable name + "F401", # Imported by unused + "EXE001", # file with shebang is not marked executable + "NPY002", # replace legacy random + "RUF005", # Consider iterable unpacking + "RUF008", # Do not use mutable default values for dataclass + "RUF010", # Use explicit conversion flag + "RUF012", # Mutable class attributes + "RUF013", # PEP 484 prohibits implicit `Optional` + "RUF015", # Prefer `next(...)` over single element slice + "RUF046", # Value being cast to `int` is already an integer + "RUF059", # Unpacked variables are not used + "RUF051", # Prefer pop over del +] +fixable = ["ALL"] +unfixable = [] +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = false + +[tool.ruff.lint.mccabe] +max-complexity = 150 + +[tool.pylint] +main.py-version="3.10" +main.analyse-fallback-blocks=false +main.clear-cache-post-run=false +main.extension-pkg-allow-list="" +main.prefer-stubs=true +main.extension-pkg-whitelist="" +main.fail-on="" +main.fail-under=10 +main.ignore="CVS" +main.ignore-paths=[ + "venv", + ".git", + ".ruff_cache", + ".vscode", + "modules/apg", + "modules/cfgzero", + "modules/control/proc", + "modules/control/units", + "modules/dml", + "modules/flash_attn_triton_amd", + "modules/ggml", + "modules/hidiffusion", + "modules/hijack/ddpm_edit.py", + "modules/intel", + "modules/intel/ipex", + "modules/framepack/pipeline", + "modules/onnx_impl", + "modules/pag", + "modules/postprocess/aurasr_arch.py", + "modules/prompt_parser_xhinker.py", + "modules/ras", + "modules/seedvr", + "modules/rife", + "modules/schedulers", + "modules/taesd", + "modules/teacache", + "modules/todo", + "modules/res4lyf", + "pipelines/bria", + "pipelines/flex2", + "pipelines/f_lite", + "pipelines/hidream", + "pipelines/hdm", + "pipelines/meissonic", + "pipelines/omnigen2", + "pipelines/segmoe", + "pipelines/xomni", + "pipelines/chrono", + "scripts/consistory", + "scripts/ctrlx", + "scripts/daam", + "scripts/demofusion", + "scripts/freescale", + "scripts/infiniteyou", + "scripts/instantir", + "scripts/lbm", + "scripts/layerdiffuse", + "scripts/mod", + "scripts/pixelsmith", + "scripts/differential_diffusion.py", + "scripts/pulid", + "scripts/xadapter", + "repositories", + "extensions-builtin/sd-extension-chainner/nodes", + "extensions-builtin/sd-webui-agent-scheduler", + "extensions-builtin/sdnext-modernui/node_modules", + "extensions-builtin/sdnext-kanvas/node_modules", + ] +main.ignore-patterns=[ + ".*test*.py$", + ".*_model.py$", + ".*_arch.py$", + ".*_model_arch.py*", + ".*_model_arch_v2.py$", + ] +main.ignored-modules="" +main.jobs=8 +main.limit-inference-results=100 +main.load-plugins="" +main.persistent=false +main.recursive=false +main.source-roots="" +main.unsafe-load-any-extension=false +basic.argument-naming-style="snake_case" +basic.attr-naming-style="snake_case" +basic.bad-names=["foo", "bar", "baz", "toto", "tutu", "tata"] +basic.bad-names-rgxs="" +basic.class-attribute-naming-style="any" +basic.class-const-naming-style="UPPER_CASE" +basic.class-naming-style="PascalCase" +basic.const-naming-style="snake_case" +basic.docstring-min-length=-1 +basic.function-naming-style="snake_case" +basic.good-names=["i","j","k","e","ex","ok","p","x","y","id"] +basic.good-names-rgxs="" +basic.include-naming-hint=false +basic.inlinevar-naming-style="any" +basic.method-naming-style="snake_case" +basic.module-naming-style="snake_case" +basic.name-group="" +basic.no-docstring-rgx="^_" +basic.property-classes="abc.abstractproperty" +basic.variable-naming-style="snake_case" +classes.check-protected-access-in-special-methods=false +classes.defining-attr-methods=["__init__", "__new__"] +classes.exclude-protected=["_asdict","_fields","_replace","_source","_make","os._exit"] +classes.valid-classmethod-first-arg="cls" +classes.valid-metaclass-classmethod-first-arg="mcs" +design.exclude-too-few-public-methods="" +design.ignored-parents="" +design.max-args=199 +design.max-attributes=99 +design.max-bool-expr=99 +design.max-branches=199 +design.max-locals=99 +design.max-parents=99 +design.max-public-methods=99 +design.max-returns=99 +design.max-statements=199 +design.min-public-methods=1 +exceptions.overgeneral-exceptions=["builtins.BaseException","builtins.Exception"] +format.expected-line-ending-format="" +# format.ignore-long-lines="^\s*(# )??$" +format.indent-after-paren=4 +format.indent-string=' ' +format.max-line-length=200 +format.max-module-lines=9999 +format.single-line-class-stmt=false +format.single-line-if-stmt=false +imports.allow-any-import-level="" +imports.allow-reexport-from-package=false +imports.allow-wildcard-with-all=false +imports.deprecated-modules="" +imports.ext-import-graph="" +imports.import-graph="" +imports.int-import-graph="" +imports.known-standard-library="" +imports.known-third-party="enchant" +imports.preferred-modules="" +logging.logging-format-style="new" +logging.logging-modules="logging" +messages_control.confidence=["HIGH","CONTROL_FLOW","INFERENCE","INFERENCE_FAILURE","UNDEFINED"] +messages_control.disable=[ + "abstract-method", + "bad-inline-option", + "bare-except", + "broad-exception-caught", + "chained-comparison", + "consider-iterating-dictionary", + "consider-merging-isinstance", + "consider-using-dict-items", + "consider-using-enumerate", + "consider-using-from-import", + "consider-using-generator", + "consider-using-get", + "consider-using-in", + "consider-using-max-builtin", + "consider-using-min-builtin", + "consider-using-sys-exit", + "cyclic-import", + "dangerous-default-value", + "deprecated-pragma", + "duplicate-code", + "file-ignored", + "import-error", + "import-outside-toplevel", + "invalid-name", + "line-too-long", + "locally-disabled", + "logging-fstring-interpolation", + "missing-class-docstring", + "missing-function-docstring", + "missing-module-docstring", + "no-else-raise", + "no-else-return", + "not-callable", + "pointless-string-statement", + "raw-checker-failed", + "simplifiable-if-expression", + "suppressed-message", + "too-few-public-methods", + "too-many-instance-attributes", + "too-many-locals", + "too-many-nested-blocks", + "too-many-positional-arguments", + "too-many-statements", + "unidiomatic-typecheck", + "unknown-option-value", + "unnecessary-dict-index-lookup", + "unnecessary-dunder-call", + "unnecessary-lambda-assigment", + "unnecessary-lambda", + "unused-wildcard-import", + "unpacking-non-sequence", + "unsubscriptable-object", + "useless-return", + "use-dict-literal", + "use-symbolic-message-instead", + "useless-suppression", + "wrong-import-position", + ] +messages_control.enable="c-extension-no-member" +method_args.timeout-methods=["requests.api.delete","requests.api.get","requests.api.head","requests.api.options","requests.api.patch","requests.api.post","requests.api.put","requests.api.request"] +miscellaneous.notes=["FIXME","XXX","TODO"] +miscellaneous.notes-rgx="" +refactoring.max-nested-blocks=5 +refactoring.never-returning-functions=["sys.exit","argparse.parse_error"] +reports.msg-template="" +reports.reports=false +reports.score=false +similarities.ignore-comments=true +similarities.ignore-docstrings=true +similarities.ignore-imports=true +similarities.ignore-signatures=true +similarities.min-similarity-lines=4 +spelling.max-spelling-suggestions=4 +# spelling.dict="" +# spelling.ignore-comment-directives=["fmt: on","fmt: off","noqa:","noqa","nosec","isort:skip","mypy:"] +# spelling.ignore-words="" +# spelling.private-dict-file="" +# spelling.store-unknown-words=false +string.check-quote-consistency=false +string.check-str-concat-over-line-jumps=false +typecheck.contextmanager-decorators="contextlib.contextmanager" +typecheck.generated-members=["numpy.*","logging.*","torch.*","cv2.*"] +typecheck.ignore-none=true +typecheck.ignore-on-opaque-inference=true +typecheck.ignored-checks-for-mixins=["no-member","not-async-context-manager","not-context-manager","attribute-defined-outside-init"] +typecheck.ignored-classes=["optparse.Values","thread._local","_thread._local","argparse.Namespace","unittest.case._AssertRaisesContext","unittest.case._AssertWarnsContext"] +typecheck.missing-member-hint=true +typecheck.missing-member-hint-distance=1 +typecheck.missing-member-max-choices=1 +typecheck.mixin-class-rgx=".*[Mm]ixin" +typecheck.signature-mutators="" +variables.additional-builtins="" +variables.allow-global-unused-variables=true +variables.allowed-redefined-builtins="" +variables.callbacks=["cb_",] +variables.dummy-variables-rgx="_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" +variables.ignored-argument-names="_.*|^ignored_|^unused_" +variables.init-import=false +variables.redefining-builtins-modules=["six.moves","past.builtins","future.builtins","builtins","io"] diff --git a/requirements.txt b/requirements.txt index 227e20fa4..8aede7efd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -55,7 +55,7 @@ protobuf==4.25.3 pytorch_lightning==2.6.0 urllib3==1.26.19 Pillow==10.4.0 -timm==1.0.16 +timm==1.0.24 pyparsing==3.2.3 typing-extensions==4.14.1 sentencepiece==0.2.1 diff --git a/scripts/flux_tools.py b/scripts/flux_tools.py index aa0358ce1..b2e87c36c 100644 --- a/scripts/flux_tools.py +++ b/scripts/flux_tools.py @@ -25,7 +25,7 @@ class Script(scripts_manager.Script): with gr.Row(): gr.HTML('  Flux.1 Redux
') with gr.Row(): - tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Canny', 'Depth'], value='None') + tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Fill (Nunchaku)', 'Canny', 'Depth', 'Depth (Nunchaku)'], value='None') with gr.Row(): prompt = gr.Slider(label='Redux prompt strength', minimum=0, maximum=2, step=0.01, value=0, visible=False) process = gr.Checkbox(label='Control preprocess input images', value=True, visible=False) @@ -34,8 +34,8 @@ class Script(scripts_manager.Script): def display(tool): return [ gr.update(visible=tool in ['Redux']), - gr.update(visible=tool in ['Canny', 'Depth']), - gr.update(visible=tool in ['Canny', 'Depth']), + gr.update(visible=tool in ['Canny', 'Depth', 'Depth (Nunchaku)']), + gr.update(visible=tool in ['Canny', 'Depth', 'Depth (Nunchaku)']), ] tool.change(fn=display, inputs=[tool], outputs=[prompt, process, strength]) @@ -91,13 +91,15 @@ class Script(scripts_manager.Script): shared.log.debug(f'{title}: tool=Redux unload') redux_pipe = None - if tool == 'Fill': + if tool in ['Fill', 'Fill (Nunchaku)']: # pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, revision="refs/pr/4").to("cuda") if p.image_mask is None: shared.log.error(f'{title}: tool={tool} no image_mask') return None - if shared.sd_model.__class__.__name__ != 'FluxFillPipeline': - shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Fill-dev" + nunchaku_suffix = '+nunchaku' if tool == 'Fill (Nunchaku)' else '' + checkpoint = f"black-forest-labs/FLUX.1-Fill-dev{nunchaku_suffix}" + if shared.sd_model.__class__.__name__ != 'FluxFillPipeline' or shared.opts.sd_model_checkpoint != checkpoint: + shared.opts.data["sd_model_checkpoint"] = checkpoint sd_models.reload_model_weights(op='model', revision="refs/pr/4") p.task_args['image'] = image p.task_args['mask_image'] = p.image_mask @@ -124,11 +126,13 @@ class Script(scripts_manager.Script): shared.log.debug(f'{title}: tool=Canny unload processor') processor_canny = None - if tool == 'Depth': + if tool in ['Depth', 'Depth (Nunchaku)']: # pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda") install('git+https://github.com/huggingface/image_gen_aux.git', 'image_gen_aux') - if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Depth' not in shared.opts.sd_model_checkpoint: - shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Depth-dev" + nunchaku_suffix = '+nunchaku' if tool == 'Depth (Nunchaku)' else '' + checkpoint = f"black-forest-labs/FLUX.1-Depth-dev{nunchaku_suffix}" + if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or shared.opts.sd_model_checkpoint != checkpoint: + shared.opts.data["sd_model_checkpoint"] = checkpoint sd_models.reload_model_weights(op='model', revision="refs/pr/1") if processor_depth is None: from image_gen_aux import DepthPreprocessor diff --git a/cli/benchmark_attention.py b/test/benchmark_attention.py similarity index 87% rename from cli/benchmark_attention.py rename to test/benchmark_attention.py index dc20eb1f0..39b0be5b0 100644 --- a/cli/benchmark_attention.py +++ b/test/benchmark_attention.py @@ -1,18 +1,25 @@ +from typing import Dict, Any import time import warnings import torch import torch.nn.functional as F -from typing import Dict, Any warnings.filterwarnings("ignore", category=UserWarning) warmup = 2 repeats = 50 dtypes = [torch.bfloat16] # , torch.float16] -# if hasattr(torch, "float8_e4m3fn"): -# dtypes.append(torch.float8_e4m3fn) - -PROFILES = { +backends = [ + "sdpa_math", + "sdpa_mem_efficient", + "sdpa_flash", + "sdpa_all", + "flex_attention", + "xformers", + "flash_attn", + "sage_attn", +] +profiles = { "sdxl": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 128}, "flux.1": {"l_q": 16717, "l_k": 16717, "h": 24, "d": 128}, "sd35": {"l_q": 16538, "l_k": 16538, "h": 24, "d": 128}, @@ -30,19 +37,19 @@ def get_stats(reset: bool = False): m = torch.cuda.max_memory_allocated() t = time.perf_counter() return m / (1024 ** 2), t - + def print_gpu_info(): if not torch.cuda.is_available(): print("GPU: Not available") return - + device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) total_mem = props.total_memory / (1024**3) free_mem, _ = torch.cuda.mem_get_info(device) free_mem = free_mem / (1024**3) major, minor = torch.cuda.get_device_capability(device) - + print(f"gpu: {torch.cuda.get_device_name(device)}") print(f"vram: total={total_mem:.2f}GB free={free_mem:.2f}GB") print(f"cuda: capability={major}.{minor} version={torch.version.cuda}") @@ -56,11 +63,9 @@ def benchmark_attention( l_k: int = 4096, h: int = 32, d: int = 128, - warmup: int = 10, - repeats: int = 100 ) -> Dict[str, Any]: device = "cuda" if torch.cuda.is_available() else "cpu" - + # Initialize tensors q = torch.randn(b, h, l_q, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype) k = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype) @@ -78,7 +83,7 @@ def benchmark_attention( try: if backend.startswith("sdpa_"): from torch.nn.attention import sdpa_kernel, SDPBackend - sdp_type = backend[len("sdpa_"):] + sdp_type = backend[len("sdpa_"):] # Map friendly names to new SDPA backends backend_map = { "math": [SDPBackend.MATH], @@ -88,21 +93,21 @@ def benchmark_attention( } if sdp_type not in backend_map: raise ValueError(f"Unknown SDPA type: {sdp_type}") - + results["version"] = torch.__version__ - + with sdpa_kernel(backend_map[sdp_type]): # Warmup for _ in range(warmup): _ = F.scaled_dot_product_attention(q, k, v) - + start_mem, start_time = get_stats(True) - + for _ in range(repeats): _ = F.scaled_dot_product_attention(q, k, v) - + end_mem, end_time = get_stats() - + results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem @@ -113,17 +118,17 @@ def benchmark_attention( q_fa = q.transpose(1, 2) k_fa = k.transpose(1, 2) v_fa = v.transpose(1, 2) - + for _ in range(warmup): _ = flash_attn_func(q_fa, k_fa, v_fa) - + start_mem, start_time = get_stats(True) - + for _ in range(repeats): _ = flash_attn_func(q_fa, k_fa, v_fa) - + end_mem, end_time = get_stats() - + results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem @@ -135,17 +140,17 @@ def benchmark_attention( q_xf = q.transpose(1, 2) k_xf = k.transpose(1, 2) v_xf = v.transpose(1, 2) - + for _ in range(warmup): _ = memory_efficient_attention(q_xf, k_xf, v_xf) - + start_mem, start_time = get_stats(True) - + for _ in range(repeats): _ = memory_efficient_attention(q_xf, k_xf, v_xf) - + end_mem, end_time = get_stats() - + results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem @@ -158,18 +163,18 @@ def benchmark_attention( results["version"] = importlib.metadata.version("sageattention") except Exception: results["version"] = getattr(sageattention, "__version__", "N/A") - + # SageAttention expects (B, H, L, D) logic for _ in range(warmup): _ = sageattn(q, k, v) - + start_mem, start_time = get_stats(True) - + for _ in range(repeats): _ = sageattn(q, k, v) - + end_mem, end_time = get_stats() - + results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem @@ -179,62 +184,52 @@ def benchmark_attention( # flex_attention requires torch.compile for performance flex_attention_compiled = torch.compile(flex_attention, dynamic=False) - + # Warmup (important to trigger compilation) for _ in range(warmup): _ = flex_attention_compiled(q, k, v) - + start_mem, start_time = get_stats(True) - + for _ in range(repeats): _ = flex_attention_compiled(q, k, v) - + end_mem, end_time = get_stats() - + results["latency_ms"] = (end_time - start_time) / repeats * 1000 results["memory_mb"] = end_mem - start_mem except Exception as e: results["status"] = "fail" results["error"] = str(e)[:49] + print(e) return results -def main(): - backends = [ - "sdpa_math", - "sdpa_mem_efficient", - "sdpa_flash", - "flex_attention", - "xformers", - "flash_attn", - "sage_attn", - ] - +def main(): + all_results = [] print_gpu_info() - print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}') - for name, config in PROFILES.items(): + print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}') + for name, config in profiles.items(): print(f"profile: {name} (L_q={config['l_q']}, L_k={config['l_k']}, H={config['h']}, D={config['d']})") for dtype in dtypes: print(f" dtype: {dtype}") print(f" {'backend':<20} | {'version':<12} | {'status':<8} | {'latency':<10} | {'memory':<12} | ") for backend in backends: res = benchmark_attention( - backend, - dtype, - l_q=config["l_q"], - l_k=config["l_k"], - h=config["h"], - d=config["d"], - warmup=warmup, - repeats=repeats + backend, + dtype, + l_q=config["l_q"], + l_k=config["l_k"], + h=config["h"], + d=config["d"], ) all_results.append(res) - + latency = f"{res['latency_ms']:.4f} ms" memory = f"{res['memory_mb']:.2f} MB" - + print(f" {res['backend']:<20} | {res['version']:<12} | {res['status']:<8} | {latency:<10} | {memory:<12} | {res['error']}") if __name__ == "__main__": diff --git a/cli/full-test.sh b/test/full-test.sh old mode 100755 new mode 100644 similarity index 100% rename from cli/full-test.sh rename to test/full-test.sh diff --git a/cli/locale-sanitize-override.py b/test/locale-sanitize-override.py old mode 100755 new mode 100644 similarity index 100% rename from cli/locale-sanitize-override.py rename to test/locale-sanitize-override.py diff --git a/cli/localize.js b/test/localize.js old mode 100755 new mode 100644 similarity index 100% rename from cli/localize.js rename to test/localize.js diff --git a/cli/test-schedulers.py b/test/test-schedulers.py similarity index 100% rename from cli/test-schedulers.py rename to test/test-schedulers.py diff --git a/cli/test-tagger.py b/test/test-tagger.py similarity index 100% rename from cli/test-tagger.py rename to test/test-tagger.py diff --git a/cli/test-weighted-lists.py b/test/test-weighted-lists.py similarity index 91% rename from cli/test-weighted-lists.py rename to test/test-weighted-lists.py index 439f08a15..38ceb0e41 100644 --- a/cli/test-weighted-lists.py +++ b/test/test-weighted-lists.py @@ -1,6 +1,7 @@ #!/usr/bin/env python -import sys, os +import sys +import os from collections import Counter script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -20,25 +21,25 @@ tolerance_pct = 5 # tests tests = [ # - empty - ["", { '': 100 } ], + ["", { '': 100 } ], # - no weights - [ "red|blonde|black", { 'black': 33, 'red': 33, 'blonde': 33 } ], + [ "red|blonde|black", { 'black': 33, 'red': 33, 'blonde': 33 } ], # - full weights <= 1 - [ "red:0.1|blonde:0.9", { 'blonde': 90, 'red': 10 } ], + [ "red:0.1|blonde:0.9", { 'blonde': 90, 'red': 10 } ], # - weights > 1 to test normalization - [ "red:1|blonde:2|black:5", { 'blonde': 25, 'red': 12.5, 'black': 62.5 } ], + [ "red:1|blonde:2|black:5", { 'blonde': 25, 'red': 12.5, 'black': 62.5 } ], # - disabling 0 weights to force one result - [ "red:0|blonde|black:0", { 'blonde': 100 } ], + [ "red:0|blonde|black:0", { 'blonde': 100 } ], # - weights <= 1 with distribution of the leftover - [ "red:0.5|blonde|black:0.3|brown", { 'red': 50, 'black': 30, 'brown': 10, 'blonde': 10 } ], + [ "red:0.5|blonde|black:0.3|brown", { 'red': 50, 'black': 30, 'brown': 10, 'blonde': 10 } ], # - weights > 1, unweightes should get default of 1 - [ "red:2|blonde|black", { 'red': 50, 'blonde': 25, 'black': 25 } ], + [ "red:2|blonde|black", { 'red': 50, 'blonde': 25, 'black': 25 } ], # - ignore content of () - [ "red:0.5|(blonde:1.3)", { 'red': 50, '(blonde:1.3)': 50 } ], + [ "red:0.5|(blonde:1.3)", { 'red': 50, '(blonde:1.3)': 50 } ], # - ignore content of [] - [ "red:0.5|[stuff:1.3]", { '[stuff:1.3]': 50, 'red': 50 } ], + [ "red:0.5|[stuff:1.3]", { '[stuff:1.3]': 50, 'red': 50 } ], # - ignore content of <> - [ "red:0.5|", { '': 50, 'red': 50 } ] + [ "red:0.5|", { '': 50, 'red': 50 } ] ] # ------------------------------------------------- @@ -109,4 +110,4 @@ with open(fn, 'r', encoding='utf-8') as f: print("RESULT: FAILED (distribution)") else: print("RESULT: PASSED") - print('') \ No newline at end of file + print('') diff --git a/cli/validate-locale.py b/test/validate-locale.py old mode 100755 new mode 100644 similarity index 100% rename from cli/validate-locale.py rename to test/validate-locale.py