mirror of https://github.com/vladmandic/automatic
Merge remote-tracking branch 'origin/dev' into refactor/remove-face-restoration
# Conflicts: # .pylintrc # .ruff.tomlpull/4637/head
commit
385532154f
|
|
@ -9,22 +9,22 @@ body:
|
|||
attributes:
|
||||
label: Model name
|
||||
description: Enter model name
|
||||
value:
|
||||
value:
|
||||
- type: textarea
|
||||
id: type
|
||||
attributes:
|
||||
label: Model type
|
||||
description: Describe model type
|
||||
value:
|
||||
value:
|
||||
- type: textarea
|
||||
id: url
|
||||
attributes:
|
||||
label: Model URL
|
||||
description: Enter URL to the model page
|
||||
value:
|
||||
value:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Reason
|
||||
description: Enter a reason why you would like to see this model supported
|
||||
value:
|
||||
value:
|
||||
|
|
|
|||
283
.pylintrc
283
.pylintrc
|
|
@ -1,283 +0,0 @@
|
|||
[MAIN]
|
||||
analyse-fallback-blocks=no
|
||||
clear-cache-post-run=no
|
||||
extension-pkg-allow-list=
|
||||
prefer-stubs=yes
|
||||
extension-pkg-whitelist=
|
||||
fail-on=
|
||||
fail-under=10
|
||||
ignore=CVS
|
||||
ignore-paths=/usr/lib/.*$,
|
||||
venv,
|
||||
.git,
|
||||
.ruff_cache,
|
||||
.vscode,
|
||||
modules/apg,
|
||||
modules/cfgzero,
|
||||
modules/control/proc,
|
||||
modules/control/units,
|
||||
modules/dml,
|
||||
modules/flash_attn_triton_amd,
|
||||
modules/ggml,
|
||||
modules/hidiffusion,
|
||||
modules/hijack/ddpm_edit.py,
|
||||
modules/intel,
|
||||
modules/intel/ipex,
|
||||
modules/framepack/pipeline,
|
||||
modules/onnx_impl,
|
||||
modules/pag,
|
||||
modules/postprocess/aurasr_arch.py,
|
||||
modules/prompt_parser_xhinker.py,
|
||||
modules/ras,
|
||||
modules/seedvr,
|
||||
modules/rife,
|
||||
modules/schedulers,
|
||||
modules/taesd,
|
||||
modules/teacache,
|
||||
modules/todo,
|
||||
modules/res4lyf,
|
||||
pipelines/bria,
|
||||
pipelines/flex2,
|
||||
pipelines/f_lite,
|
||||
pipelines/hidream,
|
||||
pipelines/hdm,
|
||||
pipelines/meissonic,
|
||||
pipelines/omnigen2,
|
||||
pipelines/segmoe,
|
||||
pipelines/xomni,
|
||||
pipelines/chrono,
|
||||
scripts/consistory,
|
||||
scripts/ctrlx,
|
||||
scripts/daam,
|
||||
scripts/demofusion,
|
||||
scripts/freescale,
|
||||
scripts/infiniteyou,
|
||||
scripts/instantir,
|
||||
scripts/lbm,
|
||||
scripts/layerdiffuse,
|
||||
scripts/mod,
|
||||
scripts/pixelsmith,
|
||||
scripts/differential_diffusion.py,
|
||||
scripts/pulid,
|
||||
scripts/xadapter,
|
||||
repositories,
|
||||
extensions-builtin/sd-extension-chainner/nodes,
|
||||
extensions-builtin/sd-webui-agent-scheduler,
|
||||
extensions-builtin/sdnext-modernui/node_modules,
|
||||
extensions-builtin/sdnext-kanvas/node_modules,
|
||||
ignore-patterns=.*test*.py$,
|
||||
.*_model.py$,
|
||||
.*_arch.py$,
|
||||
.*_model_arch.py*,
|
||||
.*_model_arch_v2.py$,
|
||||
ignored-modules=
|
||||
jobs=8
|
||||
limit-inference-results=100
|
||||
load-plugins=
|
||||
persistent=no
|
||||
py-version=3.10
|
||||
recursive=no
|
||||
source-roots=
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
[BASIC]
|
||||
argument-naming-style=snake_case
|
||||
attr-naming-style=snake_case
|
||||
bad-names=foo, bar, baz, toto, tutu, tata
|
||||
bad-names-rgxs=
|
||||
class-attribute-naming-style=any
|
||||
class-const-naming-style=UPPER_CASE
|
||||
class-naming-style=PascalCase
|
||||
const-naming-style=snake_case
|
||||
docstring-min-length=-1
|
||||
function-naming-style=snake_case
|
||||
good-names=i,j,k,e,ex,ok,p,x,y,id
|
||||
good-names-rgxs=
|
||||
include-naming-hint=no
|
||||
inlinevar-naming-style=any
|
||||
method-naming-style=snake_case
|
||||
module-naming-style=snake_case
|
||||
name-group=
|
||||
no-docstring-rgx=^_
|
||||
property-classes=abc.abstractproperty
|
||||
variable-naming-style=snake_case
|
||||
|
||||
[CLASSES]
|
||||
check-protected-access-in-special-methods=no
|
||||
defining-attr-methods=__init__, __new__,
|
||||
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
|
||||
valid-classmethod-first-arg=cls
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
[DESIGN]
|
||||
exclude-too-few-public-methods=
|
||||
ignored-parents=
|
||||
max-args=199
|
||||
max-attributes=99
|
||||
max-bool-expr=99
|
||||
max-branches=199
|
||||
max-locals=99
|
||||
max-parents=99
|
||||
max-public-methods=99
|
||||
max-returns=99
|
||||
max-statements=199
|
||||
min-public-methods=1
|
||||
|
||||
[EXCEPTIONS]
|
||||
overgeneral-exceptions=builtins.BaseException,builtins.Exception
|
||||
|
||||
[FORMAT]
|
||||
expected-line-ending-format=
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
indent-after-paren=4
|
||||
indent-string=' '
|
||||
max-line-length=200
|
||||
max-module-lines=9999
|
||||
single-line-class-stmt=no
|
||||
single-line-if-stmt=no
|
||||
|
||||
[IMPORTS]
|
||||
allow-any-import-level=
|
||||
allow-reexport-from-package=no
|
||||
allow-wildcard-with-all=no
|
||||
deprecated-modules=
|
||||
ext-import-graph=
|
||||
import-graph=
|
||||
int-import-graph=
|
||||
known-standard-library=
|
||||
known-third-party=enchant
|
||||
preferred-modules=
|
||||
|
||||
[LOGGING]
|
||||
logging-format-style=new
|
||||
logging-modules=logging
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
confidence=HIGH,
|
||||
CONTROL_FLOW,
|
||||
INFERENCE,
|
||||
INFERENCE_FAILURE,
|
||||
UNDEFINED
|
||||
# disable=C,R,W
|
||||
disable=abstract-method,
|
||||
bad-inline-option,
|
||||
bare-except,
|
||||
broad-exception-caught,
|
||||
chained-comparison,
|
||||
consider-iterating-dictionary,
|
||||
consider-merging-isinstance,
|
||||
consider-using-dict-items,
|
||||
consider-using-enumerate,
|
||||
consider-using-from-import,
|
||||
consider-using-generator,
|
||||
consider-using-get,
|
||||
consider-using-in,
|
||||
consider-using-max-builtin,
|
||||
consider-using-min-builtin,
|
||||
consider-using-sys-exit,
|
||||
cyclic-import,
|
||||
dangerous-default-value,
|
||||
deprecated-pragma,
|
||||
duplicate-code,
|
||||
file-ignored,
|
||||
import-error,
|
||||
import-outside-toplevel,
|
||||
invalid-name,
|
||||
line-too-long,
|
||||
locally-disabled,
|
||||
logging-fstring-interpolation,
|
||||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
not-callable,
|
||||
pointless-string-statement,
|
||||
raw-checker-failed,
|
||||
simplifiable-if-expression,
|
||||
suppressed-message,
|
||||
too-few-public-methods,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-positional-arguments,
|
||||
too-many-statements,
|
||||
unidiomatic-typecheck,
|
||||
unknown-option-value,
|
||||
unnecessary-dict-index-lookup,
|
||||
unnecessary-dunder-call,
|
||||
unnecessary-lambda-assigment,
|
||||
unnecessary-lambda,
|
||||
unused-wildcard-import,
|
||||
unpacking-non-sequence,
|
||||
unsubscriptable-object,
|
||||
useless-return,
|
||||
use-dict-literal,
|
||||
use-symbolic-message-instead,
|
||||
useless-suppression,
|
||||
wrong-import-position,
|
||||
enable=c-extension-no-member
|
||||
|
||||
[METHOD_ARGS]
|
||||
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
|
||||
|
||||
[MISCELLANEOUS]
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
notes-rgx=
|
||||
|
||||
[REFACTORING]
|
||||
max-nested-blocks=5
|
||||
never-returning-functions=sys.exit,argparse.parse_error
|
||||
|
||||
[REPORTS]
|
||||
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
|
||||
msg-template=
|
||||
reports=no
|
||||
score=no
|
||||
|
||||
[SIMILARITIES]
|
||||
ignore-comments=yes
|
||||
ignore-docstrings=yes
|
||||
ignore-imports=yes
|
||||
ignore-signatures=yes
|
||||
min-similarity-lines=4
|
||||
|
||||
[SPELLING]
|
||||
max-spelling-suggestions=4
|
||||
spelling-dict=
|
||||
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
|
||||
spelling-ignore-words=
|
||||
spelling-private-dict-file=
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
[STRING]
|
||||
check-quote-consistency=no
|
||||
check-str-concat-over-line-jumps=no
|
||||
|
||||
[TYPECHECK]
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
generated-members=numpy.*,logging.*,torch.*,cv2.*
|
||||
ignore-none=yes
|
||||
ignore-on-opaque-inference=yes
|
||||
ignored-checks-for-mixins=no-member,
|
||||
not-async-context-manager,
|
||||
not-context-manager,
|
||||
attribute-defined-outside-init
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
|
||||
missing-member-hint=yes
|
||||
missing-member-hint-distance=1
|
||||
missing-member-max-choices=1
|
||||
mixin-class-rgx=.*[Mm]ixin
|
||||
signature-mutators=
|
||||
|
||||
[VARIABLES]
|
||||
additional-builtins=
|
||||
allow-global-unused-variables=yes
|
||||
allowed-redefined-builtins=
|
||||
callbacks=cb_,
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
init-import=no
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||
108
.ruff.toml
108
.ruff.toml
|
|
@ -1,108 +0,0 @@
|
|||
line-length = 250
|
||||
indent-width = 4
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"venv",
|
||||
".git",
|
||||
".ruff_cache",
|
||||
".vscode",
|
||||
|
||||
"modules/cfgzero",
|
||||
"modules/flash_attn_triton_amd",
|
||||
"modules/hidiffusion",
|
||||
"modules/intel/ipex",
|
||||
"modules/pag",
|
||||
"modules/schedulers",
|
||||
"modules/teacache",
|
||||
"modules/seedvr",
|
||||
|
||||
"modules/control/proc",
|
||||
"modules/control/units",
|
||||
"modules/control/units/xs_pipe.py",
|
||||
"modules/postprocess/aurasr_arch.py",
|
||||
|
||||
"pipelines/meissonic",
|
||||
"pipelines/omnigen2",
|
||||
"pipelines/hdm",
|
||||
"pipelines/segmoe",
|
||||
"pipelines/xomni",
|
||||
"pipelines/chrono",
|
||||
|
||||
"scripts/lbm",
|
||||
"scripts/daam",
|
||||
"scripts/xadapter",
|
||||
"scripts/pulid",
|
||||
"scripts/instantir",
|
||||
"scripts/freescale",
|
||||
"scripts/consistory",
|
||||
|
||||
"repositories",
|
||||
|
||||
"extensions-builtin/Lora",
|
||||
"extensions-builtin/sd-extension-chainner/nodes",
|
||||
"extensions-builtin/sd-webui-agent-scheduler",
|
||||
"extensions-builtin/sdnext-modernui/node_modules",
|
||||
]
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
"F",
|
||||
"E",
|
||||
"W",
|
||||
"C",
|
||||
"B",
|
||||
"I",
|
||||
"YTT",
|
||||
"ASYNC",
|
||||
"RUF",
|
||||
"AIR",
|
||||
"NPY",
|
||||
"C4",
|
||||
"T10",
|
||||
"EXE",
|
||||
"ISC",
|
||||
"ICN",
|
||||
"RSE",
|
||||
"TCH",
|
||||
"TID",
|
||||
"INT",
|
||||
"PLE",
|
||||
]
|
||||
ignore = [
|
||||
"B006", # Do not use mutable data structures for argument defaults
|
||||
"B008", # Do not perform function call in argument defaults
|
||||
"B905", # Strict zip() usage
|
||||
"C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead
|
||||
"C408", # Unnecessary `dict` call
|
||||
"I001", # Import block is un-sorted or un-formatted
|
||||
"E402", # Module level import not at top of file
|
||||
"E501", # Line too long
|
||||
"E721", # Do not compare types, use `isinstance()`
|
||||
"E731", # Do not assign a `lambda` expression, use a `def`
|
||||
"E741", # Ambiguous variable name
|
||||
"F401", # Imported by unused
|
||||
"EXE001", # file with shebang is not marked executable
|
||||
"NPY002", # replace legacy random
|
||||
"RUF005", # Consider iterable unpacking
|
||||
"RUF008", # Do not use mutable default values for dataclass
|
||||
"RUF010", # Use explicit conversion flag
|
||||
"RUF012", # Mutable class attributes
|
||||
"RUF013", # PEP 484 prohibits implicit `Optional`
|
||||
"RUF015", # Prefer `next(...)` over single element slice
|
||||
"RUF046", # Value being cast to `int` is already an integer
|
||||
"RUF059", # Unpacked variables are not used
|
||||
"RUF051", # Prefer pop over del
|
||||
]
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
docstring-code-format = false
|
||||
|
||||
[lint.mccabe]
|
||||
max-complexity = 150
|
||||
11
CHANGELOG.md
11
CHANGELOG.md
|
|
@ -1,6 +1,6 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-02-07
|
||||
## Update for 2026-02-09
|
||||
|
||||
- **Upscalers**
|
||||
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
|
||||
|
|
@ -11,12 +11,21 @@
|
|||
- pipelines: add **ZImageInpaint**, thanks @CalamitousFelicitousness
|
||||
- add `--remote` command line flag that reduces client/server chatter and improves link stability
|
||||
for long-running generates, useful when running on remote servers
|
||||
- hires: allow using different lora in refiner prompt
|
||||
- **nunchaku** models are now listed in networks tab as reference models
|
||||
instead of being used implicitly via quantization, thanks @CalamitousFelicitousness
|
||||
- **UI**
|
||||
- ui: **themes** add *CTD-NT64Light* and *CTD-NT64Dark*, thanks @resonantsky
|
||||
- ui: **gallery** add option to auto-refresh gallery, thanks @awsr
|
||||
- **Internal**
|
||||
- refactor: switch to `pyproject.toml` for tool configs
|
||||
- refactor: reorganize `cli` scripts
|
||||
- refactor: move tests to dedicated `/test/`
|
||||
- update `lint` rules, thanks @awsr
|
||||
- update `requirements`
|
||||
- **Fixes**
|
||||
- fix: handle `clip` installer doing unwanted `setuptools` update
|
||||
- fix: cleanup for `uv` installer fallback
|
||||
- fix: add metadata restore to always-on scripts
|
||||
- fix: improve wildcard weights parsing, thanks @Tillerz
|
||||
- fix: ui gallery cace recursive cleanup, thanks @awsr
|
||||
|
|
|
|||
|
|
@ -0,0 +1,209 @@
|
|||
{
|
||||
"FLUX.1-Dev Nunchaku SVDQuant": {
|
||||
"path": "black-forest-labs/FLUX.1-dev",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "black-forest-labs--FLUX.1-dev.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of FLUX.1-dev transformer with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"FLUX.1-Schnell Nunchaku SVDQuant": {
|
||||
"path": "black-forest-labs/FLUX.1-schnell",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "black-forest-labs--FLUX.1-schnell.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of FLUX.1-schnell transformer with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "sampler: Default, cfg_scale: 1.0, steps: 4",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"FLUX.1-Kontext Nunchaku SVDQuant": {
|
||||
"path": "black-forest-labs/FLUX.1-Kontext-dev",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "black-forest-labs--FLUX.1-Kontext-dev.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Kontext-dev transformer with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"FLUX.1-Krea Nunchaku SVDQuant": {
|
||||
"path": "black-forest-labs/FLUX.1-Krea-dev",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "black-forest-labs--FLUX.1-Krea-dev.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Krea-dev transformer with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"FLUX.1-Fill Nunchaku SVDQuant": {
|
||||
"path": "black-forest-labs/FLUX.1-Fill-dev",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "black-forest-labs--FLUX.1-Fill-dev.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Fill-dev transformer for inpainting",
|
||||
"skip": true,
|
||||
"hidden": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"FLUX.1-Depth Nunchaku SVDQuant": {
|
||||
"path": "black-forest-labs/FLUX.1-Depth-dev",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "black-forest-labs--FLUX.1-Depth-dev.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of FLUX.1-Depth-dev transformer for depth-conditioned generation",
|
||||
"skip": true,
|
||||
"hidden": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Shuttle Jaguar Nunchaku SVDQuant": {
|
||||
"path": "shuttleai/shuttle-jaguar",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "shuttleai--shuttle-jaguar.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Shuttle Jaguar transformer",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model", "TE"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Image Nunchaku SVDQuant": {
|
||||
"path": "Qwen/Qwen-Image",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "Qwen--Qwen-Image.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Image transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Lightning (8-step) Nunchaku SVDQuant": {
|
||||
"path": "vladmandic/Qwen-Lightning",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "vladmandic--Qwen-Lightning.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning (8-step distilled) transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "steps: 8",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Lightning (4-step) Nunchaku SVDQuant": {
|
||||
"path": "vladmandic/Qwen-Lightning",
|
||||
"subfolder": "nunchaku-4step",
|
||||
"preview": "vladmandic--Qwen-Lightning.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning (4-step distilled) transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "steps: 4",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Image-Edit Nunchaku SVDQuant": {
|
||||
"path": "Qwen/Qwen-Image-Edit",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "Qwen--Qwen-Image-Edit.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Image-Edit transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Lightning-Edit (8-step) Nunchaku SVDQuant": {
|
||||
"path": "vladmandic/Qwen-Lightning-Edit",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "vladmandic--Qwen-Lightning-Edit.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning-Edit (8-step distilled editing) transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "steps: 8",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Lightning-Edit (4-step) Nunchaku SVDQuant": {
|
||||
"path": "vladmandic/Qwen-Lightning-Edit",
|
||||
"subfolder": "nunchaku-4step",
|
||||
"preview": "vladmandic--Qwen-Lightning-Edit.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Lightning-Edit (4-step distilled editing) transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "steps: 4",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Qwen-Image-Edit-2509 Nunchaku SVDQuant": {
|
||||
"path": "Qwen/Qwen-Image-Edit-2509",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "Qwen--Qwen-Image-Edit-2509.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Qwen-Image-Edit-2509 transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 September"
|
||||
},
|
||||
"Sana 1.6B 1k Nunchaku SVDQuant": {
|
||||
"path": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "Efficient-Large-Model--Sana_1600M_1024px_BF16_diffusers.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Sana 1.6B 1024px transformer with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"Z-Image-Turbo Nunchaku SVDQuant": {
|
||||
"path": "Tongyi-MAI/Z-Image-Turbo",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "Tongyi-MAI--Z-Image-Turbo.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of Z-Image-Turbo transformer with INT4 and SVD rank 128",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "sampler: Default, cfg_scale: 1.0, steps: 9",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"SDXL Base Nunchaku SVDQuant": {
|
||||
"path": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "stabilityai--stable-diffusion-xl-base-1.0.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of SDXL Base 1.0 UNet with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
},
|
||||
"SDXL Turbo Nunchaku SVDQuant": {
|
||||
"path": "stabilityai/sdxl-turbo",
|
||||
"subfolder": "nunchaku",
|
||||
"preview": "stabilityai--sdxl-turbo.jpg",
|
||||
"desc": "Nunchaku SVDQuant quantization of SDXL Turbo UNet with INT4 and SVD rank 32",
|
||||
"skip": true,
|
||||
"nunchaku": ["Model"],
|
||||
"tags": "nunchaku",
|
||||
"extras": "sampler: Default, cfg_scale: 1.0, steps: 4",
|
||||
"size": 0,
|
||||
"date": "2025 June"
|
||||
}
|
||||
}
|
||||
|
|
@ -337,6 +337,8 @@
|
|||
{"id":"","label":"Model Options","localized":"","reload":"","hint":"Settings related to behavior of specific models"},
|
||||
{"id":"","label":"Model Offloading","localized":"","reload":"","hint":"Settings related to model offloading and memory management"},
|
||||
{"id":"","label":"Model Quantization","localized":"","reload":"","hint":"Settings related to model quantization which is used to reduce memory usage"},
|
||||
{"id":"","label":"Nunchaku attention","localized":"","reload":"","hint":"Replaces default attention with Nunchaku's custom FP16 attention kernel for faster inference on consumer NVIDIA GPUs.<br>Might provide performance improvement on GPUs which have higher FP16 tensor cores throughput than BF16.<br><br>Currently only affects Flux-based models (Dev, Schnell, Kontext, Fill, Depth, etc.). Has no effect on Qwen, SDXL, Sana, or other architectures.<br><br>Disabled by default."},
|
||||
{"id":"","label":"Nunchaku offloading","localized":"","reload":"","hint":"Enables Nunchaku's own per-block CPU offloading with asynchronous CUDA streams to reduce VRAM usage.<br>Uses a ping-pong buffer strategy: while one transformer block computes on GPU, the next block preloads from CPU in the background, hiding most of the transfer latency.<br><br>Can reduce VRAM usage at the cost of slower inference.<br>This replaces SD.Next's pipeline offloading for the transformer component.<br><br>Only useful on low-VRAM GPUs. If your GPU has enough memory to hold the quantized model (16+ GB), keep this disabled for maximum speed.<br>Supports Flux and Qwen models. Not supported for SDXL where this setting is ignored.<br>Disabled by default."},
|
||||
{"id":"","label":"Image Metadata","localized":"","reload":"","hint":"Settings related to handling of metadata that is created with generated images"},
|
||||
{"id":"","label":"Legacy Options","localized":"","reload":"","hint":"Settings related to legacy options - should not be used"},
|
||||
{"id":"","label":"Restart server","localized":"","reload":"","hint":"Restart server"},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg id="Layer_1" xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:svg="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
|
||||
<!-- Generator: Adobe Illustrator 30.2.0, SVG Export Plug-In . SVG Version: 2.1.1 Build 105) -->
|
||||
<path d="M333.62,470.26l-9.4,4.63-39.05-21.52,16.16,32.8-9.4,4.63-24.4-49.52,9.4-4.63,39.09,21.59-16.19-32.87,9.4-4.63,24.4,49.52h-.01ZM442.86,365.07h0ZM334.71,418.37l6.05,12.27h.01l22.32-11.01,4.03,8.18-22.32,11,6.26,12.7,38.07-18.76,3.55-23.51-18.59-10.28-39.38,19.41h0ZM457.12,338.1l-205,101,31.64,65.17,205.12-100.53-31.76-65.64h0ZM423.12,426.16l-21.57-12.47-2.75,24.46-11.34,5.59h0l-41.79,20.59-24.4-49.52,43.06-21.21.19.07-.16-.09h-.01.01l11.54-5.68,19.04,10.95,2.2-21.42,11.34-5.59-4.76,31.42,30.96,17.21-11.55,5.69h-.01ZM433.46,369.7l-14.85,7.32-4.03-8.18,39.11-19.27,4.03,8.18-14.86,7.32,20.37,41.34-9.4,4.63-20.37-41.34h0Z" fill="#231f20"/>
|
||||
<polygon points="248.11 427.73 453.11 326.73 453.11 113.73 248.11 7.73 248.11 117.73 333.11 167.73 333.11 267.73 248.11 317.73 248.11 427.73" fill="#231f20"/>
|
||||
<polygon points="23.11 150.73 23.11 275.73 103.11 319.73 103.11 331.73 23.11 375.73 23.11 500.73 228.11 386.73 228.11 261.73 143.11 221.73 143.11 209.73 228.11 161.73 228.11 36.73 23.11 150.73" fill="#231f20"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg id="Layer_1" xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:svg="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
|
||||
<!-- Generator: Adobe Illustrator 30.2.0, SVG Export Plug-In . SVG Version: 2.1.1 Build 105) -->
|
||||
<path d="M333.62,470.26l-9.4,4.63-39.05-21.52,16.16,32.8-9.4,4.63-24.4-49.52,9.4-4.63,39.09,21.59-16.19-32.87,9.4-4.63,24.4,49.52h-.01ZM442.86,365.07h0ZM334.71,418.37l6.05,12.27h.01l22.32-11.01,4.03,8.18-22.32,11,6.26,12.7,38.07-18.76,3.55-23.51-18.59-10.28-39.38,19.41h0ZM457.12,338.1l-205,101,31.64,65.17,205.12-100.53-31.76-65.64h0ZM423.12,426.16l-21.57-12.47-2.75,24.46-11.34,5.59h0l-41.79,20.59-24.4-49.52,43.06-21.21.19.07-.16-.09h-.01.01l11.54-5.68,19.04,10.95,2.2-21.42,11.34-5.59-4.76,31.42,30.96,17.21-11.55,5.69h-.01ZM433.46,369.7l-14.85,7.32-4.03-8.18,39.11-19.27,4.03,8.18-14.86,7.32,20.37,41.34-9.4,4.63-20.37-41.34h0Z" fill="#fff"/>
|
||||
<polygon points="248.11 427.73 453.11 326.73 453.11 113.73 248.11 7.73 248.11 117.73 333.11 167.73 333.11 267.73 248.11 317.73 248.11 427.73" fill="#fff"/>
|
||||
<polygon points="23.11 150.73 23.11 275.73 103.11 319.73 103.11 331.73 23.11 375.73 23.11 500.73 228.11 386.73 228.11 261.73 143.11 221.73 143.11 209.73 228.11 161.73 228.11 36.73 23.11 150.73" fill="#fff"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 17 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 473 KiB |
BIN
html/logo-wm.png
BIN
html/logo-wm.png
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 10 KiB |
31
installer.py
31
installer.py
|
|
@ -431,6 +431,26 @@ def run(cmd: str, arg: str):
|
|||
return txt
|
||||
|
||||
|
||||
def cleanup_broken_packages():
|
||||
"""Remove dist-info directories with missing RECORD files that uv may have left behind"""
|
||||
try:
|
||||
import site
|
||||
for site_dir in site.getsitepackages():
|
||||
if not os.path.isdir(site_dir):
|
||||
continue
|
||||
for entry in os.listdir(site_dir):
|
||||
if not entry.endswith('.dist-info'):
|
||||
continue
|
||||
dist_info = os.path.join(site_dir, entry)
|
||||
record_file = os.path.join(dist_info, 'RECORD')
|
||||
if not os.path.isfile(record_file):
|
||||
pkg_name = entry.split('-')[0]
|
||||
log.warning(f'Install: removing broken package metadata: {pkg_name} path={dist_info}')
|
||||
shutil.rmtree(dist_info, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
|
||||
t_start = time.time()
|
||||
originalArg = arg
|
||||
|
|
@ -454,6 +474,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
|
|||
err = result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
log.warning(f'Install: cmd="{pipCmd}" args="{all_args}" cannot use uv, fallback to pip')
|
||||
debug(f'Install: uv pip error: {err}')
|
||||
cleanup_broken_packages()
|
||||
return pip(originalArg, ignore, quiet, uv=False)
|
||||
else:
|
||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
|
|
@ -468,7 +489,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
|
|||
|
||||
|
||||
# install package using pip if not already installed
|
||||
def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False, force: bool = False):
|
||||
def install(package, friendly: str = None, ignore: bool = False, reinstall: bool = False, no_deps: bool = False, quiet: bool = False, force: bool = False, no_build_isolation: bool = False):
|
||||
t_start = time.time()
|
||||
res = ''
|
||||
if args.reinstall or args.upgrade:
|
||||
|
|
@ -476,7 +497,8 @@ def install(package, friendly: str = None, ignore: bool = False, reinstall: bool
|
|||
quick_allowed = False
|
||||
if (args.reinstall) or (reinstall) or (not installed(package, friendly, quiet=quiet)):
|
||||
deps = '' if not no_deps else '--no-deps '
|
||||
cmd = f"install{' --upgrade' if not args.uv else ''}{' --force-reinstall' if force else ''} {deps}{package}"
|
||||
isolation = '' if not no_build_isolation else '--no-build-isolation '
|
||||
cmd = f"install{' --upgrade' if not args.uv else ''}{' --force-reinstall' if force else ''} {deps}{isolation}{package}"
|
||||
res = pip(cmd, ignore=ignore, uv=package != "uv" and not package.startswith('git+'))
|
||||
try:
|
||||
importlib.reload(pkg_resources)
|
||||
|
|
@ -665,7 +687,7 @@ def check_diffusers():
|
|||
t_start = time.time()
|
||||
if args.skip_all:
|
||||
return
|
||||
sha = '99e2cfff27dec514a43e260e885c5e6eca038b36' # diffusers commit hash
|
||||
sha = '5bf248ddd8796b4f4958559429071a28f9b2dd3a' # diffusers commit hash
|
||||
# if args.use_rocm or args.use_zluda or args.use_directml:
|
||||
# sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now
|
||||
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
|
||||
|
|
@ -1100,8 +1122,9 @@ def install_packages():
|
|||
pr.enable()
|
||||
# log.info('Install: verifying packages')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git")
|
||||
install(clip_package, 'clip', quiet=True)
|
||||
install(clip_package, 'clip', quiet=True, no_build_isolation=True)
|
||||
install('open-clip-torch', no_deps=True, quiet=True)
|
||||
install(clip_package, 'ftfy', quiet=True, no_build_isolation=True)
|
||||
# tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', 'tensorflow==2.13.0')
|
||||
# tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', None)
|
||||
# if tensorflow_package is not None:
|
||||
|
|
|
|||
|
|
@ -171,6 +171,12 @@ async function filterExtraNetworksForTab(searchTerm) {
|
|||
.toLowerCase()
|
||||
.includes('quantized') ? '' : 'none';
|
||||
});
|
||||
} else if (searchTerm === 'nunchaku/') {
|
||||
cards.forEach((elem) => {
|
||||
elem.style.display = elem.dataset.tags
|
||||
.toLowerCase()
|
||||
.includes('nunchaku') ? '' : 'none';
|
||||
});
|
||||
} else if (searchTerm === 'local/') {
|
||||
cards.forEach((elem) => {
|
||||
elem.style.display = elem.dataset.name
|
||||
|
|
|
|||
|
|
@ -1353,6 +1353,7 @@ async function initGallery() { // triggered on gradio change to monitor when ui
|
|||
|
||||
monitorGalleries();
|
||||
updateFolders();
|
||||
initGalleryAutoRefresh();
|
||||
[
|
||||
'browser_folders',
|
||||
'outdir_samples',
|
||||
|
|
|
|||
|
|
@ -156,7 +156,6 @@ def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNe
|
|||
if prompt is None:
|
||||
return "", res
|
||||
if isinstance(prompt, list):
|
||||
shared.log.warning(f"parse_prompt was called with a list instead of a string: {prompt}")
|
||||
return parse_prompts(prompt)
|
||||
|
||||
def found(m: re.Match[str]):
|
||||
|
|
@ -168,13 +167,17 @@ def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNe
|
|||
return updated_prompt, res
|
||||
|
||||
|
||||
def parse_prompts(prompts: list[str]):
|
||||
def parse_prompts(prompts: list[str], extra_data=None):
|
||||
updated_prompt_list: list[str] = []
|
||||
extra_data: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list)
|
||||
extra_data: defaultdict[str, list[ExtraNetworkParams]] = extra_data or defaultdict(list)
|
||||
for prompt in prompts:
|
||||
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||
if not extra_data:
|
||||
extra_data = parsed_extra_data
|
||||
elif parsed_extra_data:
|
||||
extra_data = parsed_extra_data
|
||||
else:
|
||||
pass
|
||||
updated_prompt_list.append(updated_prompt)
|
||||
|
||||
return updated_prompt_list, extra_data
|
||||
|
|
|
|||
|
|
@ -12,24 +12,10 @@ import piexif
|
|||
import piexif.helper
|
||||
from PIL import Image, PngImagePlugin, ExifTags, ImageDraw
|
||||
from modules import sd_samplers, shared, script_callbacks, errors, paths
|
||||
from modules.images_grid import (
|
||||
image_grid as image_grid,
|
||||
get_grid_size as get_grid_size,
|
||||
split_grid as split_grid,
|
||||
combine_grid as combine_grid,
|
||||
check_grid_size as check_grid_size,
|
||||
get_font as get_font,
|
||||
draw_grid_annotations as draw_grid_annotations,
|
||||
draw_prompt_matrix as draw_prompt_matrix,
|
||||
GridAnnotation as GridAnnotation,
|
||||
Grid as Grid,
|
||||
)
|
||||
from modules.images_resize import resize_image as resize_image
|
||||
from modules.images_namegen import (
|
||||
FilenameGenerator as FilenameGenerator,
|
||||
get_next_sequence_number as get_next_sequence_number,
|
||||
)
|
||||
from modules.video import save_video as save_video
|
||||
from modules.images_grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import
|
||||
from modules.images_resize import resize_image # pylint: disable=unused-import
|
||||
from modules.images_namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import
|
||||
from modules.video import save_video # pylint: disable=unused-import
|
||||
|
||||
|
||||
debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
|
|
|||
|
|
@ -4,10 +4,27 @@ from installer import log, pip
|
|||
from modules import devices
|
||||
|
||||
|
||||
nunchaku_ver = '1.1.0'
|
||||
nunchaku_versions = {
|
||||
'2.5': '1.0.1',
|
||||
'2.6': '1.0.1',
|
||||
'2.7': '1.1.0',
|
||||
'2.8': '1.1.0',
|
||||
'2.9': '1.1.0',
|
||||
'2.10': '1.0.2',
|
||||
'2.11': '1.1.0',
|
||||
}
|
||||
ok = False
|
||||
|
||||
|
||||
def _expected_ver():
|
||||
try:
|
||||
import torch
|
||||
torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
return nunchaku_versions.get(torch_ver)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def check():
|
||||
global ok # pylint: disable=global-statement
|
||||
if ok:
|
||||
|
|
@ -16,8 +33,9 @@ def check():
|
|||
import nunchaku
|
||||
import nunchaku.utils
|
||||
from nunchaku import __version__
|
||||
expected = _expected_ver()
|
||||
log.info(f'Nunchaku: path={nunchaku.__path__} version={__version__.__version__} precision={nunchaku.utils.get_precision()}')
|
||||
if __version__.__version__ != nunchaku_ver:
|
||||
if expected is not None and __version__.__version__ != expected:
|
||||
ok = False
|
||||
return False
|
||||
ok = True
|
||||
|
|
@ -49,14 +67,16 @@ def install_nunchaku():
|
|||
if devices.backend not in ['cuda']:
|
||||
log.error(f'Nunchaku: backend={devices.backend} unsupported')
|
||||
return False
|
||||
torch_ver = torch.__version__[:3]
|
||||
if torch_ver not in ['2.5', '2.6', '2.7', '2.8', '2.9', '2.10']:
|
||||
torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
nunchaku_ver = nunchaku_versions.get(torch_ver)
|
||||
if nunchaku_ver is None:
|
||||
log.error(f'Nunchaku: torch={torch.__version__} unsupported')
|
||||
return False
|
||||
suffix = 'x86_64' if arch == 'linux' else 'win_amd64'
|
||||
url = os.environ.get('NUNCHAKU_COMMAND', None)
|
||||
if url is None:
|
||||
arch = f'{arch}_' if arch == 'linux' else ''
|
||||
url = f'https://huggingface.co/nunchaku-tech/nunchaku/resolve/main/nunchaku-{nunchaku_ver}'
|
||||
url = f'https://huggingface.co/nunchaku-ai/nunchaku/resolve/main/nunchaku-{nunchaku_ver}'
|
||||
url += f'+torch{torch_ver}-cp{python_ver}-cp{python_ver}-{arch}{suffix}.whl'
|
||||
cmd = f'install --upgrade {url}'
|
||||
log.debug(f'Nunchaku: install="{url}"')
|
||||
|
|
|
|||
|
|
@ -255,13 +255,25 @@ def check_quant(module: str = ''):
|
|||
|
||||
def check_nunchaku(module: str = ''):
|
||||
from modules import shared
|
||||
if module not in shared.opts.nunchaku_quantization:
|
||||
model_name = getattr(shared.opts, 'sd_model_checkpoint', '')
|
||||
if '+nunchaku' not in model_name:
|
||||
return False
|
||||
from modules import mit_nunchaku
|
||||
mit_nunchaku.install_nunchaku()
|
||||
if not mit_nunchaku.ok:
|
||||
return False
|
||||
return True
|
||||
base_path = model_name.split('+')[0]
|
||||
for v in shared.reference_models.values():
|
||||
if v.get('path', '') != base_path:
|
||||
continue
|
||||
nunchaku_modules = v.get('nunchaku', None)
|
||||
if nunchaku_modules is None:
|
||||
continue
|
||||
if isinstance(nunchaku_modules, bool) and nunchaku_modules:
|
||||
nunchaku_modules = ['Model', 'TE']
|
||||
if not isinstance(nunchaku_modules, list):
|
||||
continue
|
||||
if module in nunchaku_modules:
|
||||
from modules import mit_nunchaku
|
||||
mit_nunchaku.install_nunchaku()
|
||||
return mit_nunchaku.ok
|
||||
return False
|
||||
|
||||
|
||||
def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list = None, modules_dtype_dict: dict = None):
|
||||
|
|
|
|||
|
|
@ -270,7 +270,6 @@ def process_init(p: StableDiffusionProcessing):
|
|||
p.all_prompts, p.all_negative_prompts = shared.prompt_styles.apply_styles_to_prompts(p.all_prompts, p.all_negative_prompts, p.styles, p.all_seeds)
|
||||
p.prompts = p.all_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
|
||||
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
|
||||
p.prompts, _ = extra_networks.parse_prompts(p.prompts)
|
||||
|
||||
|
||||
def process_samples(p: StableDiffusionProcessing, samples):
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ def process_base(p: processing.StableDiffusionProcessing):
|
|||
modelstats.analyze()
|
||||
try:
|
||||
t0 = time.time()
|
||||
p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts, p.network_data)
|
||||
extra_networks.activate(p, exclude=['text_encoder', 'text_encoder_2', 'text_encoder_3'])
|
||||
|
||||
if hasattr(shared.sd_model, 'tgate') and getattr(p, 'gate_step', -1) > 0:
|
||||
|
|
@ -297,10 +298,20 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
|
|||
p.denoising_strength = strength
|
||||
orig_image = p.task_args.pop('image', None) # remove image override from hires
|
||||
process_pre(p)
|
||||
|
||||
prompts = p.prompts
|
||||
reset_prompts = False
|
||||
if len(p.refiner_prompt) > 0:
|
||||
prompts = len(output.images)* [p.refiner_prompt]
|
||||
prompts, p.network_data = extra_networks.parse_prompts(prompts)
|
||||
reset_prompts = True
|
||||
if reset_prompts or ('base' in p.skip):
|
||||
extra_networks.activate(p)
|
||||
|
||||
hires_args = set_pipeline_args(
|
||||
p=p,
|
||||
model=shared.sd_model,
|
||||
prompts=len(output.images)* [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
|
||||
prompts=prompts,
|
||||
negative_prompts=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
|
||||
prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
|
||||
negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
|
||||
|
|
@ -314,11 +325,10 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
|
|||
strength=strength,
|
||||
desc='Hires',
|
||||
)
|
||||
|
||||
hires_steps = hires_args.get('prior_num_inference_steps', None) or p.hr_second_pass_steps or hires_args.get('num_inference_steps', None)
|
||||
shared.state.update(get_job_name(p, shared.sd_model), hires_steps, 1)
|
||||
try:
|
||||
if 'base' in p.skip:
|
||||
extra_networks.activate(p)
|
||||
taskid = shared.state.begin('Inference')
|
||||
output = shared.sd_model(**hires_args) # pylint: disable=not-callable
|
||||
shared.state.end(taskid)
|
||||
|
|
|
|||
|
|
@ -18,14 +18,15 @@ def load_unet_sdxl_nunchaku(repo_id):
|
|||
shared.log.error(f'Load module: quant=Nunchaku module=unet repo="{repo_id}" low nunchaku version')
|
||||
return None
|
||||
if 'turbo' in repo_id.lower():
|
||||
nunchaku_repo = 'nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors'
|
||||
nunchaku_repo = 'nunchaku-ai/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors'
|
||||
else:
|
||||
nunchaku_repo = 'nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors'
|
||||
nunchaku_repo = 'nunchaku-ai/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors'
|
||||
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=unet repo="{nunchaku_repo}" offload={shared.opts.nunchaku_offload}')
|
||||
if shared.opts.nunchaku_offload:
|
||||
shared.log.warning('Load module: quant=Nunchaku module=unet offload not supported for SDXL, ignoring')
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=unet repo="{nunchaku_repo}"')
|
||||
unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
|
||||
nunchaku_repo,
|
||||
offload=shared.opts.nunchaku_offload,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,53 +4,23 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import contextlib
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from installer import (
|
||||
log as log,
|
||||
print_dict,
|
||||
console as console,
|
||||
get_version as get_version,
|
||||
)
|
||||
|
||||
log.debug("Initializing: shared module")
|
||||
from installer import log, print_dict, console, get_version # pylint: disable=unused-import
|
||||
log.debug('Initializing: shared module')
|
||||
|
||||
import modules.memmon
|
||||
import modules.paths as paths
|
||||
from modules.json_helpers import (
|
||||
readfile as readfile,
|
||||
writefile as writefile,
|
||||
)
|
||||
from modules.shared_helpers import (
|
||||
listdir as listdir,
|
||||
walk_files as walk_files,
|
||||
html_path as html_path,
|
||||
html as html,
|
||||
req as req,
|
||||
total_tqdm as total_tqdm,
|
||||
)
|
||||
from modules.json_helpers import readfile, writefile # pylint: disable=W0611
|
||||
from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611
|
||||
from modules import errors, devices, shared_state, cmd_args, theme, history, files_cache
|
||||
from modules.shared_defaults import get_default_modes
|
||||
from modules.paths import (
|
||||
models_path as models_path, # For compatibility, do not modify from here...
|
||||
script_path as script_path,
|
||||
data_path as data_path,
|
||||
sd_configs_path as sd_configs_path,
|
||||
sd_default_config as sd_default_config,
|
||||
sd_model_file as sd_model_file,
|
||||
default_sd_model_file as default_sd_model_file,
|
||||
extensions_dir as extensions_dir,
|
||||
extensions_builtin_dir as extensions_builtin_dir, # ... to here.
|
||||
)
|
||||
from modules.memstats import (
|
||||
memory_stats,
|
||||
ram_stats as ram_stats,
|
||||
)
|
||||
from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
|
||||
from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import
|
||||
|
||||
log.debug("Initializing: pipelines")
|
||||
log.debug('Initializing: pipelines')
|
||||
from modules import shared_items
|
||||
from modules.interrogate.openclip import caption_models, caption_types, get_clip_models, refresh_clip_models
|
||||
from modules.interrogate.vqa import vlm_models, vlm_prompts, vlm_system, vlm_default
|
||||
|
|
@ -280,7 +250,6 @@ options_templates.update(options_section(("quantization", "Model Quantization"),
|
|||
"sdnq_quantize_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox),
|
||||
|
||||
"nunchaku_sep": OptionInfo("<h2>Nunchaku Engine</h2>", "", gr.HTML),
|
||||
"nunchaku_quantization": OptionInfo([], "SVDQuant enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}),
|
||||
"nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox),
|
||||
"nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox),
|
||||
|
||||
|
|
|
|||
|
|
@ -305,6 +305,7 @@ class ExtraNetworksPage:
|
|||
subdirs['Reference'] = 1
|
||||
subdirs['Distilled'] = 1
|
||||
subdirs['Quantized'] = 1
|
||||
subdirs['Nunchaku'] = 1
|
||||
subdirs['Community'] = 1
|
||||
subdirs['Cloud'] = 1
|
||||
subdirs[diffusers_base] = 1
|
||||
|
|
@ -324,6 +325,8 @@ class ExtraNetworksPage:
|
|||
subdirs.move_to_end('Distilled', last=True)
|
||||
if 'Quantized' in subdirs:
|
||||
subdirs.move_to_end('Quantized', last=True)
|
||||
if 'Nunchaku' in subdirs:
|
||||
subdirs.move_to_end('Nunchaku', last=True)
|
||||
if 'Community' in subdirs:
|
||||
subdirs.move_to_end('Community', last=True)
|
||||
if 'Cloud' in subdirs:
|
||||
|
|
@ -332,7 +335,7 @@ class ExtraNetworksPage:
|
|||
for subdir in subdirs:
|
||||
if len(subdir) == 0:
|
||||
continue
|
||||
if subdir in ['All', 'Local', 'Diffusers', 'Reference', 'Distilled', 'Quantized', 'Community', 'Cloud']:
|
||||
if subdir in ['All', 'Local', 'Diffusers', 'Reference', 'Distilled', 'Quantized', 'Nunchaku', 'Community', 'Cloud']:
|
||||
style = 'network-reference'
|
||||
else:
|
||||
style = 'network-folder'
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import html
|
|||
import json
|
||||
import concurrent
|
||||
from datetime import datetime
|
||||
from modules import shared, ui_extra_networks, sd_models, modelstats, paths
|
||||
from modules import shared, ui_extra_networks, sd_models, modelstats, paths, devices
|
||||
from modules.json_helpers import readfile
|
||||
|
||||
|
||||
|
|
@ -48,16 +48,21 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||
reference_distilled = readfile(os.path.join('data', 'reference-distilled.json'), as_type="dict")
|
||||
reference_community = readfile(os.path.join('data', 'reference-community.json'), as_type="dict")
|
||||
reference_cloud = readfile(os.path.join('data', 'reference-cloud.json'), as_type="dict")
|
||||
reference_nunchaku = readfile(os.path.join('data', 'reference-nunchaku.json'), as_type="dict")
|
||||
shared.reference_models = {}
|
||||
shared.reference_models.update(reference_base)
|
||||
shared.reference_models.update(reference_quant)
|
||||
shared.reference_models.update(reference_community)
|
||||
shared.reference_models.update(reference_distilled)
|
||||
shared.reference_models.update(reference_cloud)
|
||||
shared.reference_models.update(reference_nunchaku)
|
||||
|
||||
for k, v in shared.reference_models.items():
|
||||
count['total'] += 1
|
||||
url = v['path']
|
||||
if v.get('hidden', False):
|
||||
count['hidden'] += 1
|
||||
continue
|
||||
experimental = v.get('experimental', False)
|
||||
if experimental:
|
||||
if shared.cmd_opts.experimental:
|
||||
|
|
@ -83,6 +88,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||
path = f'{v.get("path", "")}'
|
||||
|
||||
tag = v.get('tags', '')
|
||||
if tag == 'nunchaku' and (devices.backend != 'cuda' and not shared.cmd_opts.experimental):
|
||||
count['hidden'] += 1
|
||||
continue
|
||||
if tag in count:
|
||||
count[tag] += 1
|
||||
elif tag != '':
|
||||
|
|
|
|||
|
|
@ -9,19 +9,19 @@ def load_flux_nunchaku(repo_id):
|
|||
if 'srpo' in repo_id.lower():
|
||||
pass
|
||||
elif 'flux.1-dev' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
elif 'flux.1-schnell' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-kontext' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
elif 'flux.1-krea' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-krea-dev/svdq-{nunchaku_precision}_r32-flux.1-krea-dev.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-krea-dev/svdq-{nunchaku_precision}_r32-flux.1-krea-dev.safetensors"
|
||||
elif 'flux.1-fill' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-fill-dev/svdq-{nunchaku_precision}-flux.1-fill-dev.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-fill-dev/svdq-{nunchaku_precision}-flux.1-fill-dev.safetensors"
|
||||
elif 'flux.1-depth' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{nunchaku_precision}-flux.1-depth-dev.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-flux.1-depth-dev/svdq-{nunchaku_precision}-flux.1-depth-dev.safetensors"
|
||||
elif 'shuttle' in repo_id.lower():
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}-shuttle-jaguar.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}-shuttle-jaguar.safetensors"
|
||||
else:
|
||||
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ def load_text_encoder(repo_id, cls_name, load_config=None, subfolder="text_encod
|
|||
elif cls_name == transformers.T5EncoderModel and allow_shared and shared.opts.te_shared_t5:
|
||||
if model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
repo_id = 'nunchaku-ai/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
cls_name = nunchaku.NunchakuT5EncoderModel
|
||||
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant" loader={_loader("transformers")}')
|
||||
text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained(
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
|
|||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["qwen-image"] = diffusers.QwenImageInpaintPipeline
|
||||
|
||||
if model_quant.check_nunchaku('Model'):
|
||||
transformer = qwen.load_qwen_nunchaku(repo_id)
|
||||
transformer = qwen.load_qwen_nunchaku(repo_id, subfolder=repo_subfolder)
|
||||
|
||||
if 'Qwen-Image-Distill-Full' in repo_id:
|
||||
repo_transformer = repo_id
|
||||
|
|
@ -63,6 +63,8 @@ def load_qwen(checkpoint_info, diffusers_load_config=None):
|
|||
text_encoder = generic.load_text_encoder(repo_te, cls_name=transformers.Qwen2_5_VLForConditionalGeneration, load_config=diffusers_load_config)
|
||||
|
||||
repo_id, repo_subfolder = qwen.check_qwen_pruning(repo_id, repo_subfolder)
|
||||
if repo_subfolder is not None and repo_subfolder.startswith('nunchaku'):
|
||||
repo_subfolder = None
|
||||
pipe = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
transformer=transformer,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def load_quants(kwargs, repo_id, cache_dir):
|
|||
if 'Sana_1600M_1024px' in repo_id and model_quant.check_nunchaku('Model'): # only available model
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = "nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
|
||||
nunchaku_repo = "nunchaku-ai/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} attention={shared.opts.nunchaku_attention}')
|
||||
kwargs['transformer'] = nunchaku.NunchakuSanaTransformer2DModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir)
|
||||
elif model_quant.check_quant('Model'):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ def load_nunchaku():
|
|||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_rank = 128
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-z-image-turbo/svdq-{nunchaku_precision}_r{nunchaku_rank}-z-image-turbo.safetensors"
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-z-image-turbo/svdq-{nunchaku_precision}_r{nunchaku_rank}-z-image-turbo.safetensors"
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" attention={shared.opts.nunchaku_attention}')
|
||||
transformer = nunchaku.NunchakuZImageTransformer2DModel.from_pretrained(
|
||||
nunchaku_repo,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from modules import shared, devices
|
||||
|
||||
|
||||
def load_qwen_nunchaku(repo_id):
|
||||
def load_qwen_nunchaku(repo_id, subfolder=None):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = None
|
||||
transformer = None
|
||||
four_step = subfolder is not None and '4step' in subfolder
|
||||
try:
|
||||
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
|
||||
except Exception:
|
||||
|
|
@ -14,15 +15,21 @@ def load_qwen_nunchaku(repo_id):
|
|||
if 'pruning' in repo_id.lower() or 'distill' in repo_id.lower():
|
||||
return None
|
||||
elif repo_id.lower().endswith('qwen-image'):
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors" # r32 vs r128
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors"
|
||||
elif repo_id.lower().endswith('qwen-lightning'):
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.1-8steps.safetensors" # 8-step variant
|
||||
if four_step:
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.0-4steps.safetensors"
|
||||
else:
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image-lightningv1.1-8steps.safetensors"
|
||||
elif repo_id.lower().endswith('qwen-image-edit-2509'):
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit-2509/svdq-{nunchaku_precision}_r128-qwen-image-edit-2509.safetensors" # 8-step variant
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit-2509/svdq-{nunchaku_precision}_r128-qwen-image-edit-2509.safetensors"
|
||||
elif repo_id.lower().endswith('qwen-image-edit'):
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit.safetensors" # 8-step variant
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit.safetensors"
|
||||
elif repo_id.lower().endswith('qwen-lightning-edit'):
|
||||
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-8steps.safetensors" # 8-step variant
|
||||
if four_step:
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-4steps.safetensors"
|
||||
else:
|
||||
nunchaku_repo = f"nunchaku-ai/nunchaku-qwen-image-edit/svdq-{nunchaku_precision}_r128-qwen-image-edit-lightningv1.0-8steps.safetensors"
|
||||
else:
|
||||
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,360 @@
|
|||
[project]
|
||||
name = "SD.Next"
|
||||
version = "0.0.0"
|
||||
description = "SD.Next: All-in-one WebUI for AI generative image and video creation and captioning"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 250
|
||||
indent-width = 4
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"venv",
|
||||
".git",
|
||||
".ruff_cache",
|
||||
".vscode",
|
||||
|
||||
"modules/cfgzero",
|
||||
"modules/flash_attn_triton_amd",
|
||||
"modules/hidiffusion",
|
||||
"modules/intel/ipex",
|
||||
"modules/pag",
|
||||
"modules/schedulers",
|
||||
"modules/teacache",
|
||||
"modules/seedvr",
|
||||
|
||||
"modules/control/proc",
|
||||
"modules/control/units",
|
||||
"modules/control/units/xs_pipe.py",
|
||||
"modules/postprocess/aurasr_arch.py",
|
||||
|
||||
"pipelines/meissonic",
|
||||
"pipelines/omnigen2",
|
||||
"pipelines/hdm",
|
||||
"pipelines/segmoe",
|
||||
"pipelines/xomni",
|
||||
"pipelines/chrono",
|
||||
|
||||
"scripts/lbm",
|
||||
"scripts/daam",
|
||||
"scripts/xadapter",
|
||||
"scripts/pulid",
|
||||
"scripts/instantir",
|
||||
"scripts/freescale",
|
||||
"scripts/consistory",
|
||||
|
||||
"repositories",
|
||||
|
||||
"extensions-builtin/Lora",
|
||||
"extensions-builtin/sd-extension-chainner/nodes",
|
||||
"extensions-builtin/sd-webui-agent-scheduler",
|
||||
"extensions-builtin/sdnext-modernui/node_modules",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"F",
|
||||
"E",
|
||||
"W",
|
||||
"C",
|
||||
"B",
|
||||
"I",
|
||||
"YTT",
|
||||
"ASYNC",
|
||||
"RUF",
|
||||
"AIR",
|
||||
"NPY",
|
||||
"C4",
|
||||
"T10",
|
||||
"EXE",
|
||||
"ISC",
|
||||
"ICN",
|
||||
"RSE",
|
||||
"TC",
|
||||
"TID",
|
||||
"INT",
|
||||
"PLE",
|
||||
]
|
||||
ignore = [
|
||||
"B006", # Do not use mutable data structures for argument defaults
|
||||
"B008", # Do not perform function call in argument defaults
|
||||
"B905", # Strict zip() usage
|
||||
"C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead
|
||||
"C408", # Unnecessary `dict` call
|
||||
"I001", # Import block is un-sorted or un-formatted
|
||||
"E402", # Module level import not at top of file
|
||||
"E501", # Line too long
|
||||
"E721", # Do not compare types, use `isinstance()`
|
||||
"E731", # Do not assign a `lambda` expression, use a `def`
|
||||
"E741", # Ambiguous variable name
|
||||
"F401", # Imported by unused
|
||||
"EXE001", # file with shebang is not marked executable
|
||||
"NPY002", # replace legacy random
|
||||
"RUF005", # Consider iterable unpacking
|
||||
"RUF008", # Do not use mutable default values for dataclass
|
||||
"RUF010", # Use explicit conversion flag
|
||||
"RUF012", # Mutable class attributes
|
||||
"RUF013", # PEP 484 prohibits implicit `Optional`
|
||||
"RUF015", # Prefer `next(...)` over single element slice
|
||||
"RUF046", # Value being cast to `int` is already an integer
|
||||
"RUF059", # Unpacked variables are not used
|
||||
"RUF051", # Prefer pop over del
|
||||
]
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
docstring-code-format = false
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 150
|
||||
|
||||
[tool.pylint]
|
||||
main.py-version="3.10"
|
||||
main.analyse-fallback-blocks=false
|
||||
main.clear-cache-post-run=false
|
||||
main.extension-pkg-allow-list=""
|
||||
main.prefer-stubs=true
|
||||
main.extension-pkg-whitelist=""
|
||||
main.fail-on=""
|
||||
main.fail-under=10
|
||||
main.ignore="CVS"
|
||||
main.ignore-paths=[
|
||||
"venv",
|
||||
".git",
|
||||
".ruff_cache",
|
||||
".vscode",
|
||||
"modules/apg",
|
||||
"modules/cfgzero",
|
||||
"modules/control/proc",
|
||||
"modules/control/units",
|
||||
"modules/dml",
|
||||
"modules/flash_attn_triton_amd",
|
||||
"modules/ggml",
|
||||
"modules/hidiffusion",
|
||||
"modules/hijack/ddpm_edit.py",
|
||||
"modules/intel",
|
||||
"modules/intel/ipex",
|
||||
"modules/framepack/pipeline",
|
||||
"modules/onnx_impl",
|
||||
"modules/pag",
|
||||
"modules/postprocess/aurasr_arch.py",
|
||||
"modules/prompt_parser_xhinker.py",
|
||||
"modules/ras",
|
||||
"modules/seedvr",
|
||||
"modules/rife",
|
||||
"modules/schedulers",
|
||||
"modules/taesd",
|
||||
"modules/teacache",
|
||||
"modules/todo",
|
||||
"modules/res4lyf",
|
||||
"pipelines/bria",
|
||||
"pipelines/flex2",
|
||||
"pipelines/f_lite",
|
||||
"pipelines/hidream",
|
||||
"pipelines/hdm",
|
||||
"pipelines/meissonic",
|
||||
"pipelines/omnigen2",
|
||||
"pipelines/segmoe",
|
||||
"pipelines/xomni",
|
||||
"pipelines/chrono",
|
||||
"scripts/consistory",
|
||||
"scripts/ctrlx",
|
||||
"scripts/daam",
|
||||
"scripts/demofusion",
|
||||
"scripts/freescale",
|
||||
"scripts/infiniteyou",
|
||||
"scripts/instantir",
|
||||
"scripts/lbm",
|
||||
"scripts/layerdiffuse",
|
||||
"scripts/mod",
|
||||
"scripts/pixelsmith",
|
||||
"scripts/differential_diffusion.py",
|
||||
"scripts/pulid",
|
||||
"scripts/xadapter",
|
||||
"repositories",
|
||||
"extensions-builtin/sd-extension-chainner/nodes",
|
||||
"extensions-builtin/sd-webui-agent-scheduler",
|
||||
"extensions-builtin/sdnext-modernui/node_modules",
|
||||
"extensions-builtin/sdnext-kanvas/node_modules",
|
||||
]
|
||||
main.ignore-patterns=[
|
||||
".*test*.py$",
|
||||
".*_model.py$",
|
||||
".*_arch.py$",
|
||||
".*_model_arch.py*",
|
||||
".*_model_arch_v2.py$",
|
||||
]
|
||||
main.ignored-modules=""
|
||||
main.jobs=8
|
||||
main.limit-inference-results=100
|
||||
main.load-plugins=""
|
||||
main.persistent=false
|
||||
main.recursive=false
|
||||
main.source-roots=""
|
||||
main.unsafe-load-any-extension=false
|
||||
basic.argument-naming-style="snake_case"
|
||||
basic.attr-naming-style="snake_case"
|
||||
basic.bad-names=["foo", "bar", "baz", "toto", "tutu", "tata"]
|
||||
basic.bad-names-rgxs=""
|
||||
basic.class-attribute-naming-style="any"
|
||||
basic.class-const-naming-style="UPPER_CASE"
|
||||
basic.class-naming-style="PascalCase"
|
||||
basic.const-naming-style="snake_case"
|
||||
basic.docstring-min-length=-1
|
||||
basic.function-naming-style="snake_case"
|
||||
basic.good-names=["i","j","k","e","ex","ok","p","x","y","id"]
|
||||
basic.good-names-rgxs=""
|
||||
basic.include-naming-hint=false
|
||||
basic.inlinevar-naming-style="any"
|
||||
basic.method-naming-style="snake_case"
|
||||
basic.module-naming-style="snake_case"
|
||||
basic.name-group=""
|
||||
basic.no-docstring-rgx="^_"
|
||||
basic.property-classes="abc.abstractproperty"
|
||||
basic.variable-naming-style="snake_case"
|
||||
classes.check-protected-access-in-special-methods=false
|
||||
classes.defining-attr-methods=["__init__", "__new__"]
|
||||
classes.exclude-protected=["_asdict","_fields","_replace","_source","_make","os._exit"]
|
||||
classes.valid-classmethod-first-arg="cls"
|
||||
classes.valid-metaclass-classmethod-first-arg="mcs"
|
||||
design.exclude-too-few-public-methods=""
|
||||
design.ignored-parents=""
|
||||
design.max-args=199
|
||||
design.max-attributes=99
|
||||
design.max-bool-expr=99
|
||||
design.max-branches=199
|
||||
design.max-locals=99
|
||||
design.max-parents=99
|
||||
design.max-public-methods=99
|
||||
design.max-returns=99
|
||||
design.max-statements=199
|
||||
design.min-public-methods=1
|
||||
exceptions.overgeneral-exceptions=["builtins.BaseException","builtins.Exception"]
|
||||
format.expected-line-ending-format=""
|
||||
# format.ignore-long-lines="^\s*(# )?<?https?://\S+>?$"
|
||||
format.indent-after-paren=4
|
||||
format.indent-string=' '
|
||||
format.max-line-length=200
|
||||
format.max-module-lines=9999
|
||||
format.single-line-class-stmt=false
|
||||
format.single-line-if-stmt=false
|
||||
imports.allow-any-import-level=""
|
||||
imports.allow-reexport-from-package=false
|
||||
imports.allow-wildcard-with-all=false
|
||||
imports.deprecated-modules=""
|
||||
imports.ext-import-graph=""
|
||||
imports.import-graph=""
|
||||
imports.int-import-graph=""
|
||||
imports.known-standard-library=""
|
||||
imports.known-third-party="enchant"
|
||||
imports.preferred-modules=""
|
||||
logging.logging-format-style="new"
|
||||
logging.logging-modules="logging"
|
||||
messages_control.confidence=["HIGH","CONTROL_FLOW","INFERENCE","INFERENCE_FAILURE","UNDEFINED"]
|
||||
messages_control.disable=[
|
||||
"abstract-method",
|
||||
"bad-inline-option",
|
||||
"bare-except",
|
||||
"broad-exception-caught",
|
||||
"chained-comparison",
|
||||
"consider-iterating-dictionary",
|
||||
"consider-merging-isinstance",
|
||||
"consider-using-dict-items",
|
||||
"consider-using-enumerate",
|
||||
"consider-using-from-import",
|
||||
"consider-using-generator",
|
||||
"consider-using-get",
|
||||
"consider-using-in",
|
||||
"consider-using-max-builtin",
|
||||
"consider-using-min-builtin",
|
||||
"consider-using-sys-exit",
|
||||
"cyclic-import",
|
||||
"dangerous-default-value",
|
||||
"deprecated-pragma",
|
||||
"duplicate-code",
|
||||
"file-ignored",
|
||||
"import-error",
|
||||
"import-outside-toplevel",
|
||||
"invalid-name",
|
||||
"line-too-long",
|
||||
"locally-disabled",
|
||||
"logging-fstring-interpolation",
|
||||
"missing-class-docstring",
|
||||
"missing-function-docstring",
|
||||
"missing-module-docstring",
|
||||
"no-else-raise",
|
||||
"no-else-return",
|
||||
"not-callable",
|
||||
"pointless-string-statement",
|
||||
"raw-checker-failed",
|
||||
"simplifiable-if-expression",
|
||||
"suppressed-message",
|
||||
"too-few-public-methods",
|
||||
"too-many-instance-attributes",
|
||||
"too-many-locals",
|
||||
"too-many-nested-blocks",
|
||||
"too-many-positional-arguments",
|
||||
"too-many-statements",
|
||||
"unidiomatic-typecheck",
|
||||
"unknown-option-value",
|
||||
"unnecessary-dict-index-lookup",
|
||||
"unnecessary-dunder-call",
|
||||
"unnecessary-lambda-assigment",
|
||||
"unnecessary-lambda",
|
||||
"unused-wildcard-import",
|
||||
"unpacking-non-sequence",
|
||||
"unsubscriptable-object",
|
||||
"useless-return",
|
||||
"use-dict-literal",
|
||||
"use-symbolic-message-instead",
|
||||
"useless-suppression",
|
||||
"wrong-import-position",
|
||||
]
|
||||
messages_control.enable="c-extension-no-member"
|
||||
method_args.timeout-methods=["requests.api.delete","requests.api.get","requests.api.head","requests.api.options","requests.api.patch","requests.api.post","requests.api.put","requests.api.request"]
|
||||
miscellaneous.notes=["FIXME","XXX","TODO"]
|
||||
miscellaneous.notes-rgx=""
|
||||
refactoring.max-nested-blocks=5
|
||||
refactoring.never-returning-functions=["sys.exit","argparse.parse_error"]
|
||||
reports.msg-template=""
|
||||
reports.reports=false
|
||||
reports.score=false
|
||||
similarities.ignore-comments=true
|
||||
similarities.ignore-docstrings=true
|
||||
similarities.ignore-imports=true
|
||||
similarities.ignore-signatures=true
|
||||
similarities.min-similarity-lines=4
|
||||
spelling.max-spelling-suggestions=4
|
||||
# spelling.dict=""
|
||||
# spelling.ignore-comment-directives=["fmt: on","fmt: off","noqa:","noqa","nosec","isort:skip","mypy:"]
|
||||
# spelling.ignore-words=""
|
||||
# spelling.private-dict-file=""
|
||||
# spelling.store-unknown-words=false
|
||||
string.check-quote-consistency=false
|
||||
string.check-str-concat-over-line-jumps=false
|
||||
typecheck.contextmanager-decorators="contextlib.contextmanager"
|
||||
typecheck.generated-members=["numpy.*","logging.*","torch.*","cv2.*"]
|
||||
typecheck.ignore-none=true
|
||||
typecheck.ignore-on-opaque-inference=true
|
||||
typecheck.ignored-checks-for-mixins=["no-member","not-async-context-manager","not-context-manager","attribute-defined-outside-init"]
|
||||
typecheck.ignored-classes=["optparse.Values","thread._local","_thread._local","argparse.Namespace","unittest.case._AssertRaisesContext","unittest.case._AssertWarnsContext"]
|
||||
typecheck.missing-member-hint=true
|
||||
typecheck.missing-member-hint-distance=1
|
||||
typecheck.missing-member-max-choices=1
|
||||
typecheck.mixin-class-rgx=".*[Mm]ixin"
|
||||
typecheck.signature-mutators=""
|
||||
variables.additional-builtins=""
|
||||
variables.allow-global-unused-variables=true
|
||||
variables.allowed-redefined-builtins=""
|
||||
variables.callbacks=["cb_",]
|
||||
variables.dummy-variables-rgx="_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_"
|
||||
variables.ignored-argument-names="_.*|^ignored_|^unused_"
|
||||
variables.init-import=false
|
||||
variables.redefining-builtins-modules=["six.moves","past.builtins","future.builtins","builtins","io"]
|
||||
|
|
@ -55,7 +55,7 @@ protobuf==4.25.3
|
|||
pytorch_lightning==2.6.0
|
||||
urllib3==1.26.19
|
||||
Pillow==10.4.0
|
||||
timm==1.0.16
|
||||
timm==1.0.24
|
||||
pyparsing==3.2.3
|
||||
typing-extensions==4.14.1
|
||||
sentencepiece==0.2.1
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class Script(scripts_manager.Script):
|
|||
with gr.Row():
|
||||
gr.HTML('<a href="https://blackforestlabs.ai/flux-1-tools/">  Flux.1 Redux</a><br>')
|
||||
with gr.Row():
|
||||
tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Canny', 'Depth'], value='None')
|
||||
tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Fill (Nunchaku)', 'Canny', 'Depth', 'Depth (Nunchaku)'], value='None')
|
||||
with gr.Row():
|
||||
prompt = gr.Slider(label='Redux prompt strength', minimum=0, maximum=2, step=0.01, value=0, visible=False)
|
||||
process = gr.Checkbox(label='Control preprocess input images', value=True, visible=False)
|
||||
|
|
@ -34,8 +34,8 @@ class Script(scripts_manager.Script):
|
|||
def display(tool):
|
||||
return [
|
||||
gr.update(visible=tool in ['Redux']),
|
||||
gr.update(visible=tool in ['Canny', 'Depth']),
|
||||
gr.update(visible=tool in ['Canny', 'Depth']),
|
||||
gr.update(visible=tool in ['Canny', 'Depth', 'Depth (Nunchaku)']),
|
||||
gr.update(visible=tool in ['Canny', 'Depth', 'Depth (Nunchaku)']),
|
||||
]
|
||||
|
||||
tool.change(fn=display, inputs=[tool], outputs=[prompt, process, strength])
|
||||
|
|
@ -91,13 +91,15 @@ class Script(scripts_manager.Script):
|
|||
shared.log.debug(f'{title}: tool=Redux unload')
|
||||
redux_pipe = None
|
||||
|
||||
if tool == 'Fill':
|
||||
if tool in ['Fill', 'Fill (Nunchaku)']:
|
||||
# pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, revision="refs/pr/4").to("cuda")
|
||||
if p.image_mask is None:
|
||||
shared.log.error(f'{title}: tool={tool} no image_mask')
|
||||
return None
|
||||
if shared.sd_model.__class__.__name__ != 'FluxFillPipeline':
|
||||
shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Fill-dev"
|
||||
nunchaku_suffix = '+nunchaku' if tool == 'Fill (Nunchaku)' else ''
|
||||
checkpoint = f"black-forest-labs/FLUX.1-Fill-dev{nunchaku_suffix}"
|
||||
if shared.sd_model.__class__.__name__ != 'FluxFillPipeline' or shared.opts.sd_model_checkpoint != checkpoint:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint
|
||||
sd_models.reload_model_weights(op='model', revision="refs/pr/4")
|
||||
p.task_args['image'] = image
|
||||
p.task_args['mask_image'] = p.image_mask
|
||||
|
|
@ -124,11 +126,13 @@ class Script(scripts_manager.Script):
|
|||
shared.log.debug(f'{title}: tool=Canny unload processor')
|
||||
processor_canny = None
|
||||
|
||||
if tool == 'Depth':
|
||||
if tool in ['Depth', 'Depth (Nunchaku)']:
|
||||
# pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda")
|
||||
install('git+https://github.com/huggingface/image_gen_aux.git', 'image_gen_aux')
|
||||
if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Depth' not in shared.opts.sd_model_checkpoint:
|
||||
shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Depth-dev"
|
||||
nunchaku_suffix = '+nunchaku' if tool == 'Depth (Nunchaku)' else ''
|
||||
checkpoint = f"black-forest-labs/FLUX.1-Depth-dev{nunchaku_suffix}"
|
||||
if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or shared.opts.sd_model_checkpoint != checkpoint:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint
|
||||
sd_models.reload_model_weights(op='model', revision="refs/pr/1")
|
||||
if processor_depth is None:
|
||||
from image_gen_aux import DepthPreprocessor
|
||||
|
|
|
|||
|
|
@ -1,18 +1,25 @@
|
|||
from typing import Dict, Any
|
||||
import time
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
warmup = 2
|
||||
repeats = 50
|
||||
dtypes = [torch.bfloat16] # , torch.float16]
|
||||
# if hasattr(torch, "float8_e4m3fn"):
|
||||
# dtypes.append(torch.float8_e4m3fn)
|
||||
|
||||
PROFILES = {
|
||||
backends = [
|
||||
"sdpa_math",
|
||||
"sdpa_mem_efficient",
|
||||
"sdpa_flash",
|
||||
"sdpa_all",
|
||||
"flex_attention",
|
||||
"xformers",
|
||||
"flash_attn",
|
||||
"sage_attn",
|
||||
]
|
||||
profiles = {
|
||||
"sdxl": {"l_q": 4096, "l_k": 4096, "h": 32, "d": 128},
|
||||
"flux.1": {"l_q": 16717, "l_k": 16717, "h": 24, "d": 128},
|
||||
"sd35": {"l_q": 16538, "l_k": 16538, "h": 24, "d": 128},
|
||||
|
|
@ -30,19 +37,19 @@ def get_stats(reset: bool = False):
|
|||
m = torch.cuda.max_memory_allocated()
|
||||
t = time.perf_counter()
|
||||
return m / (1024 ** 2), t
|
||||
|
||||
|
||||
def print_gpu_info():
|
||||
if not torch.cuda.is_available():
|
||||
print("GPU: Not available")
|
||||
return
|
||||
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
total_mem = props.total_memory / (1024**3)
|
||||
free_mem, _ = torch.cuda.mem_get_info(device)
|
||||
free_mem = free_mem / (1024**3)
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
|
||||
|
||||
print(f"gpu: {torch.cuda.get_device_name(device)}")
|
||||
print(f"vram: total={total_mem:.2f}GB free={free_mem:.2f}GB")
|
||||
print(f"cuda: capability={major}.{minor} version={torch.version.cuda}")
|
||||
|
|
@ -56,11 +63,9 @@ def benchmark_attention(
|
|||
l_k: int = 4096,
|
||||
h: int = 32,
|
||||
d: int = 128,
|
||||
warmup: int = 10,
|
||||
repeats: int = 100
|
||||
) -> Dict[str, Any]:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
# Initialize tensors
|
||||
q = torch.randn(b, h, l_q, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
|
||||
k = torch.randn(b, h, l_k, d, device=device, dtype=torch.float16 if dtype.is_floating_point and dtype.itemsize == 1 else dtype, requires_grad=False).to(dtype)
|
||||
|
|
@ -78,7 +83,7 @@ def benchmark_attention(
|
|||
try:
|
||||
if backend.startswith("sdpa_"):
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
sdp_type = backend[len("sdpa_"):]
|
||||
sdp_type = backend[len("sdpa_"):]
|
||||
# Map friendly names to new SDPA backends
|
||||
backend_map = {
|
||||
"math": [SDPBackend.MATH],
|
||||
|
|
@ -88,21 +93,21 @@ def benchmark_attention(
|
|||
}
|
||||
if sdp_type not in backend_map:
|
||||
raise ValueError(f"Unknown SDPA type: {sdp_type}")
|
||||
|
||||
|
||||
results["version"] = torch.__version__
|
||||
|
||||
|
||||
with sdpa_kernel(backend_map[sdp_type]):
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
_ = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
start_mem, start_time = get_stats(True)
|
||||
|
||||
|
||||
for _ in range(repeats):
|
||||
_ = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
end_mem, end_time = get_stats()
|
||||
|
||||
|
||||
results["latency_ms"] = (end_time - start_time) / repeats * 1000
|
||||
results["memory_mb"] = end_mem - start_mem
|
||||
|
||||
|
|
@ -113,17 +118,17 @@ def benchmark_attention(
|
|||
q_fa = q.transpose(1, 2)
|
||||
k_fa = k.transpose(1, 2)
|
||||
v_fa = v.transpose(1, 2)
|
||||
|
||||
|
||||
for _ in range(warmup):
|
||||
_ = flash_attn_func(q_fa, k_fa, v_fa)
|
||||
|
||||
|
||||
start_mem, start_time = get_stats(True)
|
||||
|
||||
|
||||
for _ in range(repeats):
|
||||
_ = flash_attn_func(q_fa, k_fa, v_fa)
|
||||
|
||||
|
||||
end_mem, end_time = get_stats()
|
||||
|
||||
|
||||
results["latency_ms"] = (end_time - start_time) / repeats * 1000
|
||||
results["memory_mb"] = end_mem - start_mem
|
||||
|
||||
|
|
@ -135,17 +140,17 @@ def benchmark_attention(
|
|||
q_xf = q.transpose(1, 2)
|
||||
k_xf = k.transpose(1, 2)
|
||||
v_xf = v.transpose(1, 2)
|
||||
|
||||
|
||||
for _ in range(warmup):
|
||||
_ = memory_efficient_attention(q_xf, k_xf, v_xf)
|
||||
|
||||
|
||||
start_mem, start_time = get_stats(True)
|
||||
|
||||
|
||||
for _ in range(repeats):
|
||||
_ = memory_efficient_attention(q_xf, k_xf, v_xf)
|
||||
|
||||
|
||||
end_mem, end_time = get_stats()
|
||||
|
||||
|
||||
results["latency_ms"] = (end_time - start_time) / repeats * 1000
|
||||
results["memory_mb"] = end_mem - start_mem
|
||||
|
||||
|
|
@ -158,18 +163,18 @@ def benchmark_attention(
|
|||
results["version"] = importlib.metadata.version("sageattention")
|
||||
except Exception:
|
||||
results["version"] = getattr(sageattention, "__version__", "N/A")
|
||||
|
||||
|
||||
# SageAttention expects (B, H, L, D) logic
|
||||
for _ in range(warmup):
|
||||
_ = sageattn(q, k, v)
|
||||
|
||||
|
||||
start_mem, start_time = get_stats(True)
|
||||
|
||||
|
||||
for _ in range(repeats):
|
||||
_ = sageattn(q, k, v)
|
||||
|
||||
|
||||
end_mem, end_time = get_stats()
|
||||
|
||||
|
||||
results["latency_ms"] = (end_time - start_time) / repeats * 1000
|
||||
results["memory_mb"] = end_mem - start_mem
|
||||
|
||||
|
|
@ -179,62 +184,52 @@ def benchmark_attention(
|
|||
|
||||
# flex_attention requires torch.compile for performance
|
||||
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
|
||||
|
||||
|
||||
# Warmup (important to trigger compilation)
|
||||
for _ in range(warmup):
|
||||
_ = flex_attention_compiled(q, k, v)
|
||||
|
||||
|
||||
start_mem, start_time = get_stats(True)
|
||||
|
||||
|
||||
for _ in range(repeats):
|
||||
_ = flex_attention_compiled(q, k, v)
|
||||
|
||||
|
||||
end_mem, end_time = get_stats()
|
||||
|
||||
|
||||
results["latency_ms"] = (end_time - start_time) / repeats * 1000
|
||||
results["memory_mb"] = end_mem - start_mem
|
||||
except Exception as e:
|
||||
results["status"] = "fail"
|
||||
results["error"] = str(e)[:49]
|
||||
print(e)
|
||||
|
||||
return results
|
||||
|
||||
def main():
|
||||
backends = [
|
||||
"sdpa_math",
|
||||
"sdpa_mem_efficient",
|
||||
"sdpa_flash",
|
||||
"flex_attention",
|
||||
"xformers",
|
||||
"flash_attn",
|
||||
"sage_attn",
|
||||
]
|
||||
|
||||
def main():
|
||||
|
||||
all_results = []
|
||||
|
||||
print_gpu_info()
|
||||
print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}')
|
||||
for name, config in PROFILES.items():
|
||||
print(f'config: warmup={warmup} repeats={repeats} dtypes={dtypes}')
|
||||
for name, config in profiles.items():
|
||||
print(f"profile: {name} (L_q={config['l_q']}, L_k={config['l_k']}, H={config['h']}, D={config['d']})")
|
||||
for dtype in dtypes:
|
||||
print(f" dtype: {dtype}")
|
||||
print(f" {'backend':<20} | {'version':<12} | {'status':<8} | {'latency':<10} | {'memory':<12} | ")
|
||||
for backend in backends:
|
||||
res = benchmark_attention(
|
||||
backend,
|
||||
dtype,
|
||||
l_q=config["l_q"],
|
||||
l_k=config["l_k"],
|
||||
h=config["h"],
|
||||
d=config["d"],
|
||||
warmup=warmup,
|
||||
repeats=repeats
|
||||
backend,
|
||||
dtype,
|
||||
l_q=config["l_q"],
|
||||
l_k=config["l_k"],
|
||||
h=config["h"],
|
||||
d=config["d"],
|
||||
)
|
||||
all_results.append(res)
|
||||
|
||||
|
||||
latency = f"{res['latency_ms']:.4f} ms"
|
||||
memory = f"{res['memory_mb']:.2f} MB"
|
||||
|
||||
|
||||
print(f" {res['backend']:<20} | {res['version']:<12} | {res['status']:<8} | {latency:<10} | {memory:<12} | {res['error']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
|
@ -20,25 +21,25 @@ tolerance_pct = 5
|
|||
# tests
|
||||
tests = [
|
||||
# - empty
|
||||
["", { '': 100 } ],
|
||||
["", { '': 100 } ],
|
||||
# - no weights
|
||||
[ "red|blonde|black", { 'black': 33, 'red': 33, 'blonde': 33 } ],
|
||||
[ "red|blonde|black", { 'black': 33, 'red': 33, 'blonde': 33 } ],
|
||||
# - full weights <= 1
|
||||
[ "red:0.1|blonde:0.9", { 'blonde': 90, 'red': 10 } ],
|
||||
[ "red:0.1|blonde:0.9", { 'blonde': 90, 'red': 10 } ],
|
||||
# - weights > 1 to test normalization
|
||||
[ "red:1|blonde:2|black:5", { 'blonde': 25, 'red': 12.5, 'black': 62.5 } ],
|
||||
[ "red:1|blonde:2|black:5", { 'blonde': 25, 'red': 12.5, 'black': 62.5 } ],
|
||||
# - disabling 0 weights to force one result
|
||||
[ "red:0|blonde|black:0", { 'blonde': 100 } ],
|
||||
[ "red:0|blonde|black:0", { 'blonde': 100 } ],
|
||||
# - weights <= 1 with distribution of the leftover
|
||||
[ "red:0.5|blonde|black:0.3|brown", { 'red': 50, 'black': 30, 'brown': 10, 'blonde': 10 } ],
|
||||
[ "red:0.5|blonde|black:0.3|brown", { 'red': 50, 'black': 30, 'brown': 10, 'blonde': 10 } ],
|
||||
# - weights > 1, unweightes should get default of 1
|
||||
[ "red:2|blonde|black", { 'red': 50, 'blonde': 25, 'black': 25 } ],
|
||||
[ "red:2|blonde|black", { 'red': 50, 'blonde': 25, 'black': 25 } ],
|
||||
# - ignore content of ()
|
||||
[ "red:0.5|(blonde:1.3)", { 'red': 50, '(blonde:1.3)': 50 } ],
|
||||
[ "red:0.5|(blonde:1.3)", { 'red': 50, '(blonde:1.3)': 50 } ],
|
||||
# - ignore content of []
|
||||
[ "red:0.5|[stuff:1.3]", { '[stuff:1.3]': 50, 'red': 50 } ],
|
||||
[ "red:0.5|[stuff:1.3]", { '[stuff:1.3]': 50, 'red': 50 } ],
|
||||
# - ignore content of <>
|
||||
[ "red:0.5|<lora:1.0>", { '<lora:1.0>': 50, 'red': 50 } ]
|
||||
[ "red:0.5|<lora:1.0>", { '<lora:1.0>': 50, 'red': 50 } ]
|
||||
]
|
||||
|
||||
# -------------------------------------------------
|
||||
|
|
@ -109,4 +110,4 @@ with open(fn, 'r', encoding='utf-8') as f:
|
|||
print("RESULT: FAILED (distribution)")
|
||||
else:
|
||||
print("RESULT: PASSED")
|
||||
print('')
|
||||
print('')
|
||||
Loading…
Reference in New Issue