ai-agents

Signed-off-by: vladmandic <mandic00@live.com>
pull/4752/head
vladmandic 2026-04-10 12:27:21 +02:00
parent 8d871af63d
commit b9a38b5955
24 changed files with 2785 additions and 60 deletions

View File

@ -58,3 +58,49 @@ General app structure is:
- Shared mutable global state can create subtle regressions; prefer narrow, explicit changes.
- Device/backend-specific code paths (**CUDA/ROCm/IPEX/DirectML/OpenVINO**) should not assume one platform.
- Scripts and extension loading is dynamic; failures may appear only when specific extensions or models are present.
## Repo-Local Skills
Use these repo-local skills for recurring SD.Next model integration work:
- `port-model`
File: `.github/skills/port-model/SKILL.md`
Use when adding a new model family, porting a standalone script into a Diffusers pipeline, or wiring an upstream Diffusers model into SD.Next.
- `debug-model`
File: `.github/skills/debug-model/SKILL.md`
Use when a new or existing SD.Next/Diffusers model integration fails during detection, loading, prompt encoding, sampling, or output handling.
- `check-api`
File: `.github/skills/check-api/SKILL.md`
Use when auditing API routes in `modules/api/api.py` and validating endpoint parameters plus request/response signatures.
- `check-schedulers`
File: `.github/skills/check-schedulers/SKILL.md`
Use when auditing scheduler registrations in `modules/sd_samplers_diffusers.py` for class loadability, config validity, and `SamplerData` mapping correctness.
- `check-models`
File: `.github/skills/check-models/SKILL.md`
Use when running end-to-end model integration audits across loaders, detect/routing parity, reference catalogs, and pipeline API contracts.
- `check-processing`
File: `.github/skills/check-processing/SKILL.md`
Use when validating txt2img/img2img/control processing workflows from UI submit definitions through backend processing and Diffusers execution, including parameter, type, and initialization checks.
- `check-scripts`
File: `.github/skills/check-scripts/SKILL.md`
Use when auditing scripts in `scripts/*.py` for standard `Script` overrides (`__init__`, `title`, `show`) and validating `ui()` output against `run()` or `process()` parameters.
- `github-issues`
File: `.github/skills/github-issues/SKILL.md`
Use when reading SD.Next GitHub issues with `[Issues]` in title and producing a markdown report with short summary, status, and suggested next steps for each issue.
- `github-features`
File: `.github/skills/github-features/SKILL.md`
Use when reading SD.Next GitHub issues with `[Feature]` in title and producing a markdown report with short summary, status, and suggested next steps for each issue.
- `analyze-model`
File: `.github/skills/analyze-model/SKILL.md`
Use when analyzing an external model URL to identify implementation style and estimate how difficult it is to port into SD.Next.
When creating and updating skills, update this file and the index in `.github/skills/README.md` accordingly.

51
.github/skills/README.md vendored Normal file
View File

@ -0,0 +1,51 @@
# Repo Skills
This folder contains repo-local Copilot skills for recurring SD.Next tasks.
## Available Skills
- `port-model`
File: `port-model/SKILL.md`
Use when adding or porting a model family into SD.Next and Diffusers.
- `debug-model`
File: `debug-model/SKILL.md`
Use when a new or existing SD.Next/Diffusers model integration fails during detect, load, prompt encode, sample, or output handling.
- `check-api`
File: `check-api/SKILL.md`
Use when auditing API endpoints in `modules/api/api.py` and delegated API modules for route parameter correctness plus request/response signature consistency.
- `check-schedulers`
File: `check-schedulers/SKILL.md`
Use when auditing scheduler registrations from `modules/sd_samplers_diffusers.py` to verify class loadability, config validity against scheduler capabilities, and `SamplerData` correctness.
- `check-models`
File: `check-models/SKILL.md`
Use when running an end-to-end model integration audit covering loaders, detect/routing parity, reference catalogs, and custom pipeline API contracts.
- `check-processing`
File: `check-processing/SKILL.md`
Use when validating txt2img/img2img/control processing workflows from UI submit definitions to backend execution with parameter, type, and initialization checks.
- `check-scripts`
File: `check-scripts/SKILL.md`
Use when auditing `scripts/*.py` for correct Script overrides (`__init__`, `title`, `show`) and verifying `ui()` output compatibility with `run()` or `process()` parameters.
- `github-issues`
File: `github-issues/SKILL.md`
Use when reading SD.Next GitHub issues with `[Issues]` in title and producing a markdown summary with status and suggested next steps for each issue.
- `github-features`
File: `github-features/SKILL.md`
Use when reading SD.Next GitHub issues with `[Feature]` in title and producing a markdown summary with status and suggested next steps for each issue.
- `analyze-model`
File: `analyze-model/SKILL.md`
Use when analyzing an external model URL to classify implementation style and estimate SD.Next porting difficulty before coding.
## Notes
- Keep skills narrowly task-oriented and reusable.
- Prefer referencing existing repo patterns over generic framework advice.
- Update this index when adding new repo-local skills.

133
.github/skills/analyze-model/SKILL.md vendored Normal file
View File

@ -0,0 +1,133 @@
---
name: analyze-model
description: "Analyze an external model URL (typically Hugging Face) to determine implementation style and estimate SD.Next porting difficulty using the port-model workflow."
argument-hint: "Provide model URL and optional target scope: text2img, img2img, edit, video, or full integration"
---
# Analyze External Model For SD.Next Porting
Given an external model URL, inspect how the model is implemented and estimate how hard it is to port into SD.Next according to the port-model skill.
## When To Use
- User provides a Hugging Face model URL and asks if or how it can be ported
- User wants effort estimation before implementation work
- User wants to classify whether integration should reuse Diffusers, use custom Diffusers code, or require full custom implementation
## Accepted Inputs
- Hugging Face model URL (preferred)
- Hugging Face repo id
- Optional source links to custom code repos or PRs
Example:
- https://huggingface.co/jdopensource/JoyAI-Image-Edit
## Required Outputs
For each analyzed model, provide:
1. Implementation classification
2. Evidence for classification
3. SD.Next porting path recommendation
4. Difficulty rating and effort breakdown
5. Main risks and blockers
6. First implementation steps aligned with port-model skill
## Implementation Classification Buckets
Classify into one of these (or closest fit):
1. Integrated into upstream Diffusers library
2. Custom Diffusers implementation in model repo (custom pipeline/classes, not upstream)
3. Fully custom implementation (non-Diffusers inference stack)
4. Existing integration in ComfyUI (node or PR) but not in Diffusers
5. Other (describe clearly)
## Procedure
### 1. Inspect Model Repository Artifacts
From the provided URL/repo, collect:
- model card details
- files such as model_index.json, config.json, scheduler config, tokenizer files
- presence of Diffusers-style folder layout
- references to custom Python modules or remote code requirements
### 2. Determine Runtime Stack
Identify whether model usage is:
- standard Diffusers pipeline call
- custom Diffusers pipeline class with trust_remote_code
- pure custom inference script or framework
- node-based integration in ComfyUI or another host
### 3. Cross-Check Integration Surface
Determine required SD.Next touchpoints if ported:
- loader file in pipelines/model_name.py
- detect and dispatch updates in modules/sd_detect.py and modules/sd_models.py
- model type mapping in modules/modeldata.py
- optional custom pipeline package under pipelines/model/
- reference catalog updates and preview asset requirements
### 4. Estimate Porting Difficulty
Use this scale:
- Low: mostly loader wiring to existing upstream Diffusers pipeline
- Medium: custom Diffusers classes or limited checkpoint/config adaptation
- High: full custom architecture, major prompt/sampler/output differences, or sparse docs
- Very High: no usable Diffusers path plus major runtime assumptions mismatch
Break down difficulty by:
- loader complexity
- pipeline/API contract complexity
- scheduler/sampler compatibility
- prompt encoding complexity
- checkpoint conversion/remapping complexity
- validation and testing burden
### 5. Identify Risks
Call out concrete risks:
- missing or incompatible scheduler config
- unclear output domain (latent vs pixel)
- custom text encoder or processor constraints
- nonstandard checkpoint format
- dependency on external runtime features unavailable in SD.Next
### 6. Recommend Porting Path
Map recommendation to port-model workflow:
- Upstream Diffusers reuse path
- Custom Diffusers pipeline package path
- Raw checkpoint plus remap path
Provide a concise first-step plan with smallest viable integration milestone.
## Reporting Format
Return sections in this order:
1. Classification
2. Evidence
3. Porting difficulty
4. Recommended SD.Next integration path
5. Risks and unknowns
6. Next actions
If critical information is missing, explicitly list unknowns and what to inspect next.
## Notes
- Prefer concrete evidence from repository files and model card over assumptions.
- If source suggests multiple possible paths, compare at least two and state why one is preferred.
- Keep recommendations aligned with the conventions defined in the port-model skill.

121
.github/skills/check-api/SKILL.md vendored Normal file
View File

@ -0,0 +1,121 @@
---
name: check-api
description: "Audit SD.Next API route definitions and verify endpoint parameters plus request/response signatures against declared FastAPI contracts."
argument-hint: "Optionally focus on specific route prefixes or endpoint groups"
---
# Check API Endpoints And Signatures
Read modules/api/api.py, enumerate all registered endpoints, and validate that each endpoint has coherent request parameters and response signatures.
## When To Use
- The user asks to audit API correctness
- A change touched API routes, endpoint handlers, or API models
- OpenAPI docs look wrong or clients report schema mismatches
- You need a pre-PR API contract sanity pass
## Primary File
- `modules/api/api.py`
This file is the route registration hub and must be treated as the source of truth for direct `add_api_route(...)` registrations.
## Secondary Files To Inspect
- `modules/api/models.py`
- `modules/api/endpoints.py`
- `modules/api/generate.py`
- `modules/api/process.py`
- `modules/api/control.py`
- Any module loaded via `register_api(...)` from `modules/api/*` and feature modules (caption, lora, gallery, civitai, rembg, etc.)
## Audit Goals
For every endpoint, verify:
1. Route method and path are valid and unique after subpath handling.
2. Handler call signature is compatible with the route declaration.
3. Declared `response_model` is coherent with returned payload shape.
4. Request body or query params implied by handler type hints are consistent with expected client usage.
5. Authentication behavior is intentional (`auth=True` default in `add_api_route`).
6. OpenAPI schema exposure is correct (including trailing-slash duplicate suppression).
## Procedure
### 1. Enumerate Routes In modules/api/api.py
- Collect every `self.add_api_route(...)` call.
- Capture: path, methods, handler symbol, response_model, tags, auth flag.
- Note route groups: server, generation, processing, scripts, enumerators, functional.
### 2. Resolve Handler Definitions
For each handler symbol:
- Locate the callable definition.
- Read its function signature and type hints.
- Identify required positional args, optional args, and body model annotations.
Flag issues such as:
- Required parameters that cannot be supplied by FastAPI.
- Mismatch between route method and handler intent (for example body expected on GET).
- Ambiguous or missing type hints for public API handlers.
### 3. Validate Request Signatures
- Confirm request model classes exist and are importable.
- Confirm endpoint signature reflects expected request source (path/query/body).
- Check optional vs required semantics for compatibility-sensitive endpoints.
### 4. Validate Response Signatures
- Compare declared `response_model` to handler return shape.
- Ensure list/dict wrappers match actual payload structure.
- Flag obvious drift where `response_model` implies fields never returned.
### 5. Include register_api(...) Modules
`Api.register()` delegates additional endpoints through module-level `register_api(...)` calls.
- Inspect each delegated module for routes.
- Apply the same request/response signature checks.
- Include these findings in the final report, not only direct routes in api.py.
### 6. Verify Runtime Schema (Optional But Preferred)
If feasible in the current environment:
- Build app and inspect `app.routes` metadata.
- Generate OpenAPI schema and spot-check key endpoints.
- Confirm trailing-slash duplicate suppression behavior remains correct.
If runtime schema checks are not feasible, explicitly state that and rely on static validation.
## Reporting Format
Report findings ordered by severity:
1. Breaking contract mismatches
2. Likely runtime errors
3. Schema quality or consistency issues
4. Minor style/typing improvements
For each finding include:
- Route path and method
- File and function
- Why it is a mismatch
- Minimal fix recommendation
If no issues are found, state that explicitly and mention residual risk (for example runtime-only behavior not executed).
## Output Expectations
When this skill is used, return:
- Total endpoints checked
- Direct routes vs delegated `register_api(...)` routes checked
- Findings list with severity and locations
- Clear pass/fail summary for request and response signature consistency

151
.github/skills/check-models/SKILL.md vendored Normal file
View File

@ -0,0 +1,151 @@
---
name: check-models
description: "Audit SD.Next model integrations end-to-end: loaders, detect/routing, reference catalogs, and pipeline API contracts."
argument-hint: "Optionally focus on a model family, repo id, or a subset: loader, detect-routing, references, pipeline-contracts"
---
# Check Model Integrations End-To-End
Run a consolidated model-integration audit that combines loader checks, detect/routing checks, reference-catalog checks, and pipeline contract checks.
## When To Use
- A new model family was added and needs a completeness audit
- Existing model support appears inconsistent across detection, loading, and UI references
- A custom pipeline was ported and needs contract validation
- You want a pre-PR integration quality gate for model-related changes
## Combined Scope
This skill combines four audit surfaces:
1. Loader consistency (`check-loaders` equivalent)
2. Detect/routing parity (`check-detect-routing` equivalent)
3. Reference-catalog integrity (`check-reference-catalog` equivalent)
4. Pipeline API contract conformance (`check-pipeline-contracts` equivalent)
## Primary Files
- `pipelines/model_*.py`
- `modules/sd_detect.py`
- `modules/sd_models.py`
- `modules/modeldata.py`
- `data/reference.json`
- `data/reference-cloud.json`
- `data/reference-quant.json`
- `data/reference-distilled.json`
- `data/reference-nunchaku.json`
- `data/reference-community.json`
- `models/Reference/`
Pipeline files as needed:
- `pipelines/<model>/pipeline.py`
- `pipelines/<model>/model.py`
## Audit A: Loader Consistency
For each target model loader in `pipelines/model_*.py`, verify:
- Correct `sd_models.path_to_repo(checkpoint_info)` and `sd_models.hf_auth_check(...)` usage
- Load args built with `model_quant.get_dit_args(...)` where applicable
- No duplicated kwargs (for example duplicate `torch_dtype`)
- Correct component loading path (`generic.load_transformer`, `generic.load_text_encoder`, tokenizer/processor)
- Proper post-load hooks (`sd_hijack_te`, `sd_hijack_vae`) where required
- Correct `pipe.task_args` defaults where needed
- Cleanup and `devices.torch_gc(...)` present
Flag stale patterns, missing hooks, or conflicting load behavior.
## Audit B: Detect/Routing Parity
Verify model family alignment across:
- `modules/sd_detect.py` detection heuristics
- `modules/sd_models.py` load dispatch branch
- `modules/modeldata.py` reverse classification from loaded pipeline class
Checks:
- Family is detectable by name/repo conventions
- Dispatch routes to the intended loader
- Loaded pipeline class is classified back to the same model family
- Branch ordering does not cause broad matches to shadow specific families
## Audit C: Reference Catalog Integrity
Verify references for model families intended to appear in model references.
Checks:
- Correct category file placement by type:
- base -> `data/reference.json`
- cloud -> `data/reference-cloud.json`
- quant -> `data/reference-quant.json`
- distilled -> `data/reference-distilled.json`
- nunchaku -> `data/reference-nunchaku.json`
- community -> `data/reference-community.json`
- Required fields present per entry (`path`, `preview`, `desc` when expected)
- Duplicate repo/path collisions across reference files are intentional or flagged
- Preview filename convention is consistent
- Referenced preview file exists in `models/Reference/` (or explicitly placeholder if intentional)
- JSON validity for touched reference files
## Audit D: Pipeline API Contracts
For custom pipelines (`pipelines/<model>/pipeline.py`), verify:
- Inherits from `diffusers.DiffusionPipeline`
- Registers modules correctly
- `from_pretrained` wiring is coherent with actual artifact layout
- `encode_prompt` semantics are consistent with tokenizer/text encoder setup
- `__call__` supports expected public args for its task and does not expose unsupported generic args
- Batch and negative prompt behavior are coherent
- Output conversion aligns with model output domain (latent vs pixel space)
- `output_type` and `return_dict` behavior are consistent
## Runtime Validation (Preferred)
When feasible:
- Import-level smoke tests for loaders and pipeline modules
- Lightweight loader construction checks without full heavy generation where possible
- One minimal generation/sampling pass for changed model families
If runtime checks are not feasible, report limitations clearly.
## Reporting Format
Return findings by severity:
1. Blocking integration failures
2. Contract mismatches (load/detect/reference/pipeline)
3. Consistency and quality issues
4. Optional improvements
For each finding include:
- model family
- layer (`loader`, `detect-routing`, `reference`, `pipeline-contract`)
- file location
- mismatch summary
- minimal fix
Also include summary counts:
- loaders checked
- model families checked for detect/routing parity
- reference files checked
- pipeline contracts checked
- runtime checks executed vs skipped
## Pass Criteria
A full pass requires all of the following in audited scope:
- loader path is coherent and non-conflicting
- detect/routing/modeldata parity holds
- reference entries are valid, categorized correctly, and have preview files
- custom pipeline contracts are consistent with actual model behavior
If any area is intentionally out of scope, mark as partial pass with explicit exclusions.

178
.github/skills/check-processing/SKILL.md vendored Normal file
View File

@ -0,0 +1,178 @@
---
name: check-processing
description: "Validate txt2img/img2img/control/caption processing workflows from UI submit bindings to backend processing execution and confirm parameter/type/init correctness."
argument-hint: "Optionally focus on txt2img, img2img, control, caption, or process-only and include changed files"
---
# Check Processing Workflow Contracts
Trace generation workflows from UI definitions and submit bindings to backend execution, then validate that parameters are passed, typed, and initialized correctly.
## When To Use
- A change touched UI submit wiring for txt2img, img2img, control, or caption workflows
- Processing code changed and regressions are suspected in argument ordering or defaults
- A new parameter was added to UI or processing classes/functions and needs end-to-end validation
- You want a pre-PR contract audit for generation flow integrity
## Required Workflow Coverage
Start from UI definitions and follow each workflow to final implementation:
1. `txt2img`: `modules/ui_txt2img.py` -> `modules/txt2img.py` -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers`
2. `img2img`: `modules/ui_img2img.py` -> `modules/img2img.py` -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers`
3. `control/process`: `modules/ui_control.py` -> `modules/control/run.py` (and related control processing entrypoints) -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers`
4. `caption/process`: `modules/ui_caption.py` -> caption handler module(s) -> `modules/processing.py:process_images` and/or postprocess/caption execution module(s), depending on selected caption backend
Also validate script hooks when present:
- `modules/scripts_manager.py` (`run`, `before_process`, `process`, `process_images`, `after`)
## Primary Files
- `modules/ui_txt2img.py`
- `modules/txt2img.py`
- `modules/ui_img2img.py`
- `modules/img2img.py`
- `modules/ui_control.py`
- `modules/ui_caption.py` (and `modules/ui_captions.py` if present)
- `modules/control/run.py`
- `modules/processing.py`
- `modules/processing_diffusers.py`
- `modules/scripts_manager.py`
## Audit Goals
For each covered workflow, verify all three dimensions:
1. Parameter pass-through correctness (name, order, semantic meaning)
2. Type correctness (UI component output shape vs function signature expectations)
3. Initialization correctness (defaults, `None` handling, fallback logic, and object state)
## Procedure
### 1. Build End-To-End Call Graph
For each workflow (`txt2img`, `img2img`, `control`, `caption`):
- Locate submit/click bindings in UI modules.
- Capture the exact `inputs=[...]` list order and target function (`fn=...`).
- Resolve wrappers (`call_queue.wrap_gradio_gpu_call`, queued wrappers) to actual function signatures.
- Follow function flow through processing class construction and execution (`processing.process_images`, then `process_diffusers` when applicable).
Produce a normalized mapping table per workflow:
- UI input component name
- UI expected output type
- receiving argument in submit target
- receiving processing-object field (if applicable)
- downstream consumption point
### 2. Validate Argument Order And Arity
For each submit path:
- Compare UI `inputs` order against function positional parameter order.
- Validate handling of `*args` and script arguments.
- Confirm `state`, task id, mode flags, and tab selections align with function signatures.
- Flag positional drift where adding/removing an argument in one layer is not propagated.
### 3. Validate Name And Semantic Parity
Check that semantically related parameters remain coherent across layers:
- sampler fields (`sampler_index`, `hr_sampler_index`, sampler name conversion)
- guidance/cfg fields
- size/resize fields
- denoise/refiner/hires fields
- detailer, hdr, grading fields
- override settings fields
Flag mismatches such as:
- same concept with divergent naming and wrong destination
- field sent but never consumed
- required field consumed but never provided
### 4. Validate Type Contracts
Audit type compatibility from UI component to processing target:
- `gr.Image`/`gr.File`/`gr.Video` outputs vs expected Python types (`PIL.Image`, bytes, list, path-like, etc.)
- radios returning index/value and expected downstream representation
- sliders/number inputs (`int` vs `float`) and conversion points
- optional objects (`None`) and `.name`/attribute access safety
Flag ambiguous or unsafe assumptions, especially for optional file inputs and mixed scalar/list values.
### 5. Validate Initialization And Defaults
In target modules (`txt2img.py`, `img2img.py`, `control/run.py`, processing classes):
- verify defaults/fallbacks for invalid or missing inputs
- verify guards for unset model/state and expected error paths
- verify object fields are initialized before first use
- verify flow-specific defaults are not leaking across workflows
Include checks for common regressions:
- `None` passed into required processing fields
- missing fallback for sampler/seed/size values
- stale fields retained from prior job state
### 6. Validate Script Hook Contracts
Where script systems are involved:
- verify `scripts_*.run(...)` fallback behavior to `processing.process_images(...)`
- verify `scripts_*.after(...)` receives compatible processed object
- ensure script args wiring matches `setup_ui(...)` order
### 7. Runtime Spot Check (Preferred)
If feasible, run lightweight smoke validation for each workflow:
- one minimal txt2img run
- one minimal img2img run
- one minimal control run
- one minimal caption run
Use very small dimensions/steps to limit runtime.
If runtime checks are not feasible, explicitly report static-only limitations.
## Reporting Format
Return findings by severity:
1. Blocking contract failures (will break execution)
2. High-risk mismatches (likely wrong behavior)
3. Type/init safety issues
4. Non-blocking consistency issues
For each finding include:
- workflow (`txt2img` | `img2img` | `control` | `caption`)
- layer transition (`ui -> handler`, `handler -> processing`, `processing -> diffusers`)
- file location
- mismatch summary
- minimal fix
Also include summary counts:
- workflows checked
- UI bindings checked
- function signatures checked
- parameter mappings validated
- runtime checks executed vs skipped
## Pass Criteria
A full pass requires all of the following:
- UI submit input order matches target function signatures
- all required parameters are passed end-to-end with correct semantics
- type expectations are explicit and compatible at each boundary
- initialization/default logic prevents unset/invalid state usage
- scripts fallback path to `processing.process_images` is coherent
If only part of the workflow scope was checked, report partial pass with explicit exclusions.

149
.github/skills/check-schedulers/SKILL.md vendored Normal file
View File

@ -0,0 +1,149 @@
---
name: check-schedulers
description: "Audit scheduler registrations starting from modules/sd_samplers_diffusers.py and verify class loadability, config validity against scheduler capabilities, and SamplerData correctness."
argument-hint: "Optionally focus on a scheduler subset, such as flow-matching, res4lyf, or parallel schedulers"
---
# Check Scheduler Registry And Contracts
Use `modules/sd_samplers_diffusers.py` as the starting point and verify that scheduler classes, scheduler config, and `SamplerData` mappings are coherent and executable.
## Required Guarantees
The audit must explicitly verify all three:
1. All scheduler classes can be loaded and compiled.
2. All scheduler config entries are valid and match scheduler capabilities in `__init__`.
3. All scheduler classes have valid associated `SamplerData` entries and mapping correctness.
## Scope
Primary file:
- `modules/sd_samplers_diffusers.py`
Related files:
- `modules/sd_samplers_common.py` for `SamplerData` definition and sampler expectations
- `modules/sd_samplers.py` for sampler selection flow and runtime wiring
- `modules/schedulers/**/*.py` for custom scheduler implementations
- `modules/res4lyf/**/*.py` for Res4Lyf scheduler classes (if installed/enabled)
## What "Loaded And Compiled" Means
Treat this as a two-level check:
1. Import and class resolution:
every scheduler class referenced by `SamplerData(... DiffusionSampler(..., SchedulerClass, ...))` resolves without missing symbol errors.
2. Construction and compile sanity:
scheduler instances can be created from their intended config path and survive a lightweight execution-level check (including compile path when available).
Notes:
- For non-`torch.nn.Module` schedulers, "compiled" means the scheduler integration path is executable in runtime checks (not necessarily `torch.compile`).
- If the environment cannot run compile checks, report this explicitly and still complete static validation.
## Procedure
### 1. Build Scheduler Inventory
- Enumerate scheduler classes imported in `modules/sd_samplers_diffusers.py`.
- Enumerate all entries in `samplers_data_diffusers`.
- Enumerate all config keys in `config`.
Create a joined table by sampler name with:
- sampler name
- scheduler class
- config key used
- custom scheduler category (diffusers, SD.Next custom, Res4Lyf)
### 2. Validate Scheduler Class Resolution
For each mapped scheduler class:
- confirm symbol exists and is importable
- confirm class object is callable for construction paths used in SD.Next
Flag missing imports, dead entries, or stale class names.
### 3. Validate Config Against Scheduler __init__ Capabilities
For each sampler config entry:
- inspect scheduler class `__init__` signature and accepted config fields
- flag config keys that are unsupported or misspelled
- flag required scheduler init/config fields that are absent
- verify defaults and overrides are compatible with scheduler family behavior
Special attention:
- flow-matching schedulers: shift/base_shift/max_shift/use_dynamic_shifting
- DPM families: algorithm_type/solver_order/solver_type/final_sigmas_type
- compatibility-only keys that are intentionally ignored should be documented, not silently assumed
### 4. Validate SamplerData Mapping Correctness
For each `SamplerData` entry:
- scheduler label matches config key intent
- callable builds `DiffusionSampler` with the expected scheduler class
- mapping is not accidentally pointing to a different named preset
- no duplicate names with conflicting class/config behavior
Flag mismatches such as wrong display name, wrong class wired to name, or stale aliasing.
### 5. Runtime Smoke Checks (Preferred)
If feasible, run lightweight checks:
- instantiate each scheduler from config
- execute minimal scheduler setup path (`set_timesteps` and a dummy step where possible)
- verify no immediate runtime contract errors
If runtime checks are not feasible for some schedulers, mark those explicitly as unverified-at-runtime.
### 6. Compile Path Validation
Where scheduler runtime path supports compile-related checks in SD.Next:
- verify scheduler integration path remains compatible with compile options
- detect obvious compile blockers introduced by signature/config mismatches
Do not mark compile as passed if only static checks were done.
## Reporting Format
Return findings ordered by severity:
1. Blocking scheduler load failures
2. Config/signature contract mismatches
3. `SamplerData` mapping inconsistencies
4. Non-blocking improvements
For each finding include:
- sampler name
- scheduler class
- file location
- mismatch reason
- minimal corrective action
Also include summary counts:
- total scheduler classes discovered
- total `SamplerData` entries checked
- total config entries checked
- runtime-validated count
- compile-path validated count
## Pass Criteria
The check passes only if all are true:
- all referenced scheduler classes resolve
- each scheduler config entry is compatible with scheduler capabilities
- each `SamplerData` entry is correctly mapped and usable
- no blocking runtime or compile-path failures in validated scope
If scope is partial due to environment limitations, report pass with explicit limitations, not a full pass.

129
.github/skills/check-scripts/SKILL.md vendored Normal file
View File

@ -0,0 +1,129 @@
---
name: check-scripts
description: "Audit scripts/*.py and verify Script override contracts (init/title/show) plus ui() output compatibility with run() or process() parameters."
argument-hint: "Optionally focus on a subset of scripts or only run-vs-ui or process-vs-ui checks"
---
# Check Script Class Contracts
Audit all Python scripts in `scripts/*.py` and validate that script class overrides and UI-to-execution parameter contracts are correct.
## When To Use
- New or changed files were added under `scripts/*.py`
- A script crashes when selected or executed from UI
- A script UI was changed and runtime args no longer match
- You want a pre-PR quality gate for script API compatibility
## Scope
Primary audit scope:
- `scripts/*.py`
Contract references:
- `modules/scripts_manager.py` (`Script` base class contracts for `title`, `show`, `ui`, `run`, `process`)
- `modules/scripts_postprocessing.py` (`ScriptPostprocessing` contracts for `ui` and `process`)
## Required Checks
### A. Standard Overrides: `__init__`, `title`, `show`
For each class in `scripts/*.py` that subclasses `scripts.Script` or `scripts_manager.Script`:
1. `title`:
- method exists
- callable signature is valid
- returns non-empty string value
2. `show`:
- method exists
- signature is compatible with script runner usage (`show(is_img2img)` or permissive `*args/**kwargs`)
- return behavior is compatible (`bool` or `scripts.AlwaysVisible` / equivalent)
3. `__init__` (if overridden):
- does not require mandatory constructor args that would break loader instantiation
- avoids side effects that require runtime-only globals at import time
- leaves class in a usable state before `ui()`/`run()`/`process()` are called
Notes:
- `__init__` is optional; do not fail scripts that rely on inherited constructor.
- For dynamic patterns, flag as warning with rationale instead of hard fail.
### B. `ui()` Output vs `run()`/`process()` Parameters
For each script class:
1. Determine execution target:
- Prefer `run()` if present for generation scripts
- Use `process()` if present and `run()` is absent or script is postprocessing-oriented
2. Compare `ui()` output shape to target method parameter expectations:
- `ui()` list/tuple output count should match target positional argument capacity after the first processing arg (`p` or `pp`), unless target uses `*args`
- if target is strict positional (no `*args`/`**kwargs`), detect missing/extra UI values
- if target uses keyword-driven processing, ensure UI dict keys map to accepted params or `**kwargs`
3. Validate ordering assumptions:
- UI control order should align with positional parameter order when positional binding is used
- detect obvious drift when new UI control was added but method signature was not updated
4. Validate optionality/defaults:
- required target parameters should be satisfiable by UI outputs
- defaulted target params are acceptable even if UI omits them
### C. Runner Compatibility
Confirm script methods align with runner expectations in `modules/scripts_manager.py`:
- `ui()` return type is compatible with runner collection (`list/tuple` or recognized mapping pattern where used)
- `run()`/`process()` receive args in expected form from runner slices
- no obvious mismatch between `args_from/args_to` assumptions and script method arity
For postprocessing-style scripts in `scripts/*.py`:
- verify compatibility with `modules/scripts_postprocessing.py` conventions (`ui()` list/dict, `process(pp, *args, **kwargs)`)
## Procedure
1. Enumerate all classes in `scripts/*.py` and classify by base class type.
2. For each generation script class, validate `title`, `show`, optional `__init__`, and `ui` -> `run/process` contracts.
3. For each postprocessing script class under `scripts/*.py`, validate `ui` -> `process` mapping semantics.
4. Cross-check ambiguous cases against script runner behavior from `modules/scripts_manager.py` and `modules/scripts_postprocessing.py`.
5. Report concrete mismatches with minimal fixes.
## Reporting Format
Return findings by severity:
1. Blocking script contract failures
2. Runtime- likely arg/arity mismatches
3. Signature/type compatibility warnings
4. Style/consistency improvements
For each finding include:
- script file
- class name
- failing contract area (`init`, `title`, `show`, `ui->run`, `ui->process`)
- mismatch summary
- minimal fix
Also include summary counts:
- total `scripts/*.py` files checked
- total script classes checked
- classes with `run` contract checked
- classes with `process` contract checked
- override issues found (`init/title/show`)
## Pass Criteria
A full pass requires all of the following across audited `scripts/*.py` classes:
- `title` and `show` overrides are valid and runner-compatible for generation scripts
- overridden `__init__` methods are safely instantiable
- `ui()` output contracts are compatible with `run()` or `process()` args
- no blocking arity/signature mismatch remains
If a class uses highly dynamic argument routing that cannot be proven statically, mark as conditional pass with explicit runtime validation recommendation.

217
.github/skills/debug-model/SKILL.md vendored Normal file
View File

@ -0,0 +1,217 @@
---
name: debug-model
description: "Debug a broken SD.Next or Diffusers model integration. Use when a newly added or ported model fails to load, misdetects, crashes during prompt encoding or sampling, or produces incorrect outputs."
argument-hint: "Describe the failing model, where it fails, the error message, and whether the model is upstream Diffusers, custom pipeline, or raw checkpoint based"
---
# Debug SD.Next And Diffusers Model Port
Read the error, identify which integration layer is failing, isolate the smallest reproducible failure point, fix the root cause, and validate the fix without expanding scope.
## When To Use
- A newly added SD.Next model type does not autodetect correctly
- The loader fails to instantiate a pipeline or component
- A custom pipeline imports but fails during `from_pretrained`
- Prompt encoding fails because of tokenizer, processor, or text encoder mismatch
- Sampling fails due to tensor shape, dtype, device, or scheduler issues
- The model loads but outputs corrupted images, wrong output type, or obviously incorrect results
## Debugging Order
Always debug from the outside in.
1. Detection and routing
2. Loader arguments and component selection
3. Checkpoint path and artifact layout
4. Weight loading and key mapping
5. Prompt encoding
6. Sampling forward path
7. Output postprocessing and SD.Next task integration
Do not start by rewriting the architecture if the failure is likely in detection, loader wiring, or output handling.
## Files To Check First
- `.github/copilot-instructions.md`
- `.github/instructions/core.instructions.md`
- `modules/sd_detect.py`
- `modules/sd_models.py`
- `modules/modeldata.py`
- `pipelines/model_<name>.py`
- `pipelines/<model>/model.py`
- `pipelines/<model>/pipeline.py`
- `pipelines/generic.py`
If the port is based on a standalone script, compare the failing path against the original reference implementation and identify the first semantic divergence.
## Failure Classification
### 1. Model Not Detected Or Misclassified
Check:
- Filename and repo-name heuristics in `modules/sd_detect.py`
- Loader dispatch branch in `modules/sd_models.py`
- Reverse pipeline classification in `modules/modeldata.py`
Typical symptoms:
- Wrong loader called
- Pipeline classified as a broader family such as `chroma` instead of a custom `zetachroma`
- Task switching behaves incorrectly because the loaded pipeline type is wrong
### 2. Loader Fails Before Pipeline Construction
Check:
- `sd_models.path_to_repo(checkpoint_info)` output
- `generic.load_transformer(...)` and `generic.load_text_encoder(...)` arguments
- Duplicate kwargs such as `torch_dtype`
- Wrong class chosen for text encoder, tokenizer, or processor
- Whether the source is really a Diffusers repo or only a raw checkpoint
Typical symptoms:
- Missing subfolder errors
- `from_pretrained` argument mismatch
- Component class mismatch
### 3. Raw Checkpoint Load Fails
Check:
- Checkpoint path resolution for local file, local directory, and Hub repo
- State dict load method
- Key remapping logic
- Config inference from tensor shapes
- Missing versus unexpected keys after `load_state_dict`
Typical symptoms:
- Key mismatch explosion
- Wrong inferred head counts, dimensions, or decoder settings
- Silent shape corruption caused by a bad remap
### 4. Prompt Encoding Fails
Check:
- Tokenizer or processor choice
- `trust_remote_code` requirements
- Chat template or custom prompt formatting
- Hidden state index selection
- Padding and batch alignment between positive and negative prompts
Typical symptoms:
- Tokenizer attribute errors
- Hidden state shape mismatch
- CFG failures when negative prompts do not match prompt batch length
### 5. Sampling Or Forward Pass Fails
Check:
- Input tensor shape and channel count
- Device and dtype alignment across all components
- Scheduler timesteps and expected timestep convention
- Classifier-free guidance concatenation and split logic
- Pixel-space versus latent-space assumptions
Typical symptoms:
- Shape mismatch in attention or decoder blocks
- Device mismatch between text encoder output and model tensors
- Images exploding to NaNs because timestep semantics are inverted
### 6. Output Is Wrong But No Exception Is Raised
Check:
- Whether the model predicts `x0`, noise, or velocity
- Whether the Euler or other sampler update matches the model objective
- Final scaling and clamp path
- `output_type` handling and `pipe.task_args`
- Whether a VAE is being applied incorrectly to direct pixel-space output
Typical symptoms:
- Black, gray, washed-out, or heavily clipped images
- Output with correct size but obviously broken semantics
- Correct tensors but wrong SD.Next display behavior because output type is mismatched
## Minimal Debug Procedure
### 1. Reproduce Narrowly
Capture the smallest failing operation.
- Pure import failure
- Loader-only failure
- `from_pretrained` failure
- Prompt encode failure
- Single forward pass failure
- First sampler step failure
Prefer narrow Python checks before attempting a full generation run.
### 2. Compare Against Working Pattern
Find the closest working in-repo analogue and compare:
- Loader structure
- Registered module names
- Pipeline class name and module registration
- Prompt encoding path
- Output conversion path
### 3. Fix The Root Cause
Examples:
- Add the missing `modeldata` branch instead of patching downstream task handling
- Fix checkpoint remapping rather than forcing `strict=False` and ignoring real mismatches
- Correct the output path for pixel-space models instead of routing through a VAE
- Make config inference fail explicitly when ambiguous instead of guessing silently
### 4. Validate In Layers
After each meaningful fix, validate the narrowest relevant layer first.
- `compileall` or syntax check
- `ruff` on touched files
- Import smoke test
- Loader-only smoke test
- Full run only when the lower layers are stable
## Common Root Causes
- `modules/modeldata.py` not updated after adding a new custom pipeline family
- `modules/sd_detect.py` branch order causes overbroad detection to win first
- Loader passes duplicated keyword args like `torch_dtype`
- Shared text encoder assumptions do not match the actual model variant
- `from_pretrained` assumes `transformer/` or `text_encoder/` subfolders that do not exist
- Key remapping merges QKV in the wrong order
- CFG path concatenates embeddings or latents incorrectly
- Direct pixel-space models are postprocessed like latent-space diffusion outputs
- Negative prompts are not padded or repeated to match prompt batch shape
- Pipeline class naming collides with broader family checks in `modeldata`
## Validation Checklist
When closing the task, report which of these were completed:
1. Exact failing layer identified
2. Root cause fixed
3. Syntax check passed
4. Focused lint passed
5. Import or loader smoke test passed
6. Real generation tested, or explicitly not tested
## Example Request Shapes
- "The new model port fails in from_pretrained"
- "SD.Next detects my custom pipeline as the wrong model type"
- "The loader works but generation returns black images"
- "This standalone-script port loads weights but crashes in attention"

102
.github/skills/github-features/SKILL.md vendored Normal file
View File

@ -0,0 +1,102 @@
---
name: github-features
description: "Read SD.Next GitHub issues with [Feature] in the title and generate a markdown report with short summary, status, and suggested next steps per issue."
argument-hint: "Optionally specify state (open/closed/all), max issues, and whether to include labels/assignees"
---
# Summarize SD.Next [Feature] GitHub Issues
Fetch issues from the SD.Next GitHub repository that contain `[Feature]` in the title, then produce a concise markdown report with one entry per issue.
## When To Use
- The user asks for periodic feature-request triage summaries
- You need an actionable status report for `[Feature]` tracker items
- You want suggested next actions for each matching issue
## Repository
Default target repository:
- owner: `vladmandic`
- name: `sdnext`
## Required Output
Create markdown containing, for each matching issue:
- issue link and title
- short summary (1-3 sentences)
- status (open/closed and relevant labels)
- suggested next steps (1-3 concrete actions)
## Procedure
### 1. Search Matching Issues
Use GitHub search to find issues with `[Feature]` in title.
Preferred search query template:
- `is:issue in:title "[Feature]" repo:vladmandic/sdnext`
State filters (when requested):
- open only: add `is:open` (default)
- closed only: add `is:closed`
Use `github-pull-request_doSearch` for the search step.
### 2. Fetch Full Issue Details
For each matched issue (within requested limit):
- fetch details with `github-pull-request_issue_fetch`
- capture body, labels, assignees, state, updated time, and key discussion context
### 3. Build Per-Issue Summary
For each issue, produce:
1. Short summary:
- describe feature request in plain language
- include current progress signal if present
2. Status:
- open/closed
- notable labels (for example: enhancement, planned, blocked, stale)
- optional assignee and last update signal
3. Suggested next steps:
- propose concrete, minimal actions
- tailor actions to issue state and content
- avoid generic filler
### 4. Produce Markdown Report
Return a markdown table or bullet report with one row/section per issue.
Recommended table columns:
- Issue
- Summary
- Status
- Suggested Next Steps
If there are many issues, keep summaries short and prioritize clarity.
## Reporting Rules
- Keep each issue summary concise and actionable.
- Do not invent facts not present in issue data.
- If issue body is sparse, state assumptions explicitly.
- If no matching issues are found, output a clear "no matches" report.
## Pass Criteria
A successful run must:
- search SD.Next issues with `[Feature]` in title
- include all matched issues in scope (or explicitly mention applied limit)
- provide summary, status, and suggested next steps for each issue
- return the final result as markdown

102
.github/skills/github-issues/SKILL.md vendored Normal file
View File

@ -0,0 +1,102 @@
---
name: github-issues
description: "Read SD.Next GitHub issues with [Issues] in the title and generate a markdown report with short summary, status, and suggested next steps per issue."
argument-hint: "Optionally specify state (open/closed/all), max issues, and whether to include labels/assignees"
---
# Summarize SD.Next [Issues] GitHub Issues
Fetch issues from the SD.Next GitHub repository that contain `[Issues]` in the title, then produce a concise markdown report with one entry per issue.
## When To Use
- The user asks for periodic issue triage summaries
- You need an actionable status report for `[Issues]` tracker items
- You want suggested next actions for each matching issue
## Repository
Default target repository:
- owner: `vladmandic`
- name: `sdnext`
## Required Output
Create markdown containing, for each matching issue:
- issue link and title
- short summary (1-3 sentences)
- status (open/closed and relevant labels)
- suggested next steps (1-3 concrete actions)
## Procedure
### 1. Search Matching Issues
Use GitHub search to find issues with `[Issues]` in title.
Preferred search query template:
- `is:issue in:title "[Issues]" repo:vladmandic/sdnext`
State filters (when requested):
- open only: add `is:open` (default)
- closed only: add `is:closed`
Use `github-pull-request_doSearch` for the search step.
### 2. Fetch Full Issue Details
For each matched issue (within requested limit):
- fetch details with `github-pull-request_issue_fetch`
- capture body, labels, assignees, state, updated time, and key discussion context
### 3. Build Per-Issue Summary
For each issue, produce:
1. Short summary:
- describe problem/request in plain language
- include current progress signal if present
2. Status:
- open/closed
- notable labels (for example: bug, enhancement, stale, blocked)
- optional assignee and last update signal
3. Suggested next steps:
- propose concrete, minimal actions
- tailor actions to issue state and content
- avoid generic filler
### 4. Produce Markdown Report
Return a markdown table or bullet report with one row/section per issue.
Recommended table columns:
- Issue
- Summary
- Status
- Suggested Next Steps
If there are many issues, keep summaries short and prioritize clarity.
## Reporting Rules
- Keep each issue summary concise and actionable.
- Do not invent facts not present in issue data.
- If issue body is sparse, state assumptions explicitly.
- If no matching issues are found, output a clear "no matches" report.
## Pass Criteria
A successful run must:
- search SD.Next issues with `[Issues]` in title
- include all matched issues in scope (or explicitly mention applied limit)
- provide summary, status, and suggested next steps for each issue
- return the final result as markdown

247
.github/skills/port-model/SKILL.md vendored Normal file
View File

@ -0,0 +1,247 @@
---
name: port-model
description: "Port or add a model to SD.Next using existing Diffusers and custom pipeline patterns. Use when implementing a new model loader, custom pipeline, checkpoint conversion path, or SD.Next model-type integration."
argument-hint: "Describe the source model, target task, checkpoint format, and whether the model already has a Diffusers pipeline"
---
# Port Model To SD.Next And Diffusers
Read the task, identify the model architecture and artifact layout, choose the narrowest integration path that matches existing SD.Next patterns, implement the loader and pipeline wiring, and validate the result.
## When To Use
- The user wants to add a new model family to SD.Next
- A standalone inference script needs to become a Diffusers-style pipeline
- A raw checkpoint or safetensors repo needs an SD.Next loader
- A model already exists in Diffusers but is not yet wired into SD.Next
- A custom architecture needs a repo-local `pipelines/<model>` package and loader
## Core Rule
Prefer the smallest correct integration path.
- If upstream Diffusers already supports the model cleanly, reuse the upstream pipeline and only add SD.Next wiring.
- If the model requires custom architecture or sampler behavior, add a repo-local pipeline package under `pipelines/<model>`.
- If the model ships as a raw checkpoint instead of a Diffusers repo, load or remap weights explicitly instead of pretending it is a standard Diffusers layout.
## Inputs To Collect First
Before editing anything, determine these facts:
- Model task: text-to-image, img2img, inpaint, editing, video, multimodal, or other
- Artifact format: Diffusers repo, local folder, single-file safetensors, ckpt, GGUF, or custom layout
- Core components: transformer or UNet, VAE or VQ model, text encoder, tokenizer or processor, scheduler, image encoder, adapters
- Whether output is latent-space or direct pixel-space
- Whether the model family has one fixed architecture or multiple variants
- Whether prompt encoding follows a normal tokenizer path or a custom chat/template path
- Whether there is an existing in-repo analogue to copy structurally
If any of these remain unclear after reading the repo and source artifacts, ask concise clarifying questions before implementing.
## Mandatory Category Question
Before implementing model-reference updates, explicitly ask the user which category the model belongs to:
- `base`
- `cloud`
- `quant`
- `distilled`
- `nunchaku`
- `community`
Do not guess this category. Use the user answer to decide which reference JSON file(s) to update.
## Repo Files To Check
Start by reading the task description, then inspect the closest matching implementations.
- `.github/copilot-instructions.md`
- `.github/instructions/core.instructions.md`
- `pipelines/generic.py`
- `pipelines/model_*.py` files similar to the target model
- `modules/sd_models.py`
- `modules/sd_detect.py`
- `modules/modeldata.py`
Useful examples by pattern:
- Existing upstream Diffusers loader: `pipelines/model_chroma.py`, `pipelines/model_z_image.py`
- Custom in-repo pipeline package: `pipelines/f_lite/`
- Shared Qwen loader pattern: `pipelines/model_z_image.py`, `pipelines/model_flux2_klein.py`
## Integration Decision Tree
### 1. Upstream Diffusers Support Exists
Use this path when the model already has a usable Diffusers pipeline and component classes.
Implement:
- `pipelines/model_<name>.py`
- `modules/sd_models.py` dispatch branch
- `modules/sd_detect.py` filename autodetect branch if appropriate
- `modules/modeldata.py` model type detection branch
Reuse:
- `generic.load_transformer(...)`
- `generic.load_text_encoder(...)`
- Existing processor or tokenizer loading patterns
- `sd_hijack_te.init_hijack(pipe)` and `sd_hijack_vae.init_hijack(pipe)` where relevant
### 2. Custom Pipeline Needed
Use this path when the model architecture, sampler, or prompt encoding is not available upstream.
Implement:
- `pipelines/<model>/__init__.py`
- `pipelines/<model>/model.py`
- `pipelines/<model>/pipeline.py`
- `pipelines/model_<name>.py`
- SD.Next registration in `modules/sd_models.py`, `modules/sd_detect.py`, and `modules/modeldata.py`
Model module responsibilities:
- Architecture classes
- Config handling and defaults
- Weight remapping or checkpoint conversion helpers
- Raw checkpoint loading helpers if needed
Pipeline module responsibilities:
- `DiffusionPipeline` subclass
- `__init__`
- `from_pretrained`
- `encode_prompt` or equivalent prompt preparation
- `__call__`
- Output dataclass
- Optional callback handling and output conversion
### 3. Raw Checkpoint Or Single-File Weights
Use this path when the model source is not a normal Diffusers repository.
Requirements:
- Resolve checkpoint path from local file, local directory, or Hub repo
- Load state dict directly
- Remap keys when training format differs from inference format
- Infer config from tensor shapes only if the model family truly varies and no config file exists
- Raise explicit errors for ambiguous or incomplete layouts
Do not fake a `from_pretrained` implementation that silently assumes missing subfolders exist.
## Required SD.Next Touchpoints
Most new model families need all of these:
- `pipelines/model_<name>.py`
Purpose: SD.Next loader entry point
- `modules/sd_models.py`
Purpose: route detected model type to the correct loader
- `modules/sd_detect.py`
Purpose: detect model family from filename or repo name
- `modules/modeldata.py`
Purpose: classify loaded pipeline instance back into SD.Next model type
Add only what the model actually needs.
Reference catalog touchpoints are also required for model ports intended to appear in SD.Next model references.
- `data/reference.json` for `base`
- `data/reference-cloud.json` for `cloud`
- `data/reference-quant.json` for `quant`
- `data/reference-distilled.json` for `distilled`
- `data/reference-nunchaku.json` for `nunchaku`
- `data/reference-community.json` for `community`
If the model belongs to multiple categories, update each corresponding `data/reference*.json` file.
Possible extra integration points:
- `diffusers.pipelines.auto_pipeline.AUTO_*_PIPELINES_MAPPING` when task switching matters
- `pipe.task_args` when SD.Next needs default runtime kwargs such as `output_type`
- VAE hijack, TE hijack, or task-specific processors
## Loader Conventions
In `pipelines/model_<name>.py`:
- Use `sd_models.path_to_repo(checkpoint_info)`
- Call `sd_models.hf_auth_check(checkpoint_info)`
- Use `model_quant.get_dit_args(...)` for load args
- Respect `devices.dtype`
- Log meaningful loader details
- Reuse `generic.load_transformer(...)` and `generic.load_text_encoder(...)` when possible
- Set `pipe.task_args = {'output_type': 'np'}` when the pipeline should default to numpy output for SD.Next
- Clean up temporary references and call `devices.torch_gc(...)`
Do not hardcode assumptions about CUDA-only execution, local paths, or one-off environment state.
## Pipeline Conventions
When building a custom pipeline:
- Inherit from `diffusers.DiffusionPipeline`
- Register modules with `DiffusionPipeline.register_modules(...)`
- Keep prompt encoding and output conversion inside the pipeline
- Support `prompt`, `negative_prompt`, `generator`, `output_type`, and `return_dict` when the task is text-to-image-like
- Expose only parameters that are part of the models actual sampling surface
- Preserve direct pixel-space output if the model does not use a VAE
- Use a VAE only if the model genuinely outputs latents
Do not add generic Stable Diffusion arguments that the model does not support.
## Validation Checklist
After implementation, validate in this order:
1. Syntax or bytecode compilation for new files
2. Focused linting on touched files
3. Import-level smoke test for new modules when safe
4. Loader-path validation without doing a full generation if the model is too large
5. One real load and generation pass if feasible
Always report what was validated and what was not.
## Reference Asset Update Requirements
When the user asks to add or port a model for references, also perform these steps:
1. Update the correct `data/reference*.json` file(s) based on the user-confirmed category.
2. Create a placeholder zero-byte thumbnail file in `models/Reference` for the new model.
Notes:
- In this repo, the folder is `models/Reference` (capital `R`).
- Use a deterministic filename that matches the model entry naming convention used in the target reference JSON.
- If a real thumbnail already exists, do not overwrite it with a zero-byte file.
## Common Failure Modes
- Model type is added to `sd_models.py` but not `sd_detect.py`
- Loader exists but `modules/modeldata.py` still classifies the pipeline incorrectly
- A custom pipeline is named in a way that collides with a broader existing branch such as `Chroma`
- `torch_dtype` or other loader args are passed twice
- The code assumes a Diffusers repo layout when the source is really a single checkpoint file
- Negative prompt handling does not match prompt batch size
- Output is postprocessed as latents even though the model is pixel-space
- Custom config inference uses fragile defaults without clear failure paths
## Output Expectations
When using this skill, the final implementation should usually include:
- A clear integration path choice
- The minimal set of new files required
- SD.Next routing updates
- Targeted validation results
- Any remaining runtime risks called out explicitly
## Example Request Shapes
- "Port this standalone inference script into an SD.Next Diffusers pipeline"
- "Add support for this Hugging Face model repo to SD.Next"
- "Wire this upstream Diffusers pipeline into SD.Next autodetect and loading"
- "Convert this single-file checkpoint model into a custom Diffusers pipeline for SD.Next"

View File

@ -1,60 +1,8 @@
# SD.Next: AGENTS.md Project Guidelines
SD.Next is a complex codebase with specific patterns and conventions.
General app structure is:
- Python backend server
Uses Torch for model inference, FastAPI for API routes and Gradio for creation of UI components.
- JavaScript/CSS frontend
**SD.Next** is a complex codebase with specific patterns and conventions:
- Read main instructions file `.github/copilot-instructions.md` for general project guidelines, tools, structure, style, and conventions.
- For core tasks, also review instructions `.github/instructions/core.instructions.md`
- For UI tasks, also review instructions `.github/instructions/ui.instructions.md`
## Tools
- `venv` for Python environment management, activated with `source venv/bin/activate` (Linux) or `venv\Scripts\activate` (Windows).
venv MUST be activated before running any Python commands or scripts to ensure correct dependencies and environment variables.
- `python` 3.10+.
- `pyproject.toml` for Python configuration, including linting and type checking settings.
- `eslint` configured for both core and UI code.
- `pnpm` for managing JavaScript dependencies and scripts, with key commands defined in `package.json`.
- `ruff` and `pylint` for Python linting, with configurations in `pyproject.toml` and executed via `pnpm ruff` and `pnpm pylint`.
- `pre-commit` hooks which also check line-endings and other formatting issues, configured in `.pre-commit-config.yaml`.
## Project Structure
- Entry/startup flow: `webui.sh` -> `launch.py` -> `webui.py` -> modules under `modules/`.
- Install: `installer.py` takes care of installing dependencies and setting up the environment.
- Core runtime state is centralized in `modules/shared.py` (shared.opts, model state, backend/device state).
- API/server routes are under `modules/api/`.
- UI codebase is split between base JS in `javascript/` and actual UI in `extensions-builtin/sdnext-modernui/`.
- Model and pipeline logic is split between `modules/sd_*` and `pipelines/`.
- Additional plug-ins live in `scripts/` and are used only when specified.
- Extensions live in `extensions-builtin/` and `extensions/` and are loaded dynamically.
- Tests and CLI scripts are under `test/` and `cli/`, with some API smoke checks in `test/full-test.sh`.
## Code Style
- Prefer existing project patterns over strict generic style rules;
this codebase intentionally allows patterns often flagged in default linters such as allowing long lines, etc.
## Build And Test
- Activate environment: `source venv/bin/activate` (always ensure this is active when working with Python code).
- Test startup: `python launch.py --test`
- Full startup: `python launch.py`
- Full lint sequence: `pnpm lint`
- Python checks individually: `pnpm ruff`, `pnpm pylint`
- JS checks: `pnpm eslint` and `pnpm eslint-ui`
## Conventions
- Keep PR-ready changes targeted to `dev` branch.
- Use conventions from `CONTRIBUTING`.
- Do not include unrelated edits or submodule changes when preparing contributions.
- Use existing CLI/API tool patterns in `cli/` and `test/` when adding automation scripts.
- Respect environment-driven behavior (`SD_*` flags and options) instead of hardcoding platform/model assumptions.
- For startup/init edits, preserve error handling and partial-failure tolerance in parallel scans and extension loading.
## Pitfalls
- Initialization order matters: startup paths in `launch.py` and `webui.py` are sensitive to import/load timing.
- Shared mutable global state can create subtle regressions; prefer narrow, explicit changes.
- Device/backend-specific code paths (**CUDA/ROCm/IPEX/DirectML/OpenVINO**) should not assume one platform.
- Scripts and extension loading is dynamic; failures may appear only when specific extensions or models are present.
For specific SKILLS, also review the relevant skill files specified in `.github/skills/README.md` and listed `.github/skills/*/SKILL.md`

View File

@ -1,11 +1,19 @@
# Change Log for SD.Next
## Update for 2026-04-09
## Update for 2026-04-10
- **Models**
- [AiArtLab SDXS-1B](https://huggingface.co/AiArtLab/sdxs-1b) Simple Diffusion XS *(training still in progress)*
this model combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE
- [Anima Preview-v3](https://huggingface.co/circlestone-labs/Anima) new version of Anima
- **Agents**
framework for AI agent based work in `.github/`
- general instructions: `copilot-instructions.md` (not copilot specific)
- additional instructions in `instructions/`: `core.instructions.md`, `ui.instructions.md`
- skills in in `skills/README.md`:
*validation skills*: `check-models`, `check-api`, `check-schedulers`, `check-processing`, `check-scripts`
*model skills*: `port-model`, `debug-model`, `analyze-model`
*github skills*: `github-issues`, `github-features`
- **Caption & Prompt Enhance**
- [Google Gemma 4] in *E2B* and *E4B* variants
- **Compute**
@ -18,8 +26,19 @@
- enhanced filename pattern processing
allows for any *processing* property name (as defined in `modules/processing_class.py` and saved to `ui-config.json`)
allows for any *settings* property name (as defined in `modules/ui_definitions.py` and saved to `config.json`)
- **Agents**
created framework for AI agent based work in `.github/`
*note*: all skills are agent-model agnostic
- general instructions:
`AGENTS.md`, `copilot-instructions.md`
- additional instructions in `instructions/`:
`core.instructions.md`, `ui.instructions.md`
- skills in in `skills/README.md`:
*validation skills*: `check-models`, `check-api`, `check-schedulers`, `check-processing`, `check-scripts`
*model skills*: `port-model`, `debug-model`, `analyze-model`
*github skills*: `github-issues`, `github-features`
- **Obsoleted**
- remove *system-info* from *extensions-builtin*
- removed *system-info* from *extensions-builtin*
- **Internal**
- additional typing and typechecks, thanks @awsr
- wrap hf download methods

View File

@ -46,7 +46,8 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
### Image-Base
- [NucleusMoe](<https://github.com/huggingface/diffusers/pull/13317)
- [Mugen](https://huggingface.co/CabalResearch/Mugen)
- [NucleusMoe](https://github.com/huggingface/diffusers/pull/13317)
- [Chroma Zeta](https://huggingface.co/lodestones/Zeta-Chroma): Image and video generator for creative effects and professional filters
- [Chroma Radiance](https://huggingface.co/lodestones/Chroma1-Radiance): Pixel-space model eliminating VAE artifacts for high visual fidelity
- [Bria FIBO](https://huggingface.co/briaai/FIBO): Fully JSON based
@ -58,6 +59,7 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
### Image-Edit
- [Tencent HY-WU](https://huggingface.co/tencent/HY-WU)
- [JoyAI Image Edit](https://huggingface.co/jdopensource/JoyAI-Image-Edit)
- [Bria FIBO-Edit](https://huggingface.co/briaai/Fibo-Edit-RMBG): Fully JSON-based instruction-following image editing framework
- [Meituan LongCat-Image-Edit-Turbo](https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo):6B instruction-following image editing with high visual consistency
@ -71,6 +73,7 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
### Video
- [HY-OmniWeaving](https://huggingface.co/tencent/HY-OmniWeaving)
- [LTX-Condition](https://github.com/huggingface/diffusers/pull/13058)
- [LTX-Distilled](https://github.com/huggingface/diffusers/pull/12934)
- [OpenMOSS MOVA](https://huggingface.co/OpenMOSS-Team/MOVA-720p): Unified foundation model for synchronized high-fidelity video and audio

View File

@ -244,6 +244,15 @@
"size": 26.84,
"date": "2025 July"
},
"lodestones Zeta-Chroma": {
"path": "lodestones/Zeta-Chroma",
"preview": "lodestones--Zeta-Chroma.jpg",
"desc": "Zeta-Chroma is a pixel-space diffusion transformer image model from lodestones that generates images directly in RGB space using a NextDiT-style architecture.",
"skip": true,
"extras": "sampler: Default, cfg_scale: 3.0, steps: 30",
"size": 12.11,
"date": "2026 April"
},
"Meituan LongCat Image": {
"path": "meituan-longcat/LongCat-Image",

View File

@ -42,6 +42,8 @@ def get_model_type(pipe):
model_type = 'sc'
elif "AuraFlow" in name:
model_type = 'auraflow'
elif 'ZetaChroma' in name:
model_type = 'zetachroma'
elif 'Chroma' in name:
model_type = 'chroma'
elif "Flux2" in name:

View File

@ -91,6 +91,8 @@ def guess_by_name(fn, current_guess):
new_guess = 'Stable Diffusion 3'
elif 'hidream' in fn.lower():
new_guess = 'HiDream'
elif 'zeta-chroma' in fn.lower() or 'zetachroma' in fn.lower():
new_guess = 'ZetaChroma'
elif 'chroma' in fn.lower() and 'xl' not in fn.lower():
new_guess = 'Chroma'
elif 'flux.2' in fn.lower() and 'klein' in fn.lower():

View File

@ -385,6 +385,10 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
from pipelines.model_flex import load_flex
sd_model = load_flex(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['ZetaChroma']:
from pipelines.model_zetachroma import load_zetachroma
sd_model = load_zetachroma(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['Chroma']:
from pipelines.model_chroma import load_chroma
sd_model = load_chroma(checkpoint_info, diffusers_load_config)

View File

@ -0,0 +1,59 @@
import sys
import diffusers
import transformers
from modules import devices, model_quant, sd_hijack_te, sd_models, shared
from modules.logger import log
TEXT_ENCODER_REPO = "Tongyi-MAI/Z-Image-Turbo"
def load_zetachroma(checkpoint_info, diffusers_load_config=None):
if diffusers_load_config is None:
diffusers_load_config = {}
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
load_args.setdefault("torch_dtype", devices.dtype)
log.debug(
f'Load model: type=ZetaChroma repo="{repo_id}" config={diffusers_load_config} '
f'offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}'
)
from pipelines import generic, zetachroma
diffusers.ZetaChromaPipeline = zetachroma.ZetaChromaPipeline
sys.modules["zetachroma"] = zetachroma
text_encoder = generic.load_text_encoder(
TEXT_ENCODER_REPO,
cls_name=transformers.Qwen3ForCausalLM,
load_config=diffusers_load_config,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
TEXT_ENCODER_REPO,
subfolder="tokenizer",
cache_dir=shared.opts.hfcache_dir,
trust_remote_code=True,
)
pipe = zetachroma.ZetaChromaPipeline.from_pretrained(
repo_id,
text_encoder=text_encoder,
tokenizer=tokenizer,
cache_dir=shared.opts.diffusers_dir,
trust_remote_code=True,
**load_args,
)
pipe.task_args = {
"output_type": "np",
}
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["zetachroma"] = zetachroma.ZetaChromaPipeline
del tokenizer
del text_encoder
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason="load")
return pipe

View File

@ -0,0 +1,19 @@
from .model import (
DEFAULT_TEXT_ENCODER_REPO,
NextDiTPixelSpace,
infer_model_config,
load_zetachroma_transformer,
remap_checkpoint_keys,
)
from .pipeline import ZetaChromaPipeline, ZetaChromaPipelineOutput
__all__ = [
"DEFAULT_TEXT_ENCODER_REPO",
"NextDiTPixelSpace",
"ZetaChromaPipeline",
"ZetaChromaPipelineOutput",
"infer_model_config",
"load_zetachroma_transformer",
"remap_checkpoint_keys",
]

View File

@ -0,0 +1,732 @@
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.modeling_utils import ModelMixin
from huggingface_hub import hf_hub_download, list_repo_files
from safetensors.torch import load_file
DEFAULT_TEXT_ENCODER_REPO = "Tongyi-MAI/Z-Image-Turbo"
DEFAULT_CHECKPOINT_FILENAME = "zeta-chroma-base-x0-pixel-dino-distance.safetensors"
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(t.dtype) if torch.is_floating_point(t) else embedding
def rope(pos: torch.Tensor, dim: int, theta: float) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=pos.device)
omega = 1.0 / (theta ** scale)
out = torch.einsum("...n,d->...nd", pos.to(torch.float64), omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
rot = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
return rot.reshape(*out.shape, 2, 2).float()
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: float, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(ids.shape[-1])], dim=-3)
return emb.unsqueeze(1)
def _apply_rope_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x_ = x.float().reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out = x_out + freqs_cis[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
return _apply_rope_single(xq, freqs_cis), _apply_rope_single(xk, freqs_cis)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256, output_size: int | None = None):
super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, output_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
return self.mlp(t_freq)
class JointAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, n_kv_heads: int | None = None, qk_norm: bool = True):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.head_dim = dim // n_heads
self.n_rep = n_heads // self.n_kv_heads
self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False)
self.out = nn.Linear(n_heads * self.head_dim, dim, bias=False)
if qk_norm:
self.q_norm = nn.RMSNorm(self.head_dim, elementwise_affine=True)
self.k_norm = nn.RMSNorm(self.head_dim, elementwise_affine=True)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = x.shape
qkv = self.qkv(x)
xq, xk, xv = torch.split(
qkv,
[
self.n_heads * self.head_dim,
self.n_kv_heads * self.head_dim,
self.n_kv_heads * self.head_dim,
],
dim=-1,
)
xq = self.q_norm(xq.view(batch_size, sequence_length, self.n_heads, self.head_dim))
xk = self.k_norm(xk.view(batch_size, sequence_length, self.n_kv_heads, self.head_dim))
xv = xv.view(batch_size, sequence_length, self.n_kv_heads, self.head_dim)
xq, xk = apply_rope(xq, xk, freqs_cis)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
if self.n_rep > 1:
xk = xk.unsqueeze(2).repeat(1, 1, self.n_rep, 1, 1).flatten(1, 2)
xv = xv.unsqueeze(2).repeat(1, 1, self.n_rep, 1, 1).flatten(1, 2)
out = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask)
return self.out(out.transpose(1, 2).reshape(batch_size, sequence_length, -1))
class FeedForward(nn.Module):
def __init__(self, dim: int, inner_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, inner_dim, bias=False)
self.w2 = nn.Linear(inner_dim, dim, bias=False)
self.w3 = nn.Linear(dim, inner_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class JointTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int | None,
ffn_hidden_dim: int,
norm_eps: float,
qk_norm: bool,
modulation: bool = True,
z_image_modulation: bool = False,
):
super().__init__()
self.modulation = modulation
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
self.feed_forward = FeedForward(dim, ffn_hidden_dim)
self.attention_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.attention_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
if modulation:
if z_image_modulation:
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, 256), 4 * dim, bias=True))
else:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(min(dim, 1024), 4 * dim, bias=True))
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None,
freqs_cis: torch.Tensor,
adaln_input: torch.Tensor | None = None,
) -> torch.Tensor:
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(self.attention_norm1(x) * (1 + scale_msa.unsqueeze(1)), mask, freqs_cis)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(self.ffn_norm1(x) * (1 + scale_mlp.unsqueeze(1)))
)
return x
x = x + self.attention_norm2(self.attention(self.attention_norm1(x), mask, freqs_cis))
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
class NerfEmbedder(nn.Module):
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int = 8):
super().__init__()
self.max_freqs = max_freqs
self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input))
self._pos_cache: dict[tuple[int, str, str], torch.Tensor] = {}
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
cache_key = (patch_size, str(device), str(dtype))
if cache_key in self._pos_cache:
return self._pos_cache[cache_key]
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
pos_x = pos_x.reshape(-1, 1, 1)
pos_y = pos_y.reshape(-1, 1, 1)
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
coeffs = (1 + freqs_x * freqs_y) ** -1
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
self._pos_cache[cache_key] = dct
return dct
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
batch_size, patch_area, channels = inputs.shape
_ = channels
patch_size = int(patch_area**0.5)
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
return self.embedder(torch.cat((inputs, dct), dim=-1))
class PixelResBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 3 * channels, bias=True),
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
hidden = self.in_ln(x) * (1 + scale) + shift
return x + gate * self.mlp(hidden)
class DCTFinalLayer(nn.Module):
def __init__(self, model_channels: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
max_freqs: int = 8,
):
super().__init__()
self.cond_embed = nn.Linear(z_channels, model_channels)
self.input_embedder = NerfEmbedder(in_channels, model_channels, max_freqs)
self.res_blocks = nn.ModuleList([PixelResBlock(model_channels) for _ in range(num_res_blocks)])
self.final_layer = DCTFinalLayer(model_channels, out_channels)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
x = self.input_embedder(x)
y = self.cond_embed(c).unsqueeze(1)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x)
def pad_to_patch_size(x: torch.Tensor, patch_size: tuple[int, int]) -> torch.Tensor:
_, _, height, width = x.shape
patch_h, patch_w = patch_size
pad_h = (patch_h - height % patch_h) % patch_h
pad_w = (patch_w - width % patch_w) % patch_w
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0)
return x
def pad_zimage(feats: torch.Tensor, pad_token: torch.Tensor, pad_tokens_multiple: int) -> tuple[torch.Tensor, int]:
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
if pad_extra > 0:
feats = torch.cat(
[
feats,
pad_token.to(device=feats.device, dtype=feats.dtype).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1),
],
dim=1,
)
return feats, pad_extra
class NextDiTPixelSpace(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc]
@register_to_config
def __init__(
self,
patch_size: int = 32,
in_channels: int = 3,
dim: int = 3840,
n_layers: int = 30,
n_refiner_layers: int = 2,
n_heads: int = 30,
n_kv_heads: int | None = None,
ffn_hidden_dim: int = 10240,
norm_eps: float = 1e-5,
qk_norm: bool = True,
cap_feat_dim: int = 2560,
axes_dims: list[int] | None = None,
axes_lens: list[int] | None = None,
rope_theta: float = 256.0,
time_scale: float = 1000.0,
pad_tokens_multiple: int | None = 128,
decoder_hidden_size: int = 3840,
decoder_num_res_blocks: int = 4,
decoder_max_freqs: int = 8,
):
super().__init__()
axes_dims = axes_dims or [32, 48, 48]
axes_lens = axes_lens or [1536, 512, 512]
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.dim = dim
self.n_heads = n_heads
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.x_embedder = nn.Linear(patch_size * patch_size * in_channels, dim, bias=True)
self.noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
dim,
n_heads,
n_kv_heads,
ffn_hidden_dim,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
)
for _ in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
dim,
n_heads,
n_kv_heads,
ffn_hidden_dim,
norm_eps,
qk_norm,
modulation=False,
)
for _ in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(
hidden_size=min(dim, 1024),
frequency_embedding_size=256,
output_size=256,
)
self.cap_embedder = nn.Sequential(
nn.RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
dim,
n_heads,
n_kv_heads,
ffn_hidden_dim,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
)
for _ in range(n_layers)
]
)
if pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
assert dim // n_heads == sum(axes_dims)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=list(axes_dims))
dec_in_ch = patch_size**2 * in_channels
self.dec_net = SimpleMLPAdaLN(
in_channels=dec_in_ch,
model_channels=decoder_hidden_size,
out_channels=dec_in_ch,
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
max_freqs=decoder_max_freqs,
)
self.register_buffer("__x0__", torch.tensor([]))
def embed_cap(self, cap_feats: torch.Tensor, offset: float = 0, bsz: int = 1, device=None, dtype=None):
_ = dtype
cap_feats = self.cap_embedder(cap_feats)
cap_feats_len = cap_feats.shape[1]
if self.pad_tokens_multiple is not None:
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
freqs_cis = self.rope_embedder(cap_pos_ids).movedim(1, 2)
return cap_feats, freqs_cis, cap_feats_len
def pos_ids_x(self, start_t: float, h_tokens: int, w_tokens: int, batch_size: int, device):
x_pos_ids = torch.zeros((batch_size, h_tokens * w_tokens, 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = start_t
x_pos_ids[:, :, 1] = (
torch.arange(h_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, w_tokens).flatten()
)
x_pos_ids[:, :, 2] = (
torch.arange(w_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(h_tokens, 1).flatten()
)
return x_pos_ids
def unpatchify(self, x: torch.Tensor, img_size: list[tuple[int, int]], cap_size: list[int]) -> torch.Tensor:
patch_h = patch_w = self.patch_size
imgs = []
for index in range(x.size(0)):
height, width = img_size[index]
begin = cap_size[index]
end = begin + (height // patch_h) * (width // patch_w)
imgs.append(
x[index][begin:end]
.view(height // patch_h, width // patch_w, patch_h, patch_w, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
)
return torch.stack(imgs, dim=0)
def _forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, num_tokens: int | None) -> torch.Tensor:
_ = num_tokens
timesteps = 1.0 - timesteps
_, _, original_h, original_w = x.shape
x = pad_to_patch_size(x, (self.patch_size, self.patch_size))
batch_size, channels, height, width = x.shape
t_emb = self.t_embedder(timesteps * self.time_scale, dtype=x.dtype)
patch_h = patch_w = self.patch_size
pixel_patches = (
x.view(batch_size, channels, height // patch_h, patch_h, width // patch_w, patch_w)
.permute(0, 2, 4, 3, 5, 1)
.flatten(3)
.flatten(1, 2)
)
num_image_tokens = pixel_patches.shape[1]
pixel_values = pixel_patches.reshape(batch_size * num_image_tokens, 1, patch_h * patch_w * channels)
cap_feats_emb, cap_freqs_cis, _ = self.embed_cap(context, offset=0, bsz=batch_size, device=x.device, dtype=x.dtype)
x_tokens = self.x_embedder(
x.view(batch_size, channels, height // patch_h, patch_h, width // patch_w, patch_w)
.permute(0, 2, 4, 3, 5, 1)
.flatten(3)
.flatten(1, 2)
)
cap_feats_len_total = cap_feats_emb.shape[1]
x_pos_ids = self.pos_ids_x(cap_feats_len_total + 1, height // patch_h, width // patch_w, batch_size, x.device)
if self.pad_tokens_multiple is not None:
x_tokens, x_pad_extra = pad_zimage(x_tokens, self.x_pad_token, self.pad_tokens_multiple)
x_pos_ids = F.pad(x_pos_ids, (0, 0, 0, x_pad_extra))
x_freqs_cis = self.rope_embedder(x_pos_ids).movedim(1, 2)
for layer in self.context_refiner:
cap_feats_emb = layer(cap_feats_emb, None, cap_freqs_cis)
for layer in self.noise_refiner:
x_tokens = layer(x_tokens, None, x_freqs_cis, t_emb)
padded_full_embed = torch.cat([cap_feats_emb, x_tokens], dim=1)
full_freqs_cis = torch.cat([cap_freqs_cis, x_freqs_cis], dim=1)
img_len = x_tokens.shape[1]
cap_size = [padded_full_embed.shape[1] - img_len] * batch_size
hidden_states = padded_full_embed
for layer in self.layers:
hidden_states = layer(hidden_states, None, full_freqs_cis.to(hidden_states.device), t_emb)
img_hidden = hidden_states[:, cap_size[0] : cap_size[0] + num_image_tokens, :]
decoder_cond = img_hidden.reshape(batch_size * num_image_tokens, self.dim)
output = self.dec_net(pixel_values, decoder_cond).reshape(batch_size, num_image_tokens, -1)
cap_placeholder = torch.zeros(
batch_size,
cap_size[0],
output.shape[-1],
device=output.device,
dtype=output.dtype,
)
img_size = [(height, width)] * batch_size
img_out = self.unpatchify(torch.cat([cap_placeholder, output], dim=1), img_size, cap_size)
return -img_out[:, :, :original_h, :original_w]
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, num_tokens: int | None = None):
x0_pred = self._forward(x, timesteps, context, num_tokens)
return (x - x0_pred) / timesteps.view(-1, 1, 1, 1)
def remap_checkpoint_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
new_sd: dict[str, torch.Tensor] = {}
qkv_parts: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
if key == "__x0__":
new_sd[key] = value
continue
if ".attention.to_q." in key:
prefix = key.split(".attention.to_q.")[0]
suffix = key.split(".attention.to_q.")[1]
qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["q"] = value
continue
if ".attention.to_k." in key:
prefix = key.split(".attention.to_k.")[0]
suffix = key.split(".attention.to_k.")[1]
qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["k"] = value
continue
if ".attention.to_v." in key:
prefix = key.split(".attention.to_v.")[0]
suffix = key.split(".attention.to_v.")[1]
qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["v"] = value
continue
if ".attention.to_out.0." in key:
new_sd[key.replace(".attention.to_out.0.", ".attention.out.")] = value
continue
if ".attention.norm_q." in key:
new_sd[key.replace(".attention.norm_q.", ".attention.q_norm.")] = value
continue
if ".attention.norm_k." in key:
new_sd[key.replace(".attention.norm_k.", ".attention.k_norm.")] = value
continue
new_sd[key] = value
for combined_key, parts in qkv_parts.items():
if not {"q", "k", "v"}.issubset(parts):
raise ValueError(f"Incomplete QKV for {combined_key}: got {list(parts.keys())}")
prefix = combined_key.rsplit(".", 1)[0]
suffix = combined_key.rsplit(".", 1)[1]
new_sd[f"{prefix}.qkv.{suffix}"] = torch.cat([parts["q"], parts["k"], parts["v"]], dim=0)
return new_sd
def default_axes_dims_for_head_dim(head_dim: int) -> list[int]:
if head_dim % 8 == 0:
first = head_dim // 4
second = (head_dim * 3) // 8
third = head_dim - first - second
if first % 2 == 0 and second % 2 == 0 and third % 2 == 0:
return [first, second, third]
if head_dim == 128:
return [32, 48, 48]
if head_dim % 6 == 0:
split = head_dim // 3
return [split, split, split]
raise ValueError(f"Unable to infer axes_dims for head_dim={head_dim}; provide axes_dims explicitly")
def _count_module_blocks(state_dict: dict[str, torch.Tensor], prefix: str) -> int:
indices = set()
needle = f"{prefix}."
prefix_parts = prefix.count(".") + 1
for key in state_dict:
if not key.startswith(needle):
continue
parts = key.split(".")
if len(parts) <= prefix_parts:
continue
try:
indices.add(int(parts[prefix_parts]))
except ValueError:
continue
return max(indices) + 1 if indices else 0
def infer_model_config(
state_dict: dict[str, torch.Tensor],
axes_dims: list[int] | None = None,
axes_lens: list[int] | None = None,
rope_theta: float = 256.0,
time_scale: float = 1000.0,
pad_tokens_multiple: int | None = 128,
norm_eps: float = 1e-5,
) -> dict:
dim = state_dict["x_embedder.weight"].shape[0]
patch_channels = state_dict["x_embedder.weight"].shape[1]
in_channels = 3 if patch_channels % 3 == 0 else 1
patch_size = int(round((patch_channels / in_channels) ** 0.5))
head_dim = state_dict["layers.0.attention.q_norm.weight"].shape[0]
n_heads = dim // head_dim
total_proj_heads = state_dict["layers.0.attention.qkv.weight"].shape[0] // head_dim
n_kv_heads = (total_proj_heads - n_heads) // 2
if n_kv_heads == n_heads:
n_kv_heads = None
qk_norm = "layers.0.attention.q_norm.weight" in state_dict
decoder_input = state_dict["dec_net.input_embedder.embedder.0.weight"].shape[1]
decoder_max_freqs = int(round((decoder_input - patch_channels) ** 0.5))
config = {
"patch_size": patch_size,
"in_channels": in_channels,
"dim": dim,
"n_layers": _count_module_blocks(state_dict, "layers"),
"n_refiner_layers": _count_module_blocks(state_dict, "noise_refiner"),
"n_heads": n_heads,
"n_kv_heads": n_kv_heads,
"ffn_hidden_dim": state_dict["layers.0.feed_forward.w1.weight"].shape[0],
"norm_eps": norm_eps,
"qk_norm": qk_norm,
"cap_feat_dim": state_dict["cap_embedder.1.weight"].shape[1],
"axes_dims": axes_dims or default_axes_dims_for_head_dim(head_dim),
"axes_lens": axes_lens or [1536, 512, 512],
"rope_theta": rope_theta,
"time_scale": time_scale,
"pad_tokens_multiple": pad_tokens_multiple,
"decoder_hidden_size": state_dict["dec_net.cond_embed.weight"].shape[0],
"decoder_num_res_blocks": _count_module_blocks(state_dict, "dec_net.res_blocks"),
"decoder_max_freqs": decoder_max_freqs,
}
return config
def resolve_checkpoint_path(
pretrained_model_name_or_path: str,
checkpoint_filename: str | None = None,
cache_dir: str | None = None,
local_files_only: bool | None = None,
) -> str:
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
if os.path.isdir(pretrained_model_name_or_path):
if checkpoint_filename is not None:
candidate = os.path.join(pretrained_model_name_or_path, checkpoint_filename)
if os.path.isfile(candidate):
return candidate
candidates = [
os.path.join(pretrained_model_name_or_path, name)
for name in os.listdir(pretrained_model_name_or_path)
if name.lower().endswith(".safetensors")
]
if len(candidates) == 1:
return candidates[0]
if not candidates:
raise FileNotFoundError(f"No .safetensors checkpoint found in {pretrained_model_name_or_path}")
raise FileNotFoundError(
f"Multiple .safetensors checkpoints found in {pretrained_model_name_or_path}; provide checkpoint_filename"
)
filename = checkpoint_filename
if filename is None:
repo_files = [name for name in list_repo_files(pretrained_model_name_or_path) if name.lower().endswith(".safetensors")]
if not repo_files and DEFAULT_CHECKPOINT_FILENAME:
repo_files = [DEFAULT_CHECKPOINT_FILENAME]
if len(repo_files) == 1:
filename = repo_files[0]
elif DEFAULT_CHECKPOINT_FILENAME in repo_files:
filename = DEFAULT_CHECKPOINT_FILENAME
else:
raise FileNotFoundError(
f"Unable to determine checkpoint filename for {pretrained_model_name_or_path}; provide checkpoint_filename"
)
return hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename=filename,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
def load_zetachroma_transformer(
pretrained_model_name_or_path: str,
checkpoint_filename: str | None = None,
torch_dtype: torch.dtype | None = None,
cache_dir: str | None = None,
local_files_only: bool | None = None,
model_config: dict | None = None,
axes_dims: list[int] | None = None,
axes_lens: list[int] | None = None,
rope_theta: float = 256.0,
time_scale: float = 1000.0,
pad_tokens_multiple: int | None = 128,
) -> NextDiTPixelSpace:
checkpoint_path = resolve_checkpoint_path(
pretrained_model_name_or_path,
checkpoint_filename=checkpoint_filename,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
raw_state_dict = load_file(checkpoint_path, device="cpu")
state_dict = remap_checkpoint_keys(raw_state_dict)
del raw_state_dict
config = infer_model_config(
state_dict,
axes_dims=axes_dims,
axes_lens=axes_lens,
rope_theta=rope_theta,
time_scale=time_scale,
pad_tokens_multiple=pad_tokens_multiple,
)
if model_config is not None:
config.update(model_config)
model = NextDiTPixelSpace(**config)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
ignored_missing = {"__x0__"}
relevant_missing = [key for key in missing if key not in ignored_missing]
if relevant_missing or unexpected:
raise ValueError(
"Zeta-Chroma checkpoint load mismatch: "
f"missing={relevant_missing[:20]} unexpected={unexpected[:20]}"
)
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
return model

View File

@ -0,0 +1,302 @@
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, Qwen3ForCausalLM
from .model import DEFAULT_TEXT_ENCODER_REPO, NextDiTPixelSpace, load_zetachroma_transformer
logger = logging.getLogger(__name__)
@dataclass
class ZetaChromaPipelineOutput(BaseOutput):
images: Union[List[Image.Image], np.ndarray, torch.Tensor]
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> torch.Tensor:
timesteps = torch.linspace(1, 0, num_steps + 1)
if shift:
mu = base_shift + (image_seq_len / 4096) * (max_shift - base_shift)
timesteps = math.exp(mu) / (math.exp(mu) + (1.0 / timesteps.clamp(min=1e-9) - 1.0))
timesteps[0] = 1.0
timesteps[-1] = 0.0
return timesteps
class ZetaChromaPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->dit_model"
dit_model: NextDiTPixelSpace
text_encoder: PreTrainedModel
tokenizer: PreTrainedTokenizerBase
_progress_bar_config: Dict[str, Any]
def __init__(self, dit_model: NextDiTPixelSpace, text_encoder: PreTrainedModel, tokenizer: PreTrainedTokenizerBase):
super().__init__()
DiffusionPipeline.register_modules(self, dit_model=dit_model, text_encoder=text_encoder, tokenizer=tokenizer)
if hasattr(self.text_encoder, "requires_grad_"):
self.text_encoder.requires_grad_(False)
if hasattr(self.dit_model, "requires_grad_"):
self.dit_model.requires_grad_(False)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
dit_model: Optional[NextDiTPixelSpace] = None,
text_encoder: Optional[PreTrainedModel] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
text_encoder_repo: str = DEFAULT_TEXT_ENCODER_REPO,
checkpoint_filename: Optional[str] = None,
cache_dir: Optional[str] = None,
trust_remote_code: bool = True,
torch_dtype: Optional[torch.dtype] = None,
local_files_only: Optional[bool] = None,
model_config: Optional[dict] = None,
axes_dims: Optional[list[int]] = None,
axes_lens: Optional[list[int]] = None,
rope_theta: float = 256.0,
time_scale: float = 1000.0,
pad_tokens_multiple: int | None = 128,
**kwargs,
):
_ = kwargs
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(
text_encoder_repo,
subfolder="tokenizer",
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
)
if text_encoder is None:
text_encoder = Qwen3ForCausalLM.from_pretrained(
text_encoder_repo,
subfolder="text_encoder",
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
)
if dit_model is None:
dit_model = load_zetachroma_transformer(
pretrained_model_name_or_path,
checkpoint_filename=checkpoint_filename,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only,
model_config=model_config,
axes_dims=axes_dims,
axes_lens=axes_lens,
rope_theta=rope_theta,
time_scale=time_scale,
pad_tokens_multiple=pad_tokens_multiple,
)
pipe = cls(dit_model=dit_model, text_encoder=text_encoder, tokenizer=tokenizer)
if torch_dtype is not None:
pipe.to(torch_dtype=torch_dtype)
return pipe
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def progress_bar(self, iterable=None, **kwargs): # pylint: disable=arguments-differ
self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {}
config = {**self._progress_bar_config, **kwargs}
return tqdm(iterable, **config)
def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False): # pylint: disable=arguments-differ
_ = silence_dtype_warnings
if hasattr(self, "text_encoder"):
self.text_encoder.to(device=torch_device, dtype=torch_dtype)
if hasattr(self, "dit_model"):
self.dit_model.to(device=torch_device, dtype=torch_dtype)
return self
def _encode_single_prompt(
self,
prompt: str,
device: torch.device,
dtype: torch.dtype,
max_sequence_length: int,
) -> torch.Tensor:
messages = [{"role": "user", "content": prompt}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
inputs = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=max_sequence_length,
truncation=True,
).to(device)
outputs = self.text_encoder(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
output_hidden_states=True,
)
hidden = outputs.hidden_states[-2].to(dtype)
hidden = hidden * inputs.attention_mask.unsqueeze(-1).to(dtype)
actual_len = int(inputs.attention_mask.sum(dim=1).max().item())
return hidden[:, :actual_len, :]
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(prompt, str):
prompt = [prompt]
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt)
elif len(negative_prompt) != len(prompt):
raise ValueError("negative_prompt must have the same batch length as prompt")
device = device or self._execution_device
dtype = dtype or next(self.text_encoder.parameters()).dtype
positive = [self._encode_single_prompt(text, device, dtype, max_sequence_length) for text in prompt]
negative = [self._encode_single_prompt(text, device, dtype, max_sequence_length) for text in negative_prompt]
max_len = max(max(tensor.shape[1] for tensor in positive), max(tensor.shape[1] for tensor in negative))
def pad_sequence(sequence: torch.Tensor) -> torch.Tensor:
if sequence.shape[1] == max_len:
return sequence
return torch.cat(
[
sequence,
torch.zeros(
sequence.shape[0],
max_len - sequence.shape[1],
sequence.shape[2],
dtype=sequence.dtype,
device=sequence.device,
),
],
dim=1,
)
prompt_embeds = torch.cat([pad_sequence(tensor) for tensor in positive], dim=0)
negative_embeds = torch.cat([pad_sequence(tensor) for tensor in negative], dim=0)
return prompt_embeds, negative_embeds
def _postprocess_images(self, images: torch.Tensor, output_type: str):
images = images.float().clamp(-1, 1) * 0.5 + 0.5
if output_type == "pt":
return images
images = images.permute(0, 2, 3, 1).cpu().numpy()
if output_type == "np":
return images
if output_type == "pil":
return self.numpy_to_pil((images * 255).round().clip(0, 255).astype(np.uint8))
raise ValueError(f"Unsupported output_type: {output_type}")
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = 1024,
width: Optional[int] = 1024,
num_inference_steps: int = 30,
guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 512,
shift_schedule: bool = True,
base_shift: float = 0.5,
max_shift: float = 1.15,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end=None,
**kwargs,
):
_ = kwargs
height = height or 1024
width = width or 1024
dtype = dtype or next(self.dit_model.parameters()).dtype
device = self._execution_device
prompt_embeds, negative_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
device=device,
dtype=dtype,
max_sequence_length=max_sequence_length,
)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0)
batch_size = prompt_embeds.shape[0]
latents = randn_tensor((batch_size, 3, height, width), generator=generator, device=device, dtype=dtype)
image_seq_len = (height // self.dit_model.patch_size) * (width // self.dit_model.patch_size)
sigmas = get_schedule(
num_inference_steps,
image_seq_len,
base_shift=base_shift,
max_shift=max_shift,
shift=shift_schedule,
).to(device=device, dtype=dtype)
self.dit_model.eval()
do_classifier_free_guidance = guidance_scale > 1.0
for step, timestep in enumerate(self.progress_bar(range(num_inference_steps), desc="Sampling")):
_ = timestep
t_curr = sigmas[step]
t_next = sigmas[step + 1]
t_batch = torch.full((batch_size,), t_curr, device=device, dtype=dtype)
if do_classifier_free_guidance:
latent_input = torch.cat([latents, latents], dim=0)
timestep_input = torch.cat([t_batch, t_batch], dim=0)
context_input = torch.cat([prompt_embeds, negative_embeds], dim=0)
velocity = self.dit_model(latent_input, timestep_input, context_input, context_input.shape[1])
v_cond, v_uncond = velocity.chunk(2, dim=0)
velocity = v_uncond + guidance_scale * (v_cond - v_uncond)
else:
velocity = self.dit_model(latents, t_batch, prompt_embeds, prompt_embeds.shape[1])
latents = latents + (t_next - t_curr) * velocity
if callback_on_step_end is not None:
callback_kwargs = {
"latents": latents,
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_embeds,
}
callback_result = callback_on_step_end(self, step, int(t_curr.item() * 1000), callback_kwargs)
if isinstance(callback_result, dict) and "latents" in callback_result:
latents = callback_result["latents"]
images = self._postprocess_images(latents, output_type=output_type)
if not return_dict:
return (images,)
return ZetaChromaPipelineOutput(images=images)