mirror of https://github.com/vladmandic/automatic
parent
8d871af63d
commit
b9a38b5955
|
|
@ -58,3 +58,49 @@ General app structure is:
|
|||
- Shared mutable global state can create subtle regressions; prefer narrow, explicit changes.
|
||||
- Device/backend-specific code paths (**CUDA/ROCm/IPEX/DirectML/OpenVINO**) should not assume one platform.
|
||||
- Scripts and extension loading is dynamic; failures may appear only when specific extensions or models are present.
|
||||
|
||||
## Repo-Local Skills
|
||||
|
||||
Use these repo-local skills for recurring SD.Next model integration work:
|
||||
|
||||
- `port-model`
|
||||
File: `.github/skills/port-model/SKILL.md`
|
||||
Use when adding a new model family, porting a standalone script into a Diffusers pipeline, or wiring an upstream Diffusers model into SD.Next.
|
||||
|
||||
- `debug-model`
|
||||
File: `.github/skills/debug-model/SKILL.md`
|
||||
Use when a new or existing SD.Next/Diffusers model integration fails during detection, loading, prompt encoding, sampling, or output handling.
|
||||
|
||||
- `check-api`
|
||||
File: `.github/skills/check-api/SKILL.md`
|
||||
Use when auditing API routes in `modules/api/api.py` and validating endpoint parameters plus request/response signatures.
|
||||
|
||||
- `check-schedulers`
|
||||
File: `.github/skills/check-schedulers/SKILL.md`
|
||||
Use when auditing scheduler registrations in `modules/sd_samplers_diffusers.py` for class loadability, config validity, and `SamplerData` mapping correctness.
|
||||
|
||||
- `check-models`
|
||||
File: `.github/skills/check-models/SKILL.md`
|
||||
Use when running end-to-end model integration audits across loaders, detect/routing parity, reference catalogs, and pipeline API contracts.
|
||||
|
||||
- `check-processing`
|
||||
File: `.github/skills/check-processing/SKILL.md`
|
||||
Use when validating txt2img/img2img/control processing workflows from UI submit definitions through backend processing and Diffusers execution, including parameter, type, and initialization checks.
|
||||
|
||||
- `check-scripts`
|
||||
File: `.github/skills/check-scripts/SKILL.md`
|
||||
Use when auditing scripts in `scripts/*.py` for standard `Script` overrides (`__init__`, `title`, `show`) and validating `ui()` output against `run()` or `process()` parameters.
|
||||
|
||||
- `github-issues`
|
||||
File: `.github/skills/github-issues/SKILL.md`
|
||||
Use when reading SD.Next GitHub issues with `[Issues]` in title and producing a markdown report with short summary, status, and suggested next steps for each issue.
|
||||
|
||||
- `github-features`
|
||||
File: `.github/skills/github-features/SKILL.md`
|
||||
Use when reading SD.Next GitHub issues with `[Feature]` in title and producing a markdown report with short summary, status, and suggested next steps for each issue.
|
||||
|
||||
- `analyze-model`
|
||||
File: `.github/skills/analyze-model/SKILL.md`
|
||||
Use when analyzing an external model URL to identify implementation style and estimate how difficult it is to port into SD.Next.
|
||||
|
||||
When creating and updating skills, update this file and the index in `.github/skills/README.md` accordingly.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
# Repo Skills
|
||||
|
||||
This folder contains repo-local Copilot skills for recurring SD.Next tasks.
|
||||
|
||||
## Available Skills
|
||||
|
||||
- `port-model`
|
||||
File: `port-model/SKILL.md`
|
||||
Use when adding or porting a model family into SD.Next and Diffusers.
|
||||
|
||||
- `debug-model`
|
||||
File: `debug-model/SKILL.md`
|
||||
Use when a new or existing SD.Next/Diffusers model integration fails during detect, load, prompt encode, sample, or output handling.
|
||||
|
||||
- `check-api`
|
||||
File: `check-api/SKILL.md`
|
||||
Use when auditing API endpoints in `modules/api/api.py` and delegated API modules for route parameter correctness plus request/response signature consistency.
|
||||
|
||||
- `check-schedulers`
|
||||
File: `check-schedulers/SKILL.md`
|
||||
Use when auditing scheduler registrations from `modules/sd_samplers_diffusers.py` to verify class loadability, config validity against scheduler capabilities, and `SamplerData` correctness.
|
||||
|
||||
- `check-models`
|
||||
File: `check-models/SKILL.md`
|
||||
Use when running an end-to-end model integration audit covering loaders, detect/routing parity, reference catalogs, and custom pipeline API contracts.
|
||||
|
||||
- `check-processing`
|
||||
File: `check-processing/SKILL.md`
|
||||
Use when validating txt2img/img2img/control processing workflows from UI submit definitions to backend execution with parameter, type, and initialization checks.
|
||||
|
||||
- `check-scripts`
|
||||
File: `check-scripts/SKILL.md`
|
||||
Use when auditing `scripts/*.py` for correct Script overrides (`__init__`, `title`, `show`) and verifying `ui()` output compatibility with `run()` or `process()` parameters.
|
||||
|
||||
- `github-issues`
|
||||
File: `github-issues/SKILL.md`
|
||||
Use when reading SD.Next GitHub issues with `[Issues]` in title and producing a markdown summary with status and suggested next steps for each issue.
|
||||
|
||||
- `github-features`
|
||||
File: `github-features/SKILL.md`
|
||||
Use when reading SD.Next GitHub issues with `[Feature]` in title and producing a markdown summary with status and suggested next steps for each issue.
|
||||
|
||||
- `analyze-model`
|
||||
File: `analyze-model/SKILL.md`
|
||||
Use when analyzing an external model URL to classify implementation style and estimate SD.Next porting difficulty before coding.
|
||||
|
||||
## Notes
|
||||
|
||||
- Keep skills narrowly task-oriented and reusable.
|
||||
- Prefer referencing existing repo patterns over generic framework advice.
|
||||
- Update this index when adding new repo-local skills.
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
---
|
||||
name: analyze-model
|
||||
description: "Analyze an external model URL (typically Hugging Face) to determine implementation style and estimate SD.Next porting difficulty using the port-model workflow."
|
||||
argument-hint: "Provide model URL and optional target scope: text2img, img2img, edit, video, or full integration"
|
||||
---
|
||||
|
||||
# Analyze External Model For SD.Next Porting
|
||||
|
||||
Given an external model URL, inspect how the model is implemented and estimate how hard it is to port into SD.Next according to the port-model skill.
|
||||
|
||||
## When To Use
|
||||
|
||||
- User provides a Hugging Face model URL and asks if or how it can be ported
|
||||
- User wants effort estimation before implementation work
|
||||
- User wants to classify whether integration should reuse Diffusers, use custom Diffusers code, or require full custom implementation
|
||||
|
||||
## Accepted Inputs
|
||||
|
||||
- Hugging Face model URL (preferred)
|
||||
- Hugging Face repo id
|
||||
- Optional source links to custom code repos or PRs
|
||||
|
||||
Example:
|
||||
|
||||
- https://huggingface.co/jdopensource/JoyAI-Image-Edit
|
||||
|
||||
## Required Outputs
|
||||
|
||||
For each analyzed model, provide:
|
||||
|
||||
1. Implementation classification
|
||||
2. Evidence for classification
|
||||
3. SD.Next porting path recommendation
|
||||
4. Difficulty rating and effort breakdown
|
||||
5. Main risks and blockers
|
||||
6. First implementation steps aligned with port-model skill
|
||||
|
||||
## Implementation Classification Buckets
|
||||
|
||||
Classify into one of these (or closest fit):
|
||||
|
||||
1. Integrated into upstream Diffusers library
|
||||
2. Custom Diffusers implementation in model repo (custom pipeline/classes, not upstream)
|
||||
3. Fully custom implementation (non-Diffusers inference stack)
|
||||
4. Existing integration in ComfyUI (node or PR) but not in Diffusers
|
||||
5. Other (describe clearly)
|
||||
|
||||
## Procedure
|
||||
|
||||
### 1. Inspect Model Repository Artifacts
|
||||
|
||||
From the provided URL/repo, collect:
|
||||
|
||||
- model card details
|
||||
- files such as model_index.json, config.json, scheduler config, tokenizer files
|
||||
- presence of Diffusers-style folder layout
|
||||
- references to custom Python modules or remote code requirements
|
||||
|
||||
### 2. Determine Runtime Stack
|
||||
|
||||
Identify whether model usage is:
|
||||
|
||||
- standard Diffusers pipeline call
|
||||
- custom Diffusers pipeline class with trust_remote_code
|
||||
- pure custom inference script or framework
|
||||
- node-based integration in ComfyUI or another host
|
||||
|
||||
### 3. Cross-Check Integration Surface
|
||||
|
||||
Determine required SD.Next touchpoints if ported:
|
||||
|
||||
- loader file in pipelines/model_name.py
|
||||
- detect and dispatch updates in modules/sd_detect.py and modules/sd_models.py
|
||||
- model type mapping in modules/modeldata.py
|
||||
- optional custom pipeline package under pipelines/model/
|
||||
- reference catalog updates and preview asset requirements
|
||||
|
||||
### 4. Estimate Porting Difficulty
|
||||
|
||||
Use this scale:
|
||||
|
||||
- Low: mostly loader wiring to existing upstream Diffusers pipeline
|
||||
- Medium: custom Diffusers classes or limited checkpoint/config adaptation
|
||||
- High: full custom architecture, major prompt/sampler/output differences, or sparse docs
|
||||
- Very High: no usable Diffusers path plus major runtime assumptions mismatch
|
||||
|
||||
Break down difficulty by:
|
||||
|
||||
- loader complexity
|
||||
- pipeline/API contract complexity
|
||||
- scheduler/sampler compatibility
|
||||
- prompt encoding complexity
|
||||
- checkpoint conversion/remapping complexity
|
||||
- validation and testing burden
|
||||
|
||||
### 5. Identify Risks
|
||||
|
||||
Call out concrete risks:
|
||||
|
||||
- missing or incompatible scheduler config
|
||||
- unclear output domain (latent vs pixel)
|
||||
- custom text encoder or processor constraints
|
||||
- nonstandard checkpoint format
|
||||
- dependency on external runtime features unavailable in SD.Next
|
||||
|
||||
### 6. Recommend Porting Path
|
||||
|
||||
Map recommendation to port-model workflow:
|
||||
|
||||
- Upstream Diffusers reuse path
|
||||
- Custom Diffusers pipeline package path
|
||||
- Raw checkpoint plus remap path
|
||||
|
||||
Provide a concise first-step plan with smallest viable integration milestone.
|
||||
|
||||
## Reporting Format
|
||||
|
||||
Return sections in this order:
|
||||
|
||||
1. Classification
|
||||
2. Evidence
|
||||
3. Porting difficulty
|
||||
4. Recommended SD.Next integration path
|
||||
5. Risks and unknowns
|
||||
6. Next actions
|
||||
|
||||
If critical information is missing, explicitly list unknowns and what to inspect next.
|
||||
|
||||
## Notes
|
||||
|
||||
- Prefer concrete evidence from repository files and model card over assumptions.
|
||||
- If source suggests multiple possible paths, compare at least two and state why one is preferred.
|
||||
- Keep recommendations aligned with the conventions defined in the port-model skill.
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
---
|
||||
name: check-api
|
||||
description: "Audit SD.Next API route definitions and verify endpoint parameters plus request/response signatures against declared FastAPI contracts."
|
||||
argument-hint: "Optionally focus on specific route prefixes or endpoint groups"
|
||||
---
|
||||
|
||||
# Check API Endpoints And Signatures
|
||||
|
||||
Read modules/api/api.py, enumerate all registered endpoints, and validate that each endpoint has coherent request parameters and response signatures.
|
||||
|
||||
## When To Use
|
||||
|
||||
- The user asks to audit API correctness
|
||||
- A change touched API routes, endpoint handlers, or API models
|
||||
- OpenAPI docs look wrong or clients report schema mismatches
|
||||
- You need a pre-PR API contract sanity pass
|
||||
|
||||
## Primary File
|
||||
|
||||
- `modules/api/api.py`
|
||||
|
||||
This file is the route registration hub and must be treated as the source of truth for direct `add_api_route(...)` registrations.
|
||||
|
||||
## Secondary Files To Inspect
|
||||
|
||||
- `modules/api/models.py`
|
||||
- `modules/api/endpoints.py`
|
||||
- `modules/api/generate.py`
|
||||
- `modules/api/process.py`
|
||||
- `modules/api/control.py`
|
||||
- Any module loaded via `register_api(...)` from `modules/api/*` and feature modules (caption, lora, gallery, civitai, rembg, etc.)
|
||||
|
||||
## Audit Goals
|
||||
|
||||
For every endpoint, verify:
|
||||
|
||||
1. Route method and path are valid and unique after subpath handling.
|
||||
2. Handler call signature is compatible with the route declaration.
|
||||
3. Declared `response_model` is coherent with returned payload shape.
|
||||
4. Request body or query params implied by handler type hints are consistent with expected client usage.
|
||||
5. Authentication behavior is intentional (`auth=True` default in `add_api_route`).
|
||||
6. OpenAPI schema exposure is correct (including trailing-slash duplicate suppression).
|
||||
|
||||
## Procedure
|
||||
|
||||
### 1. Enumerate Routes In modules/api/api.py
|
||||
|
||||
- Collect every `self.add_api_route(...)` call.
|
||||
- Capture: path, methods, handler symbol, response_model, tags, auth flag.
|
||||
- Note route groups: server, generation, processing, scripts, enumerators, functional.
|
||||
|
||||
### 2. Resolve Handler Definitions
|
||||
|
||||
For each handler symbol:
|
||||
|
||||
- Locate the callable definition.
|
||||
- Read its function signature and type hints.
|
||||
- Identify required positional args, optional args, and body model annotations.
|
||||
|
||||
Flag issues such as:
|
||||
|
||||
- Required parameters that cannot be supplied by FastAPI.
|
||||
- Mismatch between route method and handler intent (for example body expected on GET).
|
||||
- Ambiguous or missing type hints for public API handlers.
|
||||
|
||||
### 3. Validate Request Signatures
|
||||
|
||||
- Confirm request model classes exist and are importable.
|
||||
- Confirm endpoint signature reflects expected request source (path/query/body).
|
||||
- Check optional vs required semantics for compatibility-sensitive endpoints.
|
||||
|
||||
### 4. Validate Response Signatures
|
||||
|
||||
- Compare declared `response_model` to handler return shape.
|
||||
- Ensure list/dict wrappers match actual payload structure.
|
||||
- Flag obvious drift where `response_model` implies fields never returned.
|
||||
|
||||
### 5. Include register_api(...) Modules
|
||||
|
||||
`Api.register()` delegates additional endpoints through module-level `register_api(...)` calls.
|
||||
|
||||
- Inspect each delegated module for routes.
|
||||
- Apply the same request/response signature checks.
|
||||
- Include these findings in the final report, not only direct routes in api.py.
|
||||
|
||||
### 6. Verify Runtime Schema (Optional But Preferred)
|
||||
|
||||
If feasible in the current environment:
|
||||
|
||||
- Build app and inspect `app.routes` metadata.
|
||||
- Generate OpenAPI schema and spot-check key endpoints.
|
||||
- Confirm trailing-slash duplicate suppression behavior remains correct.
|
||||
|
||||
If runtime schema checks are not feasible, explicitly state that and rely on static validation.
|
||||
|
||||
## Reporting Format
|
||||
|
||||
Report findings ordered by severity:
|
||||
|
||||
1. Breaking contract mismatches
|
||||
2. Likely runtime errors
|
||||
3. Schema quality or consistency issues
|
||||
4. Minor style/typing improvements
|
||||
|
||||
For each finding include:
|
||||
|
||||
- Route path and method
|
||||
- File and function
|
||||
- Why it is a mismatch
|
||||
- Minimal fix recommendation
|
||||
|
||||
If no issues are found, state that explicitly and mention residual risk (for example runtime-only behavior not executed).
|
||||
|
||||
## Output Expectations
|
||||
|
||||
When this skill is used, return:
|
||||
|
||||
- Total endpoints checked
|
||||
- Direct routes vs delegated `register_api(...)` routes checked
|
||||
- Findings list with severity and locations
|
||||
- Clear pass/fail summary for request and response signature consistency
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
---
|
||||
name: check-models
|
||||
description: "Audit SD.Next model integrations end-to-end: loaders, detect/routing, reference catalogs, and pipeline API contracts."
|
||||
argument-hint: "Optionally focus on a model family, repo id, or a subset: loader, detect-routing, references, pipeline-contracts"
|
||||
---
|
||||
|
||||
# Check Model Integrations End-To-End
|
||||
|
||||
Run a consolidated model-integration audit that combines loader checks, detect/routing checks, reference-catalog checks, and pipeline contract checks.
|
||||
|
||||
## When To Use
|
||||
|
||||
- A new model family was added and needs a completeness audit
|
||||
- Existing model support appears inconsistent across detection, loading, and UI references
|
||||
- A custom pipeline was ported and needs contract validation
|
||||
- You want a pre-PR integration quality gate for model-related changes
|
||||
|
||||
## Combined Scope
|
||||
|
||||
This skill combines four audit surfaces:
|
||||
|
||||
1. Loader consistency (`check-loaders` equivalent)
|
||||
2. Detect/routing parity (`check-detect-routing` equivalent)
|
||||
3. Reference-catalog integrity (`check-reference-catalog` equivalent)
|
||||
4. Pipeline API contract conformance (`check-pipeline-contracts` equivalent)
|
||||
|
||||
## Primary Files
|
||||
|
||||
- `pipelines/model_*.py`
|
||||
- `modules/sd_detect.py`
|
||||
- `modules/sd_models.py`
|
||||
- `modules/modeldata.py`
|
||||
- `data/reference.json`
|
||||
- `data/reference-cloud.json`
|
||||
- `data/reference-quant.json`
|
||||
- `data/reference-distilled.json`
|
||||
- `data/reference-nunchaku.json`
|
||||
- `data/reference-community.json`
|
||||
- `models/Reference/`
|
||||
|
||||
Pipeline files as needed:
|
||||
|
||||
- `pipelines/<model>/pipeline.py`
|
||||
- `pipelines/<model>/model.py`
|
||||
|
||||
## Audit A: Loader Consistency
|
||||
|
||||
For each target model loader in `pipelines/model_*.py`, verify:
|
||||
|
||||
- Correct `sd_models.path_to_repo(checkpoint_info)` and `sd_models.hf_auth_check(...)` usage
|
||||
- Load args built with `model_quant.get_dit_args(...)` where applicable
|
||||
- No duplicated kwargs (for example duplicate `torch_dtype`)
|
||||
- Correct component loading path (`generic.load_transformer`, `generic.load_text_encoder`, tokenizer/processor)
|
||||
- Proper post-load hooks (`sd_hijack_te`, `sd_hijack_vae`) where required
|
||||
- Correct `pipe.task_args` defaults where needed
|
||||
- Cleanup and `devices.torch_gc(...)` present
|
||||
|
||||
Flag stale patterns, missing hooks, or conflicting load behavior.
|
||||
|
||||
## Audit B: Detect/Routing Parity
|
||||
|
||||
Verify model family alignment across:
|
||||
|
||||
- `modules/sd_detect.py` detection heuristics
|
||||
- `modules/sd_models.py` load dispatch branch
|
||||
- `modules/modeldata.py` reverse classification from loaded pipeline class
|
||||
|
||||
Checks:
|
||||
|
||||
- Family is detectable by name/repo conventions
|
||||
- Dispatch routes to the intended loader
|
||||
- Loaded pipeline class is classified back to the same model family
|
||||
- Branch ordering does not cause broad matches to shadow specific families
|
||||
|
||||
## Audit C: Reference Catalog Integrity
|
||||
|
||||
Verify references for model families intended to appear in model references.
|
||||
|
||||
Checks:
|
||||
|
||||
- Correct category file placement by type:
|
||||
- base -> `data/reference.json`
|
||||
- cloud -> `data/reference-cloud.json`
|
||||
- quant -> `data/reference-quant.json`
|
||||
- distilled -> `data/reference-distilled.json`
|
||||
- nunchaku -> `data/reference-nunchaku.json`
|
||||
- community -> `data/reference-community.json`
|
||||
- Required fields present per entry (`path`, `preview`, `desc` when expected)
|
||||
- Duplicate repo/path collisions across reference files are intentional or flagged
|
||||
- Preview filename convention is consistent
|
||||
- Referenced preview file exists in `models/Reference/` (or explicitly placeholder if intentional)
|
||||
- JSON validity for touched reference files
|
||||
|
||||
## Audit D: Pipeline API Contracts
|
||||
|
||||
For custom pipelines (`pipelines/<model>/pipeline.py`), verify:
|
||||
|
||||
- Inherits from `diffusers.DiffusionPipeline`
|
||||
- Registers modules correctly
|
||||
- `from_pretrained` wiring is coherent with actual artifact layout
|
||||
- `encode_prompt` semantics are consistent with tokenizer/text encoder setup
|
||||
- `__call__` supports expected public args for its task and does not expose unsupported generic args
|
||||
- Batch and negative prompt behavior are coherent
|
||||
- Output conversion aligns with model output domain (latent vs pixel space)
|
||||
- `output_type` and `return_dict` behavior are consistent
|
||||
|
||||
## Runtime Validation (Preferred)
|
||||
|
||||
When feasible:
|
||||
|
||||
- Import-level smoke tests for loaders and pipeline modules
|
||||
- Lightweight loader construction checks without full heavy generation where possible
|
||||
- One minimal generation/sampling pass for changed model families
|
||||
|
||||
If runtime checks are not feasible, report limitations clearly.
|
||||
|
||||
## Reporting Format
|
||||
|
||||
Return findings by severity:
|
||||
|
||||
1. Blocking integration failures
|
||||
2. Contract mismatches (load/detect/reference/pipeline)
|
||||
3. Consistency and quality issues
|
||||
4. Optional improvements
|
||||
|
||||
For each finding include:
|
||||
|
||||
- model family
|
||||
- layer (`loader`, `detect-routing`, `reference`, `pipeline-contract`)
|
||||
- file location
|
||||
- mismatch summary
|
||||
- minimal fix
|
||||
|
||||
Also include summary counts:
|
||||
|
||||
- loaders checked
|
||||
- model families checked for detect/routing parity
|
||||
- reference files checked
|
||||
- pipeline contracts checked
|
||||
- runtime checks executed vs skipped
|
||||
|
||||
## Pass Criteria
|
||||
|
||||
A full pass requires all of the following in audited scope:
|
||||
|
||||
- loader path is coherent and non-conflicting
|
||||
- detect/routing/modeldata parity holds
|
||||
- reference entries are valid, categorized correctly, and have preview files
|
||||
- custom pipeline contracts are consistent with actual model behavior
|
||||
|
||||
If any area is intentionally out of scope, mark as partial pass with explicit exclusions.
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
---
|
||||
name: check-processing
|
||||
description: "Validate txt2img/img2img/control/caption processing workflows from UI submit bindings to backend processing execution and confirm parameter/type/init correctness."
|
||||
argument-hint: "Optionally focus on txt2img, img2img, control, caption, or process-only and include changed files"
|
||||
---
|
||||
|
||||
# Check Processing Workflow Contracts
|
||||
|
||||
Trace generation workflows from UI definitions and submit bindings to backend execution, then validate that parameters are passed, typed, and initialized correctly.
|
||||
|
||||
## When To Use
|
||||
|
||||
- A change touched UI submit wiring for txt2img, img2img, control, or caption workflows
|
||||
- Processing code changed and regressions are suspected in argument ordering or defaults
|
||||
- A new parameter was added to UI or processing classes/functions and needs end-to-end validation
|
||||
- You want a pre-PR contract audit for generation flow integrity
|
||||
|
||||
## Required Workflow Coverage
|
||||
|
||||
Start from UI definitions and follow each workflow to final implementation:
|
||||
|
||||
1. `txt2img`: `modules/ui_txt2img.py` -> `modules/txt2img.py` -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers`
|
||||
2. `img2img`: `modules/ui_img2img.py` -> `modules/img2img.py` -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers`
|
||||
3. `control/process`: `modules/ui_control.py` -> `modules/control/run.py` (and related control processing entrypoints) -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers`
|
||||
4. `caption/process`: `modules/ui_caption.py` -> caption handler module(s) -> `modules/processing.py:process_images` and/or postprocess/caption execution module(s), depending on selected caption backend
|
||||
|
||||
Also validate script hooks when present:
|
||||
|
||||
- `modules/scripts_manager.py` (`run`, `before_process`, `process`, `process_images`, `after`)
|
||||
|
||||
## Primary Files
|
||||
|
||||
- `modules/ui_txt2img.py`
|
||||
- `modules/txt2img.py`
|
||||
- `modules/ui_img2img.py`
|
||||
- `modules/img2img.py`
|
||||
- `modules/ui_control.py`
|
||||
- `modules/ui_caption.py` (and `modules/ui_captions.py` if present)
|
||||
- `modules/control/run.py`
|
||||
- `modules/processing.py`
|
||||
- `modules/processing_diffusers.py`
|
||||
- `modules/scripts_manager.py`
|
||||
|
||||
## Audit Goals
|
||||
|
||||
For each covered workflow, verify all three dimensions:
|
||||
|
||||
1. Parameter pass-through correctness (name, order, semantic meaning)
|
||||
2. Type correctness (UI component output shape vs function signature expectations)
|
||||
3. Initialization correctness (defaults, `None` handling, fallback logic, and object state)
|
||||
|
||||
## Procedure
|
||||
|
||||
### 1. Build End-To-End Call Graph
|
||||
|
||||
For each workflow (`txt2img`, `img2img`, `control`, `caption`):
|
||||
|
||||
- Locate submit/click bindings in UI modules.
|
||||
- Capture the exact `inputs=[...]` list order and target function (`fn=...`).
|
||||
- Resolve wrappers (`call_queue.wrap_gradio_gpu_call`, queued wrappers) to actual function signatures.
|
||||
- Follow function flow through processing class construction and execution (`processing.process_images`, then `process_diffusers` when applicable).
|
||||
|
||||
Produce a normalized mapping table per workflow:
|
||||
|
||||
- UI input component name
|
||||
- UI expected output type
|
||||
- receiving argument in submit target
|
||||
- receiving processing-object field (if applicable)
|
||||
- downstream consumption point
|
||||
|
||||
### 2. Validate Argument Order And Arity
|
||||
|
||||
For each submit path:
|
||||
|
||||
- Compare UI `inputs` order against function positional parameter order.
|
||||
- Validate handling of `*args` and script arguments.
|
||||
- Confirm `state`, task id, mode flags, and tab selections align with function signatures.
|
||||
- Flag positional drift where adding/removing an argument in one layer is not propagated.
|
||||
|
||||
### 3. Validate Name And Semantic Parity
|
||||
|
||||
Check that semantically related parameters remain coherent across layers:
|
||||
|
||||
- sampler fields (`sampler_index`, `hr_sampler_index`, sampler name conversion)
|
||||
- guidance/cfg fields
|
||||
- size/resize fields
|
||||
- denoise/refiner/hires fields
|
||||
- detailer, hdr, grading fields
|
||||
- override settings fields
|
||||
|
||||
Flag mismatches such as:
|
||||
|
||||
- same concept with divergent naming and wrong destination
|
||||
- field sent but never consumed
|
||||
- required field consumed but never provided
|
||||
|
||||
### 4. Validate Type Contracts
|
||||
|
||||
Audit type compatibility from UI component to processing target:
|
||||
|
||||
- `gr.Image`/`gr.File`/`gr.Video` outputs vs expected Python types (`PIL.Image`, bytes, list, path-like, etc.)
|
||||
- radios returning index/value and expected downstream representation
|
||||
- sliders/number inputs (`int` vs `float`) and conversion points
|
||||
- optional objects (`None`) and `.name`/attribute access safety
|
||||
|
||||
Flag ambiguous or unsafe assumptions, especially for optional file inputs and mixed scalar/list values.
|
||||
|
||||
### 5. Validate Initialization And Defaults
|
||||
|
||||
In target modules (`txt2img.py`, `img2img.py`, `control/run.py`, processing classes):
|
||||
|
||||
- verify defaults/fallbacks for invalid or missing inputs
|
||||
- verify guards for unset model/state and expected error paths
|
||||
- verify object fields are initialized before first use
|
||||
- verify flow-specific defaults are not leaking across workflows
|
||||
|
||||
Include checks for common regressions:
|
||||
|
||||
- `None` passed into required processing fields
|
||||
- missing fallback for sampler/seed/size values
|
||||
- stale fields retained from prior job state
|
||||
|
||||
### 6. Validate Script Hook Contracts
|
||||
|
||||
Where script systems are involved:
|
||||
|
||||
- verify `scripts_*.run(...)` fallback behavior to `processing.process_images(...)`
|
||||
- verify `scripts_*.after(...)` receives compatible processed object
|
||||
- ensure script args wiring matches `setup_ui(...)` order
|
||||
|
||||
### 7. Runtime Spot Check (Preferred)
|
||||
|
||||
If feasible, run lightweight smoke validation for each workflow:
|
||||
|
||||
- one minimal txt2img run
|
||||
- one minimal img2img run
|
||||
- one minimal control run
|
||||
- one minimal caption run
|
||||
|
||||
Use very small dimensions/steps to limit runtime.
|
||||
If runtime checks are not feasible, explicitly report static-only limitations.
|
||||
|
||||
## Reporting Format
|
||||
|
||||
Return findings by severity:
|
||||
|
||||
1. Blocking contract failures (will break execution)
|
||||
2. High-risk mismatches (likely wrong behavior)
|
||||
3. Type/init safety issues
|
||||
4. Non-blocking consistency issues
|
||||
|
||||
For each finding include:
|
||||
|
||||
- workflow (`txt2img` | `img2img` | `control` | `caption`)
|
||||
- layer transition (`ui -> handler`, `handler -> processing`, `processing -> diffusers`)
|
||||
- file location
|
||||
- mismatch summary
|
||||
- minimal fix
|
||||
|
||||
Also include summary counts:
|
||||
|
||||
- workflows checked
|
||||
- UI bindings checked
|
||||
- function signatures checked
|
||||
- parameter mappings validated
|
||||
- runtime checks executed vs skipped
|
||||
|
||||
## Pass Criteria
|
||||
|
||||
A full pass requires all of the following:
|
||||
|
||||
- UI submit input order matches target function signatures
|
||||
- all required parameters are passed end-to-end with correct semantics
|
||||
- type expectations are explicit and compatible at each boundary
|
||||
- initialization/default logic prevents unset/invalid state usage
|
||||
- scripts fallback path to `processing.process_images` is coherent
|
||||
|
||||
If only part of the workflow scope was checked, report partial pass with explicit exclusions.
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
---
|
||||
name: check-schedulers
|
||||
description: "Audit scheduler registrations starting from modules/sd_samplers_diffusers.py and verify class loadability, config validity against scheduler capabilities, and SamplerData correctness."
|
||||
argument-hint: "Optionally focus on a scheduler subset, such as flow-matching, res4lyf, or parallel schedulers"
|
||||
---
|
||||
|
||||
# Check Scheduler Registry And Contracts
|
||||
|
||||
Use `modules/sd_samplers_diffusers.py` as the starting point and verify that scheduler classes, scheduler config, and `SamplerData` mappings are coherent and executable.
|
||||
|
||||
## Required Guarantees
|
||||
|
||||
The audit must explicitly verify all three:
|
||||
|
||||
1. All scheduler classes can be loaded and compiled.
|
||||
2. All scheduler config entries are valid and match scheduler capabilities in `__init__`.
|
||||
3. All scheduler classes have valid associated `SamplerData` entries and mapping correctness.
|
||||
|
||||
## Scope
|
||||
|
||||
Primary file:
|
||||
|
||||
- `modules/sd_samplers_diffusers.py`
|
||||
|
||||
Related files:
|
||||
|
||||
- `modules/sd_samplers_common.py` for `SamplerData` definition and sampler expectations
|
||||
- `modules/sd_samplers.py` for sampler selection flow and runtime wiring
|
||||
- `modules/schedulers/**/*.py` for custom scheduler implementations
|
||||
- `modules/res4lyf/**/*.py` for Res4Lyf scheduler classes (if installed/enabled)
|
||||
|
||||
## What "Loaded And Compiled" Means
|
||||
|
||||
Treat this as a two-level check:
|
||||
|
||||
1. Import and class resolution:
|
||||
every scheduler class referenced by `SamplerData(... DiffusionSampler(..., SchedulerClass, ...))` resolves without missing symbol errors.
|
||||
2. Construction and compile sanity:
|
||||
scheduler instances can be created from their intended config path and survive a lightweight execution-level check (including compile path when available).
|
||||
|
||||
Notes:
|
||||
|
||||
- For non-`torch.nn.Module` schedulers, "compiled" means the scheduler integration path is executable in runtime checks (not necessarily `torch.compile`).
|
||||
- If the environment cannot run compile checks, report this explicitly and still complete static validation.
|
||||
|
||||
## Procedure
|
||||
|
||||
### 1. Build Scheduler Inventory
|
||||
|
||||
- Enumerate scheduler classes imported in `modules/sd_samplers_diffusers.py`.
|
||||
- Enumerate all entries in `samplers_data_diffusers`.
|
||||
- Enumerate all config keys in `config`.
|
||||
|
||||
Create a joined table by sampler name with:
|
||||
|
||||
- sampler name
|
||||
- scheduler class
|
||||
- config key used
|
||||
- custom scheduler category (diffusers, SD.Next custom, Res4Lyf)
|
||||
|
||||
### 2. Validate Scheduler Class Resolution
|
||||
|
||||
For each mapped scheduler class:
|
||||
|
||||
- confirm symbol exists and is importable
|
||||
- confirm class object is callable for construction paths used in SD.Next
|
||||
|
||||
Flag missing imports, dead entries, or stale class names.
|
||||
|
||||
### 3. Validate Config Against Scheduler __init__ Capabilities
|
||||
|
||||
For each sampler config entry:
|
||||
|
||||
- inspect scheduler class `__init__` signature and accepted config fields
|
||||
- flag config keys that are unsupported or misspelled
|
||||
- flag required scheduler init/config fields that are absent
|
||||
- verify defaults and overrides are compatible with scheduler family behavior
|
||||
|
||||
Special attention:
|
||||
|
||||
- flow-matching schedulers: shift/base_shift/max_shift/use_dynamic_shifting
|
||||
- DPM families: algorithm_type/solver_order/solver_type/final_sigmas_type
|
||||
- compatibility-only keys that are intentionally ignored should be documented, not silently assumed
|
||||
|
||||
### 4. Validate SamplerData Mapping Correctness
|
||||
|
||||
For each `SamplerData` entry:
|
||||
|
||||
- scheduler label matches config key intent
|
||||
- callable builds `DiffusionSampler` with the expected scheduler class
|
||||
- mapping is not accidentally pointing to a different named preset
|
||||
- no duplicate names with conflicting class/config behavior
|
||||
|
||||
Flag mismatches such as wrong display name, wrong class wired to name, or stale aliasing.
|
||||
|
||||
### 5. Runtime Smoke Checks (Preferred)
|
||||
|
||||
If feasible, run lightweight checks:
|
||||
|
||||
- instantiate each scheduler from config
|
||||
- execute minimal scheduler setup path (`set_timesteps` and a dummy step where possible)
|
||||
- verify no immediate runtime contract errors
|
||||
|
||||
If runtime checks are not feasible for some schedulers, mark those explicitly as unverified-at-runtime.
|
||||
|
||||
### 6. Compile Path Validation
|
||||
|
||||
Where scheduler runtime path supports compile-related checks in SD.Next:
|
||||
|
||||
- verify scheduler integration path remains compatible with compile options
|
||||
- detect obvious compile blockers introduced by signature/config mismatches
|
||||
|
||||
Do not mark compile as passed if only static checks were done.
|
||||
|
||||
## Reporting Format
|
||||
|
||||
Return findings ordered by severity:
|
||||
|
||||
1. Blocking scheduler load failures
|
||||
2. Config/signature contract mismatches
|
||||
3. `SamplerData` mapping inconsistencies
|
||||
4. Non-blocking improvements
|
||||
|
||||
For each finding include:
|
||||
|
||||
- sampler name
|
||||
- scheduler class
|
||||
- file location
|
||||
- mismatch reason
|
||||
- minimal corrective action
|
||||
|
||||
Also include summary counts:
|
||||
|
||||
- total scheduler classes discovered
|
||||
- total `SamplerData` entries checked
|
||||
- total config entries checked
|
||||
- runtime-validated count
|
||||
- compile-path validated count
|
||||
|
||||
## Pass Criteria
|
||||
|
||||
The check passes only if all are true:
|
||||
|
||||
- all referenced scheduler classes resolve
|
||||
- each scheduler config entry is compatible with scheduler capabilities
|
||||
- each `SamplerData` entry is correctly mapped and usable
|
||||
- no blocking runtime or compile-path failures in validated scope
|
||||
|
||||
If scope is partial due to environment limitations, report pass with explicit limitations, not a full pass.
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
---
|
||||
name: check-scripts
|
||||
description: "Audit scripts/*.py and verify Script override contracts (init/title/show) plus ui() output compatibility with run() or process() parameters."
|
||||
argument-hint: "Optionally focus on a subset of scripts or only run-vs-ui or process-vs-ui checks"
|
||||
---
|
||||
|
||||
# Check Script Class Contracts
|
||||
|
||||
Audit all Python scripts in `scripts/*.py` and validate that script class overrides and UI-to-execution parameter contracts are correct.
|
||||
|
||||
## When To Use
|
||||
|
||||
- New or changed files were added under `scripts/*.py`
|
||||
- A script crashes when selected or executed from UI
|
||||
- A script UI was changed and runtime args no longer match
|
||||
- You want a pre-PR quality gate for script API compatibility
|
||||
|
||||
## Scope
|
||||
|
||||
Primary audit scope:
|
||||
|
||||
- `scripts/*.py`
|
||||
|
||||
Contract references:
|
||||
|
||||
- `modules/scripts_manager.py` (`Script` base class contracts for `title`, `show`, `ui`, `run`, `process`)
|
||||
- `modules/scripts_postprocessing.py` (`ScriptPostprocessing` contracts for `ui` and `process`)
|
||||
|
||||
## Required Checks
|
||||
|
||||
### A. Standard Overrides: `__init__`, `title`, `show`
|
||||
|
||||
For each class in `scripts/*.py` that subclasses `scripts.Script` or `scripts_manager.Script`:
|
||||
|
||||
1. `title`:
|
||||
- method exists
|
||||
- callable signature is valid
|
||||
- returns non-empty string value
|
||||
|
||||
2. `show`:
|
||||
- method exists
|
||||
- signature is compatible with script runner usage (`show(is_img2img)` or permissive `*args/**kwargs`)
|
||||
- return behavior is compatible (`bool` or `scripts.AlwaysVisible` / equivalent)
|
||||
|
||||
3. `__init__` (if overridden):
|
||||
- does not require mandatory constructor args that would break loader instantiation
|
||||
- avoids side effects that require runtime-only globals at import time
|
||||
- leaves class in a usable state before `ui()`/`run()`/`process()` are called
|
||||
|
||||
Notes:
|
||||
- `__init__` is optional; do not fail scripts that rely on inherited constructor.
|
||||
- For dynamic patterns, flag as warning with rationale instead of hard fail.
|
||||
|
||||
### B. `ui()` Output vs `run()`/`process()` Parameters
|
||||
|
||||
For each script class:
|
||||
|
||||
1. Determine execution target:
|
||||
- Prefer `run()` if present for generation scripts
|
||||
- Use `process()` if present and `run()` is absent or script is postprocessing-oriented
|
||||
|
||||
2. Compare `ui()` output shape to target method parameter expectations:
|
||||
- `ui()` list/tuple output count should match target positional argument capacity after the first processing arg (`p` or `pp`), unless target uses `*args`
|
||||
- if target is strict positional (no `*args`/`**kwargs`), detect missing/extra UI values
|
||||
- if target uses keyword-driven processing, ensure UI dict keys map to accepted params or `**kwargs`
|
||||
|
||||
3. Validate ordering assumptions:
|
||||
- UI control order should align with positional parameter order when positional binding is used
|
||||
- detect obvious drift when new UI control was added but method signature was not updated
|
||||
|
||||
4. Validate optionality/defaults:
|
||||
- required target parameters should be satisfiable by UI outputs
|
||||
- defaulted target params are acceptable even if UI omits them
|
||||
|
||||
### C. Runner Compatibility
|
||||
|
||||
Confirm script methods align with runner expectations in `modules/scripts_manager.py`:
|
||||
|
||||
- `ui()` return type is compatible with runner collection (`list/tuple` or recognized mapping pattern where used)
|
||||
- `run()`/`process()` receive args in expected form from runner slices
|
||||
- no obvious mismatch between `args_from/args_to` assumptions and script method arity
|
||||
|
||||
For postprocessing-style scripts in `scripts/*.py`:
|
||||
|
||||
- verify compatibility with `modules/scripts_postprocessing.py` conventions (`ui()` list/dict, `process(pp, *args, **kwargs)`)
|
||||
|
||||
## Procedure
|
||||
|
||||
1. Enumerate all classes in `scripts/*.py` and classify by base class type.
|
||||
2. For each generation script class, validate `title`, `show`, optional `__init__`, and `ui` -> `run/process` contracts.
|
||||
3. For each postprocessing script class under `scripts/*.py`, validate `ui` -> `process` mapping semantics.
|
||||
4. Cross-check ambiguous cases against script runner behavior from `modules/scripts_manager.py` and `modules/scripts_postprocessing.py`.
|
||||
5. Report concrete mismatches with minimal fixes.
|
||||
|
||||
## Reporting Format
|
||||
|
||||
Return findings by severity:
|
||||
|
||||
1. Blocking script contract failures
|
||||
2. Runtime- likely arg/arity mismatches
|
||||
3. Signature/type compatibility warnings
|
||||
4. Style/consistency improvements
|
||||
|
||||
For each finding include:
|
||||
|
||||
- script file
|
||||
- class name
|
||||
- failing contract area (`init`, `title`, `show`, `ui->run`, `ui->process`)
|
||||
- mismatch summary
|
||||
- minimal fix
|
||||
|
||||
Also include summary counts:
|
||||
|
||||
- total `scripts/*.py` files checked
|
||||
- total script classes checked
|
||||
- classes with `run` contract checked
|
||||
- classes with `process` contract checked
|
||||
- override issues found (`init/title/show`)
|
||||
|
||||
## Pass Criteria
|
||||
|
||||
A full pass requires all of the following across audited `scripts/*.py` classes:
|
||||
|
||||
- `title` and `show` overrides are valid and runner-compatible for generation scripts
|
||||
- overridden `__init__` methods are safely instantiable
|
||||
- `ui()` output contracts are compatible with `run()` or `process()` args
|
||||
- no blocking arity/signature mismatch remains
|
||||
|
||||
If a class uses highly dynamic argument routing that cannot be proven statically, mark as conditional pass with explicit runtime validation recommendation.
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
---
|
||||
name: debug-model
|
||||
description: "Debug a broken SD.Next or Diffusers model integration. Use when a newly added or ported model fails to load, misdetects, crashes during prompt encoding or sampling, or produces incorrect outputs."
|
||||
argument-hint: "Describe the failing model, where it fails, the error message, and whether the model is upstream Diffusers, custom pipeline, or raw checkpoint based"
|
||||
---
|
||||
|
||||
# Debug SD.Next And Diffusers Model Port
|
||||
|
||||
Read the error, identify which integration layer is failing, isolate the smallest reproducible failure point, fix the root cause, and validate the fix without expanding scope.
|
||||
|
||||
## When To Use
|
||||
|
||||
- A newly added SD.Next model type does not autodetect correctly
|
||||
- The loader fails to instantiate a pipeline or component
|
||||
- A custom pipeline imports but fails during `from_pretrained`
|
||||
- Prompt encoding fails because of tokenizer, processor, or text encoder mismatch
|
||||
- Sampling fails due to tensor shape, dtype, device, or scheduler issues
|
||||
- The model loads but outputs corrupted images, wrong output type, or obviously incorrect results
|
||||
|
||||
## Debugging Order
|
||||
|
||||
Always debug from the outside in.
|
||||
|
||||
1. Detection and routing
|
||||
2. Loader arguments and component selection
|
||||
3. Checkpoint path and artifact layout
|
||||
4. Weight loading and key mapping
|
||||
5. Prompt encoding
|
||||
6. Sampling forward path
|
||||
7. Output postprocessing and SD.Next task integration
|
||||
|
||||
Do not start by rewriting the architecture if the failure is likely in detection, loader wiring, or output handling.
|
||||
|
||||
## Files To Check First
|
||||
|
||||
- `.github/copilot-instructions.md`
|
||||
- `.github/instructions/core.instructions.md`
|
||||
- `modules/sd_detect.py`
|
||||
- `modules/sd_models.py`
|
||||
- `modules/modeldata.py`
|
||||
- `pipelines/model_<name>.py`
|
||||
- `pipelines/<model>/model.py`
|
||||
- `pipelines/<model>/pipeline.py`
|
||||
- `pipelines/generic.py`
|
||||
|
||||
If the port is based on a standalone script, compare the failing path against the original reference implementation and identify the first semantic divergence.
|
||||
|
||||
## Failure Classification
|
||||
|
||||
### 1. Model Not Detected Or Misclassified
|
||||
|
||||
Check:
|
||||
|
||||
- Filename and repo-name heuristics in `modules/sd_detect.py`
|
||||
- Loader dispatch branch in `modules/sd_models.py`
|
||||
- Reverse pipeline classification in `modules/modeldata.py`
|
||||
|
||||
Typical symptoms:
|
||||
|
||||
- Wrong loader called
|
||||
- Pipeline classified as a broader family such as `chroma` instead of a custom `zetachroma`
|
||||
- Task switching behaves incorrectly because the loaded pipeline type is wrong
|
||||
|
||||
### 2. Loader Fails Before Pipeline Construction
|
||||
|
||||
Check:
|
||||
|
||||
- `sd_models.path_to_repo(checkpoint_info)` output
|
||||
- `generic.load_transformer(...)` and `generic.load_text_encoder(...)` arguments
|
||||
- Duplicate kwargs such as `torch_dtype`
|
||||
- Wrong class chosen for text encoder, tokenizer, or processor
|
||||
- Whether the source is really a Diffusers repo or only a raw checkpoint
|
||||
|
||||
Typical symptoms:
|
||||
|
||||
- Missing subfolder errors
|
||||
- `from_pretrained` argument mismatch
|
||||
- Component class mismatch
|
||||
|
||||
### 3. Raw Checkpoint Load Fails
|
||||
|
||||
Check:
|
||||
|
||||
- Checkpoint path resolution for local file, local directory, and Hub repo
|
||||
- State dict load method
|
||||
- Key remapping logic
|
||||
- Config inference from tensor shapes
|
||||
- Missing versus unexpected keys after `load_state_dict`
|
||||
|
||||
Typical symptoms:
|
||||
|
||||
- Key mismatch explosion
|
||||
- Wrong inferred head counts, dimensions, or decoder settings
|
||||
- Silent shape corruption caused by a bad remap
|
||||
|
||||
### 4. Prompt Encoding Fails
|
||||
|
||||
Check:
|
||||
|
||||
- Tokenizer or processor choice
|
||||
- `trust_remote_code` requirements
|
||||
- Chat template or custom prompt formatting
|
||||
- Hidden state index selection
|
||||
- Padding and batch alignment between positive and negative prompts
|
||||
|
||||
Typical symptoms:
|
||||
|
||||
- Tokenizer attribute errors
|
||||
- Hidden state shape mismatch
|
||||
- CFG failures when negative prompts do not match prompt batch length
|
||||
|
||||
### 5. Sampling Or Forward Pass Fails
|
||||
|
||||
Check:
|
||||
|
||||
- Input tensor shape and channel count
|
||||
- Device and dtype alignment across all components
|
||||
- Scheduler timesteps and expected timestep convention
|
||||
- Classifier-free guidance concatenation and split logic
|
||||
- Pixel-space versus latent-space assumptions
|
||||
|
||||
Typical symptoms:
|
||||
|
||||
- Shape mismatch in attention or decoder blocks
|
||||
- Device mismatch between text encoder output and model tensors
|
||||
- Images exploding to NaNs because timestep semantics are inverted
|
||||
|
||||
### 6. Output Is Wrong But No Exception Is Raised
|
||||
|
||||
Check:
|
||||
|
||||
- Whether the model predicts `x0`, noise, or velocity
|
||||
- Whether the Euler or other sampler update matches the model objective
|
||||
- Final scaling and clamp path
|
||||
- `output_type` handling and `pipe.task_args`
|
||||
- Whether a VAE is being applied incorrectly to direct pixel-space output
|
||||
|
||||
Typical symptoms:
|
||||
|
||||
- Black, gray, washed-out, or heavily clipped images
|
||||
- Output with correct size but obviously broken semantics
|
||||
- Correct tensors but wrong SD.Next display behavior because output type is mismatched
|
||||
|
||||
## Minimal Debug Procedure
|
||||
|
||||
### 1. Reproduce Narrowly
|
||||
|
||||
Capture the smallest failing operation.
|
||||
|
||||
- Pure import failure
|
||||
- Loader-only failure
|
||||
- `from_pretrained` failure
|
||||
- Prompt encode failure
|
||||
- Single forward pass failure
|
||||
- First sampler step failure
|
||||
|
||||
Prefer narrow Python checks before attempting a full generation run.
|
||||
|
||||
### 2. Compare Against Working Pattern
|
||||
|
||||
Find the closest working in-repo analogue and compare:
|
||||
|
||||
- Loader structure
|
||||
- Registered module names
|
||||
- Pipeline class name and module registration
|
||||
- Prompt encoding path
|
||||
- Output conversion path
|
||||
|
||||
### 3. Fix The Root Cause
|
||||
|
||||
Examples:
|
||||
|
||||
- Add the missing `modeldata` branch instead of patching downstream task handling
|
||||
- Fix checkpoint remapping rather than forcing `strict=False` and ignoring real mismatches
|
||||
- Correct the output path for pixel-space models instead of routing through a VAE
|
||||
- Make config inference fail explicitly when ambiguous instead of guessing silently
|
||||
|
||||
### 4. Validate In Layers
|
||||
|
||||
After each meaningful fix, validate the narrowest relevant layer first.
|
||||
|
||||
- `compileall` or syntax check
|
||||
- `ruff` on touched files
|
||||
- Import smoke test
|
||||
- Loader-only smoke test
|
||||
- Full run only when the lower layers are stable
|
||||
|
||||
## Common Root Causes
|
||||
|
||||
- `modules/modeldata.py` not updated after adding a new custom pipeline family
|
||||
- `modules/sd_detect.py` branch order causes overbroad detection to win first
|
||||
- Loader passes duplicated keyword args like `torch_dtype`
|
||||
- Shared text encoder assumptions do not match the actual model variant
|
||||
- `from_pretrained` assumes `transformer/` or `text_encoder/` subfolders that do not exist
|
||||
- Key remapping merges QKV in the wrong order
|
||||
- CFG path concatenates embeddings or latents incorrectly
|
||||
- Direct pixel-space models are postprocessed like latent-space diffusion outputs
|
||||
- Negative prompts are not padded or repeated to match prompt batch shape
|
||||
- Pipeline class naming collides with broader family checks in `modeldata`
|
||||
|
||||
## Validation Checklist
|
||||
|
||||
When closing the task, report which of these were completed:
|
||||
|
||||
1. Exact failing layer identified
|
||||
2. Root cause fixed
|
||||
3. Syntax check passed
|
||||
4. Focused lint passed
|
||||
5. Import or loader smoke test passed
|
||||
6. Real generation tested, or explicitly not tested
|
||||
|
||||
## Example Request Shapes
|
||||
|
||||
- "The new model port fails in from_pretrained"
|
||||
- "SD.Next detects my custom pipeline as the wrong model type"
|
||||
- "The loader works but generation returns black images"
|
||||
- "This standalone-script port loads weights but crashes in attention"
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
---
|
||||
name: github-features
|
||||
description: "Read SD.Next GitHub issues with [Feature] in the title and generate a markdown report with short summary, status, and suggested next steps per issue."
|
||||
argument-hint: "Optionally specify state (open/closed/all), max issues, and whether to include labels/assignees"
|
||||
---
|
||||
|
||||
# Summarize SD.Next [Feature] GitHub Issues
|
||||
|
||||
Fetch issues from the SD.Next GitHub repository that contain `[Feature]` in the title, then produce a concise markdown report with one entry per issue.
|
||||
|
||||
## When To Use
|
||||
|
||||
- The user asks for periodic feature-request triage summaries
|
||||
- You need an actionable status report for `[Feature]` tracker items
|
||||
- You want suggested next actions for each matching issue
|
||||
|
||||
## Repository
|
||||
|
||||
Default target repository:
|
||||
|
||||
- owner: `vladmandic`
|
||||
- name: `sdnext`
|
||||
|
||||
## Required Output
|
||||
|
||||
Create markdown containing, for each matching issue:
|
||||
|
||||
- issue link and title
|
||||
- short summary (1-3 sentences)
|
||||
- status (open/closed and relevant labels)
|
||||
- suggested next steps (1-3 concrete actions)
|
||||
|
||||
## Procedure
|
||||
|
||||
### 1. Search Matching Issues
|
||||
|
||||
Use GitHub search to find issues with `[Feature]` in title.
|
||||
|
||||
Preferred search query template:
|
||||
|
||||
- `is:issue in:title "[Feature]" repo:vladmandic/sdnext`
|
||||
|
||||
State filters (when requested):
|
||||
|
||||
- open only: add `is:open` (default)
|
||||
- closed only: add `is:closed`
|
||||
|
||||
Use `github-pull-request_doSearch` for the search step.
|
||||
|
||||
### 2. Fetch Full Issue Details
|
||||
|
||||
For each matched issue (within requested limit):
|
||||
|
||||
- fetch details with `github-pull-request_issue_fetch`
|
||||
- capture body, labels, assignees, state, updated time, and key discussion context
|
||||
|
||||
### 3. Build Per-Issue Summary
|
||||
|
||||
For each issue, produce:
|
||||
|
||||
1. Short summary:
|
||||
- describe feature request in plain language
|
||||
- include current progress signal if present
|
||||
|
||||
2. Status:
|
||||
- open/closed
|
||||
- notable labels (for example: enhancement, planned, blocked, stale)
|
||||
- optional assignee and last update signal
|
||||
|
||||
3. Suggested next steps:
|
||||
- propose concrete, minimal actions
|
||||
- tailor actions to issue state and content
|
||||
- avoid generic filler
|
||||
|
||||
### 4. Produce Markdown Report
|
||||
|
||||
Return a markdown table or bullet report with one row/section per issue.
|
||||
|
||||
Recommended table columns:
|
||||
|
||||
- Issue
|
||||
- Summary
|
||||
- Status
|
||||
- Suggested Next Steps
|
||||
|
||||
If there are many issues, keep summaries short and prioritize clarity.
|
||||
|
||||
## Reporting Rules
|
||||
|
||||
- Keep each issue summary concise and actionable.
|
||||
- Do not invent facts not present in issue data.
|
||||
- If issue body is sparse, state assumptions explicitly.
|
||||
- If no matching issues are found, output a clear "no matches" report.
|
||||
|
||||
## Pass Criteria
|
||||
|
||||
A successful run must:
|
||||
|
||||
- search SD.Next issues with `[Feature]` in title
|
||||
- include all matched issues in scope (or explicitly mention applied limit)
|
||||
- provide summary, status, and suggested next steps for each issue
|
||||
- return the final result as markdown
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
---
|
||||
name: github-issues
|
||||
description: "Read SD.Next GitHub issues with [Issues] in the title and generate a markdown report with short summary, status, and suggested next steps per issue."
|
||||
argument-hint: "Optionally specify state (open/closed/all), max issues, and whether to include labels/assignees"
|
||||
---
|
||||
|
||||
# Summarize SD.Next [Issues] GitHub Issues
|
||||
|
||||
Fetch issues from the SD.Next GitHub repository that contain `[Issues]` in the title, then produce a concise markdown report with one entry per issue.
|
||||
|
||||
## When To Use
|
||||
|
||||
- The user asks for periodic issue triage summaries
|
||||
- You need an actionable status report for `[Issues]` tracker items
|
||||
- You want suggested next actions for each matching issue
|
||||
|
||||
## Repository
|
||||
|
||||
Default target repository:
|
||||
|
||||
- owner: `vladmandic`
|
||||
- name: `sdnext`
|
||||
|
||||
## Required Output
|
||||
|
||||
Create markdown containing, for each matching issue:
|
||||
|
||||
- issue link and title
|
||||
- short summary (1-3 sentences)
|
||||
- status (open/closed and relevant labels)
|
||||
- suggested next steps (1-3 concrete actions)
|
||||
|
||||
## Procedure
|
||||
|
||||
### 1. Search Matching Issues
|
||||
|
||||
Use GitHub search to find issues with `[Issues]` in title.
|
||||
|
||||
Preferred search query template:
|
||||
|
||||
- `is:issue in:title "[Issues]" repo:vladmandic/sdnext`
|
||||
|
||||
State filters (when requested):
|
||||
|
||||
- open only: add `is:open` (default)
|
||||
- closed only: add `is:closed`
|
||||
|
||||
Use `github-pull-request_doSearch` for the search step.
|
||||
|
||||
### 2. Fetch Full Issue Details
|
||||
|
||||
For each matched issue (within requested limit):
|
||||
|
||||
- fetch details with `github-pull-request_issue_fetch`
|
||||
- capture body, labels, assignees, state, updated time, and key discussion context
|
||||
|
||||
### 3. Build Per-Issue Summary
|
||||
|
||||
For each issue, produce:
|
||||
|
||||
1. Short summary:
|
||||
- describe problem/request in plain language
|
||||
- include current progress signal if present
|
||||
|
||||
2. Status:
|
||||
- open/closed
|
||||
- notable labels (for example: bug, enhancement, stale, blocked)
|
||||
- optional assignee and last update signal
|
||||
|
||||
3. Suggested next steps:
|
||||
- propose concrete, minimal actions
|
||||
- tailor actions to issue state and content
|
||||
- avoid generic filler
|
||||
|
||||
### 4. Produce Markdown Report
|
||||
|
||||
Return a markdown table or bullet report with one row/section per issue.
|
||||
|
||||
Recommended table columns:
|
||||
|
||||
- Issue
|
||||
- Summary
|
||||
- Status
|
||||
- Suggested Next Steps
|
||||
|
||||
If there are many issues, keep summaries short and prioritize clarity.
|
||||
|
||||
## Reporting Rules
|
||||
|
||||
- Keep each issue summary concise and actionable.
|
||||
- Do not invent facts not present in issue data.
|
||||
- If issue body is sparse, state assumptions explicitly.
|
||||
- If no matching issues are found, output a clear "no matches" report.
|
||||
|
||||
## Pass Criteria
|
||||
|
||||
A successful run must:
|
||||
|
||||
- search SD.Next issues with `[Issues]` in title
|
||||
- include all matched issues in scope (or explicitly mention applied limit)
|
||||
- provide summary, status, and suggested next steps for each issue
|
||||
- return the final result as markdown
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
---
|
||||
name: port-model
|
||||
description: "Port or add a model to SD.Next using existing Diffusers and custom pipeline patterns. Use when implementing a new model loader, custom pipeline, checkpoint conversion path, or SD.Next model-type integration."
|
||||
argument-hint: "Describe the source model, target task, checkpoint format, and whether the model already has a Diffusers pipeline"
|
||||
---
|
||||
|
||||
# Port Model To SD.Next And Diffusers
|
||||
|
||||
Read the task, identify the model architecture and artifact layout, choose the narrowest integration path that matches existing SD.Next patterns, implement the loader and pipeline wiring, and validate the result.
|
||||
|
||||
## When To Use
|
||||
|
||||
- The user wants to add a new model family to SD.Next
|
||||
- A standalone inference script needs to become a Diffusers-style pipeline
|
||||
- A raw checkpoint or safetensors repo needs an SD.Next loader
|
||||
- A model already exists in Diffusers but is not yet wired into SD.Next
|
||||
- A custom architecture needs a repo-local `pipelines/<model>` package and loader
|
||||
|
||||
## Core Rule
|
||||
|
||||
Prefer the smallest correct integration path.
|
||||
|
||||
- If upstream Diffusers already supports the model cleanly, reuse the upstream pipeline and only add SD.Next wiring.
|
||||
- If the model requires custom architecture or sampler behavior, add a repo-local pipeline package under `pipelines/<model>`.
|
||||
- If the model ships as a raw checkpoint instead of a Diffusers repo, load or remap weights explicitly instead of pretending it is a standard Diffusers layout.
|
||||
|
||||
## Inputs To Collect First
|
||||
|
||||
Before editing anything, determine these facts:
|
||||
|
||||
- Model task: text-to-image, img2img, inpaint, editing, video, multimodal, or other
|
||||
- Artifact format: Diffusers repo, local folder, single-file safetensors, ckpt, GGUF, or custom layout
|
||||
- Core components: transformer or UNet, VAE or VQ model, text encoder, tokenizer or processor, scheduler, image encoder, adapters
|
||||
- Whether output is latent-space or direct pixel-space
|
||||
- Whether the model family has one fixed architecture or multiple variants
|
||||
- Whether prompt encoding follows a normal tokenizer path or a custom chat/template path
|
||||
- Whether there is an existing in-repo analogue to copy structurally
|
||||
|
||||
If any of these remain unclear after reading the repo and source artifacts, ask concise clarifying questions before implementing.
|
||||
|
||||
## Mandatory Category Question
|
||||
|
||||
Before implementing model-reference updates, explicitly ask the user which category the model belongs to:
|
||||
|
||||
- `base`
|
||||
- `cloud`
|
||||
- `quant`
|
||||
- `distilled`
|
||||
- `nunchaku`
|
||||
- `community`
|
||||
|
||||
Do not guess this category. Use the user answer to decide which reference JSON file(s) to update.
|
||||
|
||||
## Repo Files To Check
|
||||
|
||||
Start by reading the task description, then inspect the closest matching implementations.
|
||||
|
||||
- `.github/copilot-instructions.md`
|
||||
- `.github/instructions/core.instructions.md`
|
||||
- `pipelines/generic.py`
|
||||
- `pipelines/model_*.py` files similar to the target model
|
||||
- `modules/sd_models.py`
|
||||
- `modules/sd_detect.py`
|
||||
- `modules/modeldata.py`
|
||||
|
||||
Useful examples by pattern:
|
||||
|
||||
- Existing upstream Diffusers loader: `pipelines/model_chroma.py`, `pipelines/model_z_image.py`
|
||||
- Custom in-repo pipeline package: `pipelines/f_lite/`
|
||||
- Shared Qwen loader pattern: `pipelines/model_z_image.py`, `pipelines/model_flux2_klein.py`
|
||||
|
||||
## Integration Decision Tree
|
||||
|
||||
### 1. Upstream Diffusers Support Exists
|
||||
|
||||
Use this path when the model already has a usable Diffusers pipeline and component classes.
|
||||
|
||||
Implement:
|
||||
|
||||
- `pipelines/model_<name>.py`
|
||||
- `modules/sd_models.py` dispatch branch
|
||||
- `modules/sd_detect.py` filename autodetect branch if appropriate
|
||||
- `modules/modeldata.py` model type detection branch
|
||||
|
||||
Reuse:
|
||||
|
||||
- `generic.load_transformer(...)`
|
||||
- `generic.load_text_encoder(...)`
|
||||
- Existing processor or tokenizer loading patterns
|
||||
- `sd_hijack_te.init_hijack(pipe)` and `sd_hijack_vae.init_hijack(pipe)` where relevant
|
||||
|
||||
### 2. Custom Pipeline Needed
|
||||
|
||||
Use this path when the model architecture, sampler, or prompt encoding is not available upstream.
|
||||
|
||||
Implement:
|
||||
|
||||
- `pipelines/<model>/__init__.py`
|
||||
- `pipelines/<model>/model.py`
|
||||
- `pipelines/<model>/pipeline.py`
|
||||
- `pipelines/model_<name>.py`
|
||||
- SD.Next registration in `modules/sd_models.py`, `modules/sd_detect.py`, and `modules/modeldata.py`
|
||||
|
||||
Model module responsibilities:
|
||||
|
||||
- Architecture classes
|
||||
- Config handling and defaults
|
||||
- Weight remapping or checkpoint conversion helpers
|
||||
- Raw checkpoint loading helpers if needed
|
||||
|
||||
Pipeline module responsibilities:
|
||||
|
||||
- `DiffusionPipeline` subclass
|
||||
- `__init__`
|
||||
- `from_pretrained`
|
||||
- `encode_prompt` or equivalent prompt preparation
|
||||
- `__call__`
|
||||
- Output dataclass
|
||||
- Optional callback handling and output conversion
|
||||
|
||||
### 3. Raw Checkpoint Or Single-File Weights
|
||||
|
||||
Use this path when the model source is not a normal Diffusers repository.
|
||||
|
||||
Requirements:
|
||||
|
||||
- Resolve checkpoint path from local file, local directory, or Hub repo
|
||||
- Load state dict directly
|
||||
- Remap keys when training format differs from inference format
|
||||
- Infer config from tensor shapes only if the model family truly varies and no config file exists
|
||||
- Raise explicit errors for ambiguous or incomplete layouts
|
||||
|
||||
Do not fake a `from_pretrained` implementation that silently assumes missing subfolders exist.
|
||||
|
||||
## Required SD.Next Touchpoints
|
||||
|
||||
Most new model families need all of these:
|
||||
|
||||
- `pipelines/model_<name>.py`
|
||||
Purpose: SD.Next loader entry point
|
||||
- `modules/sd_models.py`
|
||||
Purpose: route detected model type to the correct loader
|
||||
- `modules/sd_detect.py`
|
||||
Purpose: detect model family from filename or repo name
|
||||
- `modules/modeldata.py`
|
||||
Purpose: classify loaded pipeline instance back into SD.Next model type
|
||||
|
||||
Add only what the model actually needs.
|
||||
|
||||
Reference catalog touchpoints are also required for model ports intended to appear in SD.Next model references.
|
||||
|
||||
- `data/reference.json` for `base`
|
||||
- `data/reference-cloud.json` for `cloud`
|
||||
- `data/reference-quant.json` for `quant`
|
||||
- `data/reference-distilled.json` for `distilled`
|
||||
- `data/reference-nunchaku.json` for `nunchaku`
|
||||
- `data/reference-community.json` for `community`
|
||||
|
||||
If the model belongs to multiple categories, update each corresponding `data/reference*.json` file.
|
||||
|
||||
Possible extra integration points:
|
||||
|
||||
- `diffusers.pipelines.auto_pipeline.AUTO_*_PIPELINES_MAPPING` when task switching matters
|
||||
- `pipe.task_args` when SD.Next needs default runtime kwargs such as `output_type`
|
||||
- VAE hijack, TE hijack, or task-specific processors
|
||||
|
||||
## Loader Conventions
|
||||
|
||||
In `pipelines/model_<name>.py`:
|
||||
|
||||
- Use `sd_models.path_to_repo(checkpoint_info)`
|
||||
- Call `sd_models.hf_auth_check(checkpoint_info)`
|
||||
- Use `model_quant.get_dit_args(...)` for load args
|
||||
- Respect `devices.dtype`
|
||||
- Log meaningful loader details
|
||||
- Reuse `generic.load_transformer(...)` and `generic.load_text_encoder(...)` when possible
|
||||
- Set `pipe.task_args = {'output_type': 'np'}` when the pipeline should default to numpy output for SD.Next
|
||||
- Clean up temporary references and call `devices.torch_gc(...)`
|
||||
|
||||
Do not hardcode assumptions about CUDA-only execution, local paths, or one-off environment state.
|
||||
|
||||
## Pipeline Conventions
|
||||
|
||||
When building a custom pipeline:
|
||||
|
||||
- Inherit from `diffusers.DiffusionPipeline`
|
||||
- Register modules with `DiffusionPipeline.register_modules(...)`
|
||||
- Keep prompt encoding and output conversion inside the pipeline
|
||||
- Support `prompt`, `negative_prompt`, `generator`, `output_type`, and `return_dict` when the task is text-to-image-like
|
||||
- Expose only parameters that are part of the model’s actual sampling surface
|
||||
- Preserve direct pixel-space output if the model does not use a VAE
|
||||
- Use a VAE only if the model genuinely outputs latents
|
||||
|
||||
Do not add generic Stable Diffusion arguments that the model does not support.
|
||||
|
||||
## Validation Checklist
|
||||
|
||||
After implementation, validate in this order:
|
||||
|
||||
1. Syntax or bytecode compilation for new files
|
||||
2. Focused linting on touched files
|
||||
3. Import-level smoke test for new modules when safe
|
||||
4. Loader-path validation without doing a full generation if the model is too large
|
||||
5. One real load and generation pass if feasible
|
||||
|
||||
Always report what was validated and what was not.
|
||||
|
||||
## Reference Asset Update Requirements
|
||||
|
||||
When the user asks to add or port a model for references, also perform these steps:
|
||||
|
||||
1. Update the correct `data/reference*.json` file(s) based on the user-confirmed category.
|
||||
2. Create a placeholder zero-byte thumbnail file in `models/Reference` for the new model.
|
||||
|
||||
Notes:
|
||||
|
||||
- In this repo, the folder is `models/Reference` (capital `R`).
|
||||
- Use a deterministic filename that matches the model entry naming convention used in the target reference JSON.
|
||||
- If a real thumbnail already exists, do not overwrite it with a zero-byte file.
|
||||
|
||||
## Common Failure Modes
|
||||
|
||||
- Model type is added to `sd_models.py` but not `sd_detect.py`
|
||||
- Loader exists but `modules/modeldata.py` still classifies the pipeline incorrectly
|
||||
- A custom pipeline is named in a way that collides with a broader existing branch such as `Chroma`
|
||||
- `torch_dtype` or other loader args are passed twice
|
||||
- The code assumes a Diffusers repo layout when the source is really a single checkpoint file
|
||||
- Negative prompt handling does not match prompt batch size
|
||||
- Output is postprocessed as latents even though the model is pixel-space
|
||||
- Custom config inference uses fragile defaults without clear failure paths
|
||||
|
||||
## Output Expectations
|
||||
|
||||
When using this skill, the final implementation should usually include:
|
||||
|
||||
- A clear integration path choice
|
||||
- The minimal set of new files required
|
||||
- SD.Next routing updates
|
||||
- Targeted validation results
|
||||
- Any remaining runtime risks called out explicitly
|
||||
|
||||
## Example Request Shapes
|
||||
|
||||
- "Port this standalone inference script into an SD.Next Diffusers pipeline"
|
||||
- "Add support for this Hugging Face model repo to SD.Next"
|
||||
- "Wire this upstream Diffusers pipeline into SD.Next autodetect and loading"
|
||||
- "Convert this single-file checkpoint model into a custom Diffusers pipeline for SD.Next"
|
||||
62
AGENTS.md
62
AGENTS.md
|
|
@ -1,60 +1,8 @@
|
|||
# SD.Next: AGENTS.md Project Guidelines
|
||||
|
||||
SD.Next is a complex codebase with specific patterns and conventions.
|
||||
General app structure is:
|
||||
- Python backend server
|
||||
Uses Torch for model inference, FastAPI for API routes and Gradio for creation of UI components.
|
||||
- JavaScript/CSS frontend
|
||||
**SD.Next** is a complex codebase with specific patterns and conventions:
|
||||
- Read main instructions file `.github/copilot-instructions.md` for general project guidelines, tools, structure, style, and conventions.
|
||||
- For core tasks, also review instructions `.github/instructions/core.instructions.md`
|
||||
- For UI tasks, also review instructions `.github/instructions/ui.instructions.md`
|
||||
|
||||
## Tools
|
||||
|
||||
- `venv` for Python environment management, activated with `source venv/bin/activate` (Linux) or `venv\Scripts\activate` (Windows).
|
||||
venv MUST be activated before running any Python commands or scripts to ensure correct dependencies and environment variables.
|
||||
- `python` 3.10+.
|
||||
- `pyproject.toml` for Python configuration, including linting and type checking settings.
|
||||
- `eslint` configured for both core and UI code.
|
||||
- `pnpm` for managing JavaScript dependencies and scripts, with key commands defined in `package.json`.
|
||||
- `ruff` and `pylint` for Python linting, with configurations in `pyproject.toml` and executed via `pnpm ruff` and `pnpm pylint`.
|
||||
- `pre-commit` hooks which also check line-endings and other formatting issues, configured in `.pre-commit-config.yaml`.
|
||||
|
||||
## Project Structure
|
||||
|
||||
- Entry/startup flow: `webui.sh` -> `launch.py` -> `webui.py` -> modules under `modules/`.
|
||||
- Install: `installer.py` takes care of installing dependencies and setting up the environment.
|
||||
- Core runtime state is centralized in `modules/shared.py` (shared.opts, model state, backend/device state).
|
||||
- API/server routes are under `modules/api/`.
|
||||
- UI codebase is split between base JS in `javascript/` and actual UI in `extensions-builtin/sdnext-modernui/`.
|
||||
- Model and pipeline logic is split between `modules/sd_*` and `pipelines/`.
|
||||
- Additional plug-ins live in `scripts/` and are used only when specified.
|
||||
- Extensions live in `extensions-builtin/` and `extensions/` and are loaded dynamically.
|
||||
- Tests and CLI scripts are under `test/` and `cli/`, with some API smoke checks in `test/full-test.sh`.
|
||||
|
||||
## Code Style
|
||||
|
||||
- Prefer existing project patterns over strict generic style rules;
|
||||
this codebase intentionally allows patterns often flagged in default linters such as allowing long lines, etc.
|
||||
|
||||
## Build And Test
|
||||
|
||||
- Activate environment: `source venv/bin/activate` (always ensure this is active when working with Python code).
|
||||
- Test startup: `python launch.py --test`
|
||||
- Full startup: `python launch.py`
|
||||
- Full lint sequence: `pnpm lint`
|
||||
- Python checks individually: `pnpm ruff`, `pnpm pylint`
|
||||
- JS checks: `pnpm eslint` and `pnpm eslint-ui`
|
||||
|
||||
## Conventions
|
||||
|
||||
- Keep PR-ready changes targeted to `dev` branch.
|
||||
- Use conventions from `CONTRIBUTING`.
|
||||
- Do not include unrelated edits or submodule changes when preparing contributions.
|
||||
- Use existing CLI/API tool patterns in `cli/` and `test/` when adding automation scripts.
|
||||
- Respect environment-driven behavior (`SD_*` flags and options) instead of hardcoding platform/model assumptions.
|
||||
- For startup/init edits, preserve error handling and partial-failure tolerance in parallel scans and extension loading.
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- Initialization order matters: startup paths in `launch.py` and `webui.py` are sensitive to import/load timing.
|
||||
- Shared mutable global state can create subtle regressions; prefer narrow, explicit changes.
|
||||
- Device/backend-specific code paths (**CUDA/ROCm/IPEX/DirectML/OpenVINO**) should not assume one platform.
|
||||
- Scripts and extension loading is dynamic; failures may appear only when specific extensions or models are present.
|
||||
For specific SKILLS, also review the relevant skill files specified in `.github/skills/README.md` and listed `.github/skills/*/SKILL.md`
|
||||
|
|
|
|||
23
CHANGELOG.md
23
CHANGELOG.md
|
|
@ -1,11 +1,19 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-04-09
|
||||
## Update for 2026-04-10
|
||||
|
||||
- **Models**
|
||||
- [AiArtLab SDXS-1B](https://huggingface.co/AiArtLab/sdxs-1b) Simple Diffusion XS *(training still in progress)*
|
||||
this model combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE
|
||||
- [Anima Preview-v3](https://huggingface.co/circlestone-labs/Anima) new version of Anima
|
||||
- **Agents**
|
||||
framework for AI agent based work in `.github/`
|
||||
- general instructions: `copilot-instructions.md` (not copilot specific)
|
||||
- additional instructions in `instructions/`: `core.instructions.md`, `ui.instructions.md`
|
||||
- skills in in `skills/README.md`:
|
||||
*validation skills*: `check-models`, `check-api`, `check-schedulers`, `check-processing`, `check-scripts`
|
||||
*model skills*: `port-model`, `debug-model`, `analyze-model`
|
||||
*github skills*: `github-issues`, `github-features`
|
||||
- **Caption & Prompt Enhance**
|
||||
- [Google Gemma 4] in *E2B* and *E4B* variants
|
||||
- **Compute**
|
||||
|
|
@ -18,8 +26,19 @@
|
|||
- enhanced filename pattern processing
|
||||
allows for any *processing* property name (as defined in `modules/processing_class.py` and saved to `ui-config.json`)
|
||||
allows for any *settings* property name (as defined in `modules/ui_definitions.py` and saved to `config.json`)
|
||||
- **Agents**
|
||||
created framework for AI agent based work in `.github/`
|
||||
*note*: all skills are agent-model agnostic
|
||||
- general instructions:
|
||||
`AGENTS.md`, `copilot-instructions.md`
|
||||
- additional instructions in `instructions/`:
|
||||
`core.instructions.md`, `ui.instructions.md`
|
||||
- skills in in `skills/README.md`:
|
||||
*validation skills*: `check-models`, `check-api`, `check-schedulers`, `check-processing`, `check-scripts`
|
||||
*model skills*: `port-model`, `debug-model`, `analyze-model`
|
||||
*github skills*: `github-issues`, `github-features`
|
||||
- **Obsoleted**
|
||||
- remove *system-info* from *extensions-builtin*
|
||||
- removed *system-info* from *extensions-builtin*
|
||||
- **Internal**
|
||||
- additional typing and typechecks, thanks @awsr
|
||||
- wrap hf download methods
|
||||
|
|
|
|||
5
TODO.md
5
TODO.md
|
|
@ -46,7 +46,8 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
|
|||
|
||||
### Image-Base
|
||||
|
||||
- [NucleusMoe](<https://github.com/huggingface/diffusers/pull/13317)
|
||||
- [Mugen](https://huggingface.co/CabalResearch/Mugen)
|
||||
- [NucleusMoe](https://github.com/huggingface/diffusers/pull/13317)
|
||||
- [Chroma Zeta](https://huggingface.co/lodestones/Zeta-Chroma): Image and video generator for creative effects and professional filters
|
||||
- [Chroma Radiance](https://huggingface.co/lodestones/Chroma1-Radiance): Pixel-space model eliminating VAE artifacts for high visual fidelity
|
||||
- [Bria FIBO](https://huggingface.co/briaai/FIBO): Fully JSON based
|
||||
|
|
@ -58,6 +59,7 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
|
|||
|
||||
### Image-Edit
|
||||
|
||||
- [Tencent HY-WU](https://huggingface.co/tencent/HY-WU)
|
||||
- [JoyAI Image Edit](https://huggingface.co/jdopensource/JoyAI-Image-Edit)
|
||||
- [Bria FIBO-Edit](https://huggingface.co/briaai/Fibo-Edit-RMBG): Fully JSON-based instruction-following image editing framework
|
||||
- [Meituan LongCat-Image-Edit-Turbo](https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo):6B instruction-following image editing with high visual consistency
|
||||
|
|
@ -71,6 +73,7 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
|
|||
|
||||
### Video
|
||||
|
||||
- [HY-OmniWeaving](https://huggingface.co/tencent/HY-OmniWeaving)
|
||||
- [LTX-Condition](https://github.com/huggingface/diffusers/pull/13058)
|
||||
- [LTX-Distilled](https://github.com/huggingface/diffusers/pull/12934)
|
||||
- [OpenMOSS MOVA](https://huggingface.co/OpenMOSS-Team/MOVA-720p): Unified foundation model for synchronized high-fidelity video and audio
|
||||
|
|
|
|||
|
|
@ -244,6 +244,15 @@
|
|||
"size": 26.84,
|
||||
"date": "2025 July"
|
||||
},
|
||||
"lodestones Zeta-Chroma": {
|
||||
"path": "lodestones/Zeta-Chroma",
|
||||
"preview": "lodestones--Zeta-Chroma.jpg",
|
||||
"desc": "Zeta-Chroma is a pixel-space diffusion transformer image model from lodestones that generates images directly in RGB space using a NextDiT-style architecture.",
|
||||
"skip": true,
|
||||
"extras": "sampler: Default, cfg_scale: 3.0, steps: 30",
|
||||
"size": 12.11,
|
||||
"date": "2026 April"
|
||||
},
|
||||
|
||||
"Meituan LongCat Image": {
|
||||
"path": "meituan-longcat/LongCat-Image",
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ def get_model_type(pipe):
|
|||
model_type = 'sc'
|
||||
elif "AuraFlow" in name:
|
||||
model_type = 'auraflow'
|
||||
elif 'ZetaChroma' in name:
|
||||
model_type = 'zetachroma'
|
||||
elif 'Chroma' in name:
|
||||
model_type = 'chroma'
|
||||
elif "Flux2" in name:
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ def guess_by_name(fn, current_guess):
|
|||
new_guess = 'Stable Diffusion 3'
|
||||
elif 'hidream' in fn.lower():
|
||||
new_guess = 'HiDream'
|
||||
elif 'zeta-chroma' in fn.lower() or 'zetachroma' in fn.lower():
|
||||
new_guess = 'ZetaChroma'
|
||||
elif 'chroma' in fn.lower() and 'xl' not in fn.lower():
|
||||
new_guess = 'Chroma'
|
||||
elif 'flux.2' in fn.lower() and 'klein' in fn.lower():
|
||||
|
|
|
|||
|
|
@ -385,6 +385,10 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
|
|||
from pipelines.model_flex import load_flex
|
||||
sd_model = load_flex(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['ZetaChroma']:
|
||||
from pipelines.model_zetachroma import load_zetachroma
|
||||
sd_model = load_zetachroma(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Chroma']:
|
||||
from pipelines.model_chroma import load_chroma
|
||||
sd_model = load_chroma(checkpoint_info, diffusers_load_config)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
import sys
|
||||
|
||||
import diffusers
|
||||
import transformers
|
||||
from modules import devices, model_quant, sd_hijack_te, sd_models, shared
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
TEXT_ENCODER_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
||||
|
||||
|
||||
def load_zetachroma(checkpoint_info, diffusers_load_config=None):
|
||||
if diffusers_load_config is None:
|
||||
diffusers_load_config = {}
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
load_args.setdefault("torch_dtype", devices.dtype)
|
||||
log.debug(
|
||||
f'Load model: type=ZetaChroma repo="{repo_id}" config={diffusers_load_config} '
|
||||
f'offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}'
|
||||
)
|
||||
|
||||
from pipelines import generic, zetachroma
|
||||
|
||||
diffusers.ZetaChromaPipeline = zetachroma.ZetaChromaPipeline
|
||||
sys.modules["zetachroma"] = zetachroma
|
||||
|
||||
text_encoder = generic.load_text_encoder(
|
||||
TEXT_ENCODER_REPO,
|
||||
cls_name=transformers.Qwen3ForCausalLM,
|
||||
load_config=diffusers_load_config,
|
||||
)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
TEXT_ENCODER_REPO,
|
||||
subfolder="tokenizer",
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
pipe = zetachroma.ZetaChromaPipeline.from_pretrained(
|
||||
repo_id,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
trust_remote_code=True,
|
||||
**load_args,
|
||||
)
|
||||
pipe.task_args = {
|
||||
"output_type": "np",
|
||||
}
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["zetachroma"] = zetachroma.ZetaChromaPipeline
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
devices.torch_gc(force=True, reason="load")
|
||||
return pipe
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from .model import (
|
||||
DEFAULT_TEXT_ENCODER_REPO,
|
||||
NextDiTPixelSpace,
|
||||
infer_model_config,
|
||||
load_zetachroma_transformer,
|
||||
remap_checkpoint_keys,
|
||||
)
|
||||
from .pipeline import ZetaChromaPipeline, ZetaChromaPipelineOutput
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_TEXT_ENCODER_REPO",
|
||||
"NextDiTPixelSpace",
|
||||
"ZetaChromaPipeline",
|
||||
"ZetaChromaPipelineOutput",
|
||||
"infer_model_config",
|
||||
"load_zetachroma_transformer",
|
||||
"remap_checkpoint_keys",
|
||||
]
|
||||
|
|
@ -0,0 +1,732 @@
|
|||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
DEFAULT_TEXT_ENCODER_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
||||
DEFAULT_CHECKPOINT_FILENAME = "zeta-chroma-base-x0-pixel-dino-distance.safetensors"
|
||||
|
||||
|
||||
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(t.dtype) if torch.is_floating_point(t) else embedding
|
||||
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: float) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=pos.device)
|
||||
omega = 1.0 / (theta ** scale)
|
||||
out = torch.einsum("...n,d->...nd", pos.to(torch.float64), omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
rot = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
return rot.reshape(*out.shape, 2, 2).float()
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: float, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(ids.shape[-1])], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def _apply_rope_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 1, 2)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out = x_out + freqs_cis[..., 1] * x_[..., 1]
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
return _apply_rope_single(xq, freqs_cis), _apply_rope_single(xk, freqs_cis)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256, output_size: int | None = None):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, output_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
def forward(self, t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
return self.mlp(t_freq)
|
||||
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
def __init__(self, dim: int, n_heads: int, n_kv_heads: int | None = None, qk_norm: bool = True):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = n_heads // self.n_kv_heads
|
||||
self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False)
|
||||
self.out = nn.Linear(n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = nn.RMSNorm(self.head_dim, elementwise_affine=True)
|
||||
self.k_norm = nn.RMSNorm(self.head_dim, elementwise_affine=True)
|
||||
else:
|
||||
self.q_norm = nn.Identity()
|
||||
self.k_norm = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = x.shape
|
||||
qkv = self.qkv(x)
|
||||
xq, xk, xv = torch.split(
|
||||
qkv,
|
||||
[
|
||||
self.n_heads * self.head_dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
xq = self.q_norm(xq.view(batch_size, sequence_length, self.n_heads, self.head_dim))
|
||||
xk = self.k_norm(xk.view(batch_size, sequence_length, self.n_kv_heads, self.head_dim))
|
||||
xv = xv.view(batch_size, sequence_length, self.n_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
if self.n_rep > 1:
|
||||
xk = xk.unsqueeze(2).repeat(1, 1, self.n_rep, 1, 1).flatten(1, 2)
|
||||
xv = xv.unsqueeze(2).repeat(1, 1, self.n_rep, 1, 1).flatten(1, 2)
|
||||
|
||||
out = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask)
|
||||
return self.out(out.transpose(1, 2).reshape(batch_size, sequence_length, -1))
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, inner_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.w2 = nn.Linear(inner_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, inner_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class JointTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int | None,
|
||||
ffn_hidden_dim: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation: bool = True,
|
||||
z_image_modulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
|
||||
self.feed_forward = FeedForward(dim, ffn_hidden_dim)
|
||||
self.attention_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
|
||||
self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
|
||||
self.attention_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
|
||||
self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
|
||||
|
||||
if modulation:
|
||||
if z_image_modulation:
|
||||
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, 256), 4 * dim, bias=True))
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(min(dim, 1024), 4 * dim, bias=True))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
self.attention(self.attention_norm1(x) * (1 + scale_msa.unsqueeze(1)), mask, freqs_cis)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
self.feed_forward(self.ffn_norm1(x) * (1 + scale_mlp.unsqueeze(1)))
|
||||
)
|
||||
return x
|
||||
|
||||
x = x + self.attention_norm2(self.attention(self.attention_norm1(x), mask, freqs_cis))
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
return x
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int = 8):
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input))
|
||||
self._pos_cache: dict[tuple[int, str, str], torch.Tensor] = {}
|
||||
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
cache_key = (patch_size, str(device), str(dtype))
|
||||
if cache_key in self._pos_cache:
|
||||
return self._pos_cache[cache_key]
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
|
||||
self._pos_cache[cache_key] = dct
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, patch_area, channels = inputs.shape
|
||||
_ = channels
|
||||
patch_size = int(patch_area**0.5)
|
||||
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
|
||||
return self.embedder(torch.cat((inputs, dct), dim=-1))
|
||||
|
||||
|
||||
class PixelResBlock(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 3 * channels, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
hidden = self.in_ln(x) * (1 + scale) + shift
|
||||
return x + gate * self.mlp(hidden)
|
||||
|
||||
|
||||
class DCTFinalLayer(nn.Module):
|
||||
def __init__(self, model_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
|
||||
|
||||
class SimpleMLPAdaLN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
z_channels: int,
|
||||
num_res_blocks: int,
|
||||
max_freqs: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_embed = nn.Linear(z_channels, model_channels)
|
||||
self.input_embedder = NerfEmbedder(in_channels, model_channels, max_freqs)
|
||||
self.res_blocks = nn.ModuleList([PixelResBlock(model_channels) for _ in range(num_res_blocks)])
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
x = self.input_embedder(x)
|
||||
y = self.cond_embed(c).unsqueeze(1)
|
||||
for block in self.res_blocks:
|
||||
x = block(x, y)
|
||||
return self.final_layer(x)
|
||||
|
||||
|
||||
def pad_to_patch_size(x: torch.Tensor, patch_size: tuple[int, int]) -> torch.Tensor:
|
||||
_, _, height, width = x.shape
|
||||
patch_h, patch_w = patch_size
|
||||
pad_h = (patch_h - height % patch_h) % patch_h
|
||||
pad_w = (patch_w - width % patch_w) % patch_w
|
||||
if pad_h or pad_w:
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0)
|
||||
return x
|
||||
|
||||
|
||||
def pad_zimage(feats: torch.Tensor, pad_token: torch.Tensor, pad_tokens_multiple: int) -> tuple[torch.Tensor, int]:
|
||||
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
|
||||
if pad_extra > 0:
|
||||
feats = torch.cat(
|
||||
[
|
||||
feats,
|
||||
pad_token.to(device=feats.device, dtype=feats.dtype).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return feats, pad_extra
|
||||
|
||||
|
||||
class NextDiTPixelSpace(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc]
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 32,
|
||||
in_channels: int = 3,
|
||||
dim: int = 3840,
|
||||
n_layers: int = 30,
|
||||
n_refiner_layers: int = 2,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int | None = None,
|
||||
ffn_hidden_dim: int = 10240,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
cap_feat_dim: int = 2560,
|
||||
axes_dims: list[int] | None = None,
|
||||
axes_lens: list[int] | None = None,
|
||||
rope_theta: float = 256.0,
|
||||
time_scale: float = 1000.0,
|
||||
pad_tokens_multiple: int | None = 128,
|
||||
decoder_hidden_size: int = 3840,
|
||||
decoder_num_res_blocks: int = 4,
|
||||
decoder_max_freqs: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
axes_dims = axes_dims or [32, 48, 48]
|
||||
axes_lens = axes_lens or [1536, 512, 512]
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.time_scale = time_scale
|
||||
self.pad_tokens_multiple = pad_tokens_multiple
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
|
||||
self.x_embedder = nn.Linear(patch_size * patch_size * in_channels, dim, bias=True)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
ffn_hidden_dim,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
)
|
||||
for _ in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
ffn_hidden_dim,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for _ in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size=min(dim, 1024),
|
||||
frequency_embedding_size=256,
|
||||
output_size=256,
|
||||
)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
nn.RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True),
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
ffn_hidden_dim,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if pad_tokens_multiple is not None:
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
assert dim // n_heads == sum(axes_dims)
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=list(axes_dims))
|
||||
dec_in_ch = patch_size**2 * in_channels
|
||||
self.dec_net = SimpleMLPAdaLN(
|
||||
in_channels=dec_in_ch,
|
||||
model_channels=decoder_hidden_size,
|
||||
out_channels=dec_in_ch,
|
||||
z_channels=dim,
|
||||
num_res_blocks=decoder_num_res_blocks,
|
||||
max_freqs=decoder_max_freqs,
|
||||
)
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
def embed_cap(self, cap_feats: torch.Tensor, offset: float = 0, bsz: int = 1, device=None, dtype=None):
|
||||
_ = dtype
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats_len = cap_feats.shape[1]
|
||||
if self.pad_tokens_multiple is not None:
|
||||
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
|
||||
freqs_cis = self.rope_embedder(cap_pos_ids).movedim(1, 2)
|
||||
return cap_feats, freqs_cis, cap_feats_len
|
||||
|
||||
def pos_ids_x(self, start_t: float, h_tokens: int, w_tokens: int, batch_size: int, device):
|
||||
x_pos_ids = torch.zeros((batch_size, h_tokens * w_tokens, 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = start_t
|
||||
x_pos_ids[:, :, 1] = (
|
||||
torch.arange(h_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, w_tokens).flatten()
|
||||
)
|
||||
x_pos_ids[:, :, 2] = (
|
||||
torch.arange(w_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(h_tokens, 1).flatten()
|
||||
)
|
||||
return x_pos_ids
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_size: list[tuple[int, int]], cap_size: list[int]) -> torch.Tensor:
|
||||
patch_h = patch_w = self.patch_size
|
||||
imgs = []
|
||||
for index in range(x.size(0)):
|
||||
height, width = img_size[index]
|
||||
begin = cap_size[index]
|
||||
end = begin + (height // patch_h) * (width // patch_w)
|
||||
imgs.append(
|
||||
x[index][begin:end]
|
||||
.view(height // patch_h, width // patch_w, patch_h, patch_w, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
return torch.stack(imgs, dim=0)
|
||||
|
||||
def _forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, num_tokens: int | None) -> torch.Tensor:
|
||||
_ = num_tokens
|
||||
timesteps = 1.0 - timesteps
|
||||
_, _, original_h, original_w = x.shape
|
||||
x = pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
batch_size, channels, height, width = x.shape
|
||||
t_emb = self.t_embedder(timesteps * self.time_scale, dtype=x.dtype)
|
||||
patch_h = patch_w = self.patch_size
|
||||
|
||||
pixel_patches = (
|
||||
x.view(batch_size, channels, height // patch_h, patch_h, width // patch_w, patch_w)
|
||||
.permute(0, 2, 4, 3, 5, 1)
|
||||
.flatten(3)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
num_image_tokens = pixel_patches.shape[1]
|
||||
pixel_values = pixel_patches.reshape(batch_size * num_image_tokens, 1, patch_h * patch_w * channels)
|
||||
|
||||
cap_feats_emb, cap_freqs_cis, _ = self.embed_cap(context, offset=0, bsz=batch_size, device=x.device, dtype=x.dtype)
|
||||
|
||||
x_tokens = self.x_embedder(
|
||||
x.view(batch_size, channels, height // patch_h, patch_h, width // patch_w, patch_w)
|
||||
.permute(0, 2, 4, 3, 5, 1)
|
||||
.flatten(3)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
cap_feats_len_total = cap_feats_emb.shape[1]
|
||||
x_pos_ids = self.pos_ids_x(cap_feats_len_total + 1, height // patch_h, width // patch_w, batch_size, x.device)
|
||||
if self.pad_tokens_multiple is not None:
|
||||
x_tokens, x_pad_extra = pad_zimage(x_tokens, self.x_pad_token, self.pad_tokens_multiple)
|
||||
x_pos_ids = F.pad(x_pos_ids, (0, 0, 0, x_pad_extra))
|
||||
x_freqs_cis = self.rope_embedder(x_pos_ids).movedim(1, 2)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats_emb = layer(cap_feats_emb, None, cap_freqs_cis)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x_tokens = layer(x_tokens, None, x_freqs_cis, t_emb)
|
||||
|
||||
padded_full_embed = torch.cat([cap_feats_emb, x_tokens], dim=1)
|
||||
full_freqs_cis = torch.cat([cap_freqs_cis, x_freqs_cis], dim=1)
|
||||
img_len = x_tokens.shape[1]
|
||||
cap_size = [padded_full_embed.shape[1] - img_len] * batch_size
|
||||
|
||||
hidden_states = padded_full_embed
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, None, full_freqs_cis.to(hidden_states.device), t_emb)
|
||||
|
||||
img_hidden = hidden_states[:, cap_size[0] : cap_size[0] + num_image_tokens, :]
|
||||
decoder_cond = img_hidden.reshape(batch_size * num_image_tokens, self.dim)
|
||||
output = self.dec_net(pixel_values, decoder_cond).reshape(batch_size, num_image_tokens, -1)
|
||||
cap_placeholder = torch.zeros(
|
||||
batch_size,
|
||||
cap_size[0],
|
||||
output.shape[-1],
|
||||
device=output.device,
|
||||
dtype=output.dtype,
|
||||
)
|
||||
img_size = [(height, width)] * batch_size
|
||||
img_out = self.unpatchify(torch.cat([cap_placeholder, output], dim=1), img_size, cap_size)
|
||||
return -img_out[:, :, :original_h, :original_w]
|
||||
|
||||
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, num_tokens: int | None = None):
|
||||
x0_pred = self._forward(x, timesteps, context, num_tokens)
|
||||
return (x - x0_pred) / timesteps.view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
def remap_checkpoint_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
qkv_parts: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key == "__x0__":
|
||||
new_sd[key] = value
|
||||
continue
|
||||
if ".attention.to_q." in key:
|
||||
prefix = key.split(".attention.to_q.")[0]
|
||||
suffix = key.split(".attention.to_q.")[1]
|
||||
qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["q"] = value
|
||||
continue
|
||||
if ".attention.to_k." in key:
|
||||
prefix = key.split(".attention.to_k.")[0]
|
||||
suffix = key.split(".attention.to_k.")[1]
|
||||
qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["k"] = value
|
||||
continue
|
||||
if ".attention.to_v." in key:
|
||||
prefix = key.split(".attention.to_v.")[0]
|
||||
suffix = key.split(".attention.to_v.")[1]
|
||||
qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["v"] = value
|
||||
continue
|
||||
if ".attention.to_out.0." in key:
|
||||
new_sd[key.replace(".attention.to_out.0.", ".attention.out.")] = value
|
||||
continue
|
||||
if ".attention.norm_q." in key:
|
||||
new_sd[key.replace(".attention.norm_q.", ".attention.q_norm.")] = value
|
||||
continue
|
||||
if ".attention.norm_k." in key:
|
||||
new_sd[key.replace(".attention.norm_k.", ".attention.k_norm.")] = value
|
||||
continue
|
||||
new_sd[key] = value
|
||||
|
||||
for combined_key, parts in qkv_parts.items():
|
||||
if not {"q", "k", "v"}.issubset(parts):
|
||||
raise ValueError(f"Incomplete QKV for {combined_key}: got {list(parts.keys())}")
|
||||
prefix = combined_key.rsplit(".", 1)[0]
|
||||
suffix = combined_key.rsplit(".", 1)[1]
|
||||
new_sd[f"{prefix}.qkv.{suffix}"] = torch.cat([parts["q"], parts["k"], parts["v"]], dim=0)
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def default_axes_dims_for_head_dim(head_dim: int) -> list[int]:
|
||||
if head_dim % 8 == 0:
|
||||
first = head_dim // 4
|
||||
second = (head_dim * 3) // 8
|
||||
third = head_dim - first - second
|
||||
if first % 2 == 0 and second % 2 == 0 and third % 2 == 0:
|
||||
return [first, second, third]
|
||||
if head_dim == 128:
|
||||
return [32, 48, 48]
|
||||
if head_dim % 6 == 0:
|
||||
split = head_dim // 3
|
||||
return [split, split, split]
|
||||
raise ValueError(f"Unable to infer axes_dims for head_dim={head_dim}; provide axes_dims explicitly")
|
||||
|
||||
|
||||
def _count_module_blocks(state_dict: dict[str, torch.Tensor], prefix: str) -> int:
|
||||
indices = set()
|
||||
needle = f"{prefix}."
|
||||
prefix_parts = prefix.count(".") + 1
|
||||
for key in state_dict:
|
||||
if not key.startswith(needle):
|
||||
continue
|
||||
parts = key.split(".")
|
||||
if len(parts) <= prefix_parts:
|
||||
continue
|
||||
try:
|
||||
indices.add(int(parts[prefix_parts]))
|
||||
except ValueError:
|
||||
continue
|
||||
return max(indices) + 1 if indices else 0
|
||||
|
||||
|
||||
def infer_model_config(
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
axes_dims: list[int] | None = None,
|
||||
axes_lens: list[int] | None = None,
|
||||
rope_theta: float = 256.0,
|
||||
time_scale: float = 1000.0,
|
||||
pad_tokens_multiple: int | None = 128,
|
||||
norm_eps: float = 1e-5,
|
||||
) -> dict:
|
||||
dim = state_dict["x_embedder.weight"].shape[0]
|
||||
patch_channels = state_dict["x_embedder.weight"].shape[1]
|
||||
in_channels = 3 if patch_channels % 3 == 0 else 1
|
||||
patch_size = int(round((patch_channels / in_channels) ** 0.5))
|
||||
head_dim = state_dict["layers.0.attention.q_norm.weight"].shape[0]
|
||||
n_heads = dim // head_dim
|
||||
total_proj_heads = state_dict["layers.0.attention.qkv.weight"].shape[0] // head_dim
|
||||
n_kv_heads = (total_proj_heads - n_heads) // 2
|
||||
if n_kv_heads == n_heads:
|
||||
n_kv_heads = None
|
||||
qk_norm = "layers.0.attention.q_norm.weight" in state_dict
|
||||
decoder_input = state_dict["dec_net.input_embedder.embedder.0.weight"].shape[1]
|
||||
decoder_max_freqs = int(round((decoder_input - patch_channels) ** 0.5))
|
||||
config = {
|
||||
"patch_size": patch_size,
|
||||
"in_channels": in_channels,
|
||||
"dim": dim,
|
||||
"n_layers": _count_module_blocks(state_dict, "layers"),
|
||||
"n_refiner_layers": _count_module_blocks(state_dict, "noise_refiner"),
|
||||
"n_heads": n_heads,
|
||||
"n_kv_heads": n_kv_heads,
|
||||
"ffn_hidden_dim": state_dict["layers.0.feed_forward.w1.weight"].shape[0],
|
||||
"norm_eps": norm_eps,
|
||||
"qk_norm": qk_norm,
|
||||
"cap_feat_dim": state_dict["cap_embedder.1.weight"].shape[1],
|
||||
"axes_dims": axes_dims or default_axes_dims_for_head_dim(head_dim),
|
||||
"axes_lens": axes_lens or [1536, 512, 512],
|
||||
"rope_theta": rope_theta,
|
||||
"time_scale": time_scale,
|
||||
"pad_tokens_multiple": pad_tokens_multiple,
|
||||
"decoder_hidden_size": state_dict["dec_net.cond_embed.weight"].shape[0],
|
||||
"decoder_num_res_blocks": _count_module_blocks(state_dict, "dec_net.res_blocks"),
|
||||
"decoder_max_freqs": decoder_max_freqs,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def resolve_checkpoint_path(
|
||||
pretrained_model_name_or_path: str,
|
||||
checkpoint_filename: str | None = None,
|
||||
cache_dir: str | None = None,
|
||||
local_files_only: bool | None = None,
|
||||
) -> str:
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
return pretrained_model_name_or_path
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if checkpoint_filename is not None:
|
||||
candidate = os.path.join(pretrained_model_name_or_path, checkpoint_filename)
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
candidates = [
|
||||
os.path.join(pretrained_model_name_or_path, name)
|
||||
for name in os.listdir(pretrained_model_name_or_path)
|
||||
if name.lower().endswith(".safetensors")
|
||||
]
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No .safetensors checkpoint found in {pretrained_model_name_or_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Multiple .safetensors checkpoints found in {pretrained_model_name_or_path}; provide checkpoint_filename"
|
||||
)
|
||||
|
||||
filename = checkpoint_filename
|
||||
if filename is None:
|
||||
repo_files = [name for name in list_repo_files(pretrained_model_name_or_path) if name.lower().endswith(".safetensors")]
|
||||
if not repo_files and DEFAULT_CHECKPOINT_FILENAME:
|
||||
repo_files = [DEFAULT_CHECKPOINT_FILENAME]
|
||||
if len(repo_files) == 1:
|
||||
filename = repo_files[0]
|
||||
elif DEFAULT_CHECKPOINT_FILENAME in repo_files:
|
||||
filename = DEFAULT_CHECKPOINT_FILENAME
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Unable to determine checkpoint filename for {pretrained_model_name_or_path}; provide checkpoint_filename"
|
||||
)
|
||||
|
||||
return hf_hub_download(
|
||||
repo_id=pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
|
||||
def load_zetachroma_transformer(
|
||||
pretrained_model_name_or_path: str,
|
||||
checkpoint_filename: str | None = None,
|
||||
torch_dtype: torch.dtype | None = None,
|
||||
cache_dir: str | None = None,
|
||||
local_files_only: bool | None = None,
|
||||
model_config: dict | None = None,
|
||||
axes_dims: list[int] | None = None,
|
||||
axes_lens: list[int] | None = None,
|
||||
rope_theta: float = 256.0,
|
||||
time_scale: float = 1000.0,
|
||||
pad_tokens_multiple: int | None = 128,
|
||||
) -> NextDiTPixelSpace:
|
||||
checkpoint_path = resolve_checkpoint_path(
|
||||
pretrained_model_name_or_path,
|
||||
checkpoint_filename=checkpoint_filename,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
raw_state_dict = load_file(checkpoint_path, device="cpu")
|
||||
state_dict = remap_checkpoint_keys(raw_state_dict)
|
||||
del raw_state_dict
|
||||
config = infer_model_config(
|
||||
state_dict,
|
||||
axes_dims=axes_dims,
|
||||
axes_lens=axes_lens,
|
||||
rope_theta=rope_theta,
|
||||
time_scale=time_scale,
|
||||
pad_tokens_multiple=pad_tokens_multiple,
|
||||
)
|
||||
if model_config is not None:
|
||||
config.update(model_config)
|
||||
|
||||
model = NextDiTPixelSpace(**config)
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
ignored_missing = {"__x0__"}
|
||||
relevant_missing = [key for key in missing if key not in ignored_missing]
|
||||
if relevant_missing or unexpected:
|
||||
raise ValueError(
|
||||
"Zeta-Chroma checkpoint load mismatch: "
|
||||
f"missing={relevant_missing[:20]} unexpected={unexpected[:20]}"
|
||||
)
|
||||
if torch_dtype is not None:
|
||||
model = model.to(dtype=torch_dtype)
|
||||
return model
|
||||
|
|
@ -0,0 +1,302 @@
|
|||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, Qwen3ForCausalLM
|
||||
|
||||
from .model import DEFAULT_TEXT_ENCODER_REPO, NextDiTPixelSpace, load_zetachroma_transformer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZetaChromaPipelineOutput(BaseOutput):
|
||||
images: Union[List[Image.Image], np.ndarray, torch.Tensor]
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> torch.Tensor:
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
if shift:
|
||||
mu = base_shift + (image_seq_len / 4096) * (max_shift - base_shift)
|
||||
timesteps = math.exp(mu) / (math.exp(mu) + (1.0 / timesteps.clamp(min=1e-9) - 1.0))
|
||||
timesteps[0] = 1.0
|
||||
timesteps[-1] = 0.0
|
||||
return timesteps
|
||||
|
||||
|
||||
class ZetaChromaPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "text_encoder->dit_model"
|
||||
|
||||
dit_model: NextDiTPixelSpace
|
||||
text_encoder: PreTrainedModel
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
_progress_bar_config: Dict[str, Any]
|
||||
|
||||
def __init__(self, dit_model: NextDiTPixelSpace, text_encoder: PreTrainedModel, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__()
|
||||
DiffusionPipeline.register_modules(self, dit_model=dit_model, text_encoder=text_encoder, tokenizer=tokenizer)
|
||||
if hasattr(self.text_encoder, "requires_grad_"):
|
||||
self.text_encoder.requires_grad_(False)
|
||||
if hasattr(self.dit_model, "requires_grad_"):
|
||||
self.dit_model.requires_grad_(False)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
dit_model: Optional[NextDiTPixelSpace] = None,
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
text_encoder_repo: str = DEFAULT_TEXT_ENCODER_REPO,
|
||||
checkpoint_filename: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
trust_remote_code: bool = True,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
local_files_only: Optional[bool] = None,
|
||||
model_config: Optional[dict] = None,
|
||||
axes_dims: Optional[list[int]] = None,
|
||||
axes_lens: Optional[list[int]] = None,
|
||||
rope_theta: float = 256.0,
|
||||
time_scale: float = 1000.0,
|
||||
pad_tokens_multiple: int | None = 128,
|
||||
**kwargs,
|
||||
):
|
||||
_ = kwargs
|
||||
if tokenizer is None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_encoder_repo,
|
||||
subfolder="tokenizer",
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if text_encoder is None:
|
||||
text_encoder = Qwen3ForCausalLM.from_pretrained(
|
||||
text_encoder_repo,
|
||||
subfolder="text_encoder",
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if dit_model is None:
|
||||
dit_model = load_zetachroma_transformer(
|
||||
pretrained_model_name_or_path,
|
||||
checkpoint_filename=checkpoint_filename,
|
||||
torch_dtype=torch_dtype,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
model_config=model_config,
|
||||
axes_dims=axes_dims,
|
||||
axes_lens=axes_lens,
|
||||
rope_theta=rope_theta,
|
||||
time_scale=time_scale,
|
||||
pad_tokens_multiple=pad_tokens_multiple,
|
||||
)
|
||||
pipe = cls(dit_model=dit_model, text_encoder=text_encoder, tokenizer=tokenizer)
|
||||
if torch_dtype is not None:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
return pipe
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
def progress_bar(self, iterable=None, **kwargs): # pylint: disable=arguments-differ
|
||||
self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {}
|
||||
config = {**self._progress_bar_config, **kwargs}
|
||||
return tqdm(iterable, **config)
|
||||
|
||||
def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False): # pylint: disable=arguments-differ
|
||||
_ = silence_dtype_warnings
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.text_encoder.to(device=torch_device, dtype=torch_dtype)
|
||||
if hasattr(self, "dit_model"):
|
||||
self.dit_model.to(device=torch_device, dtype=torch_dtype)
|
||||
return self
|
||||
|
||||
def _encode_single_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_sequence_length: int,
|
||||
) -> torch.Tensor:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
inputs = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).to(device)
|
||||
|
||||
outputs = self.text_encoder(
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden = outputs.hidden_states[-2].to(dtype)
|
||||
hidden = hidden * inputs.attention_mask.unsqueeze(-1).to(dtype)
|
||||
actual_len = int(inputs.attention_mask.sum(dim=1).max().item())
|
||||
return hidden[:, :actual_len, :]
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * len(prompt)
|
||||
elif len(negative_prompt) != len(prompt):
|
||||
raise ValueError("negative_prompt must have the same batch length as prompt")
|
||||
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or next(self.text_encoder.parameters()).dtype
|
||||
|
||||
positive = [self._encode_single_prompt(text, device, dtype, max_sequence_length) for text in prompt]
|
||||
negative = [self._encode_single_prompt(text, device, dtype, max_sequence_length) for text in negative_prompt]
|
||||
max_len = max(max(tensor.shape[1] for tensor in positive), max(tensor.shape[1] for tensor in negative))
|
||||
|
||||
def pad_sequence(sequence: torch.Tensor) -> torch.Tensor:
|
||||
if sequence.shape[1] == max_len:
|
||||
return sequence
|
||||
return torch.cat(
|
||||
[
|
||||
sequence,
|
||||
torch.zeros(
|
||||
sequence.shape[0],
|
||||
max_len - sequence.shape[1],
|
||||
sequence.shape[2],
|
||||
dtype=sequence.dtype,
|
||||
device=sequence.device,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([pad_sequence(tensor) for tensor in positive], dim=0)
|
||||
negative_embeds = torch.cat([pad_sequence(tensor) for tensor in negative], dim=0)
|
||||
return prompt_embeds, negative_embeds
|
||||
|
||||
def _postprocess_images(self, images: torch.Tensor, output_type: str):
|
||||
images = images.float().clamp(-1, 1) * 0.5 + 0.5
|
||||
if output_type == "pt":
|
||||
return images
|
||||
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
if output_type == "np":
|
||||
return images
|
||||
if output_type == "pil":
|
||||
return self.numpy_to_pil((images * 255).round().clip(0, 255).astype(np.uint8))
|
||||
raise ValueError(f"Unsupported output_type: {output_type}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = 1024,
|
||||
width: Optional[int] = 1024,
|
||||
num_inference_steps: int = 30,
|
||||
guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 512,
|
||||
shift_schedule: bool = True,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end=None,
|
||||
**kwargs,
|
||||
):
|
||||
_ = kwargs
|
||||
height = height or 1024
|
||||
width = width or 1024
|
||||
dtype = dtype or next(self.dit_model.parameters()).dtype
|
||||
device = self._execution_device
|
||||
|
||||
prompt_embeds, negative_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
latents = randn_tensor((batch_size, 3, height, width), generator=generator, device=device, dtype=dtype)
|
||||
image_seq_len = (height // self.dit_model.patch_size) * (width // self.dit_model.patch_size)
|
||||
sigmas = get_schedule(
|
||||
num_inference_steps,
|
||||
image_seq_len,
|
||||
base_shift=base_shift,
|
||||
max_shift=max_shift,
|
||||
shift=shift_schedule,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
self.dit_model.eval()
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
for step, timestep in enumerate(self.progress_bar(range(num_inference_steps), desc="Sampling")):
|
||||
_ = timestep
|
||||
t_curr = sigmas[step]
|
||||
t_next = sigmas[step + 1]
|
||||
t_batch = torch.full((batch_size,), t_curr, device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
latent_input = torch.cat([latents, latents], dim=0)
|
||||
timestep_input = torch.cat([t_batch, t_batch], dim=0)
|
||||
context_input = torch.cat([prompt_embeds, negative_embeds], dim=0)
|
||||
velocity = self.dit_model(latent_input, timestep_input, context_input, context_input.shape[1])
|
||||
v_cond, v_uncond = velocity.chunk(2, dim=0)
|
||||
velocity = v_uncond + guidance_scale * (v_cond - v_uncond)
|
||||
else:
|
||||
velocity = self.dit_model(latents, t_batch, prompt_embeds, prompt_embeds.shape[1])
|
||||
|
||||
latents = latents + (t_next - t_curr) * velocity
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {
|
||||
"latents": latents,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"negative_prompt_embeds": negative_embeds,
|
||||
}
|
||||
callback_result = callback_on_step_end(self, step, int(t_curr.item() * 1000), callback_kwargs)
|
||||
if isinstance(callback_result, dict) and "latents" in callback_result:
|
||||
latents = callback_result["latents"]
|
||||
|
||||
images = self._postprocess_images(latents, output_type=output_type)
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
return ZetaChromaPipelineOutput(images=images)
|
||||
Loading…
Reference in New Issue