v25.0.0 release (#3138)

* Add support for custom learning rate scheduler type to the GUI

* Add .webp image extension support to BLIP2 captioning.

* Check for --debug flag for gui command-line args at startup

* Validate GPU ID accelerate input and return error when needed

* Update to latest sd-scripts dev commit

* Fix issue with pip upgrade

* Remove confusing log after command execution.

* piecewise_constant scheduler

* Update to latest sd-scripts dev commit

* fix: fixed docker-compose for passing models via volumes

* Prevent providing the legacy learning_rate if unet or te learning rate is provided

* Fix toml noise offset parameters based on selected type

* Fix adaptive_noise_scale value not properly loading from json config

* Fix prompt.txt location

* Improve "print command" output format

* Use output model name as wandb run name if not provided

* Update sd-scripts dev release

* Bump crate-ci/typos from 1.21.0 to 1.22.9

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.21.0 to 1.22.9.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.21.0...v1.22.9)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump docker/build-push-action from 5 to 6

Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 5 to 6.
- [Release notes](https://github.com/docker/build-push-action/releases)
- [Commits](https://github.com/docker/build-push-action/compare/v5...v6)

---
updated-dependencies:
- dependency-name: docker/build-push-action
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Get latest sd3 code

* Adding SD3 GUI elements

* Fix interactivity

* MVP GUI for SD3

* Fix text encoder issue

* Add fork section to readme

* Update sd3 commit

* Merge security-fix

* Update sc-script to latest code

* Auto-detect model type for safetensors files

Automatically tick the checkboxes for v2 and SDXL on the common training UI
and LoRA extract/merge utilities.

* autodetect-modeltype: remove unused lambda inputs

* rework TE1/TE2 learning rate handling for SDXL dreambooth

SDXL dreambooth apparently trains without the text encoders by default,
requiring the `--train_text_encoder` flag to be passed so that the
learning rates for TE1/TE2 are recognized.

The toml handling now permits 0 to be passed as a learning rate in
order to disable training of one or both text encoders.
This behavior aligns with the description given on the GUI.

TE1/TE2 learning rate parameters can be left blank on the GUI to
not pass a value to the training script.

* dreambooth_gui: fix toml value filtering condition

In python3, `0 == False` will evaluate True.
That can cause arg values of 0 to be wrongly eliminated from the toml output.
The conditional must check the type when comparing for False.

* autodetect-modeltype: also do the v2 checkbox in extract_lora

* Update to latest dev branch code

* bring back SDXLConfig accordion for dreambooth gui (#2694)

b-fission <b-fission@users.noreply.github.com>

* Update to latest sd3 branch commit

* Fix merge issue

* Update gradio version

* Update to latest flux.1 code

* Add Flux.1 Model checkbox and detection

* Adding LoRA type "Flux1" to dropdown

* Added Flux.1 parameters to GUI

* Update sd-scripts and requirements

* Add missing Flux.1 GUI parameters

* Update to latest sd-scripts sd3 code

* Fix issue with cache_text_encoder_outputs

* Update to latest sd-scripts flux1 code

* Adding new flux.1 options to GUI

* Update to latest sd-scripts version of flux.1

* Adding guidance_scale option

* Update to latest sd3 flux.1 sd-scripts

* Add dreambooth and finetuning support for flux.1

* Update README

* Fix t5xxl path issue in DB

* add missing fp8_base parameter

* Fix issue with guidance scale not being passed as float for values like 1

* Temporary fir for blockwise_fused_optimizers

* Update to latest sd-scripts Flux.1 code

* Fix blockwise_fused_optimizers typo

* Add mem_eff_save option to GUI for Flux.1

* Added support for Flux.1 LoRA Merge

* Update to latest sd-scripts sd3 branch code

* Add diffusers option to flux.1 merge LoRA utility

* Fix issue with split_mode and train_blocks

* Updating requirements

* Add flux_fused_backward_pass to dreambooth and finetuning

* Update requirements_linux_docker.txt

update accelerate version for linux_docker

* Update to latest sd3 flux code

* Add extract flux lora GUI

* MErged latest sd3 branch code

* Add support for split_qkv

* Add missing network argument for split_qkv

* Add timestep_sampling shift support

* Update to latest sd-scripts flux.1 code

* Add support for fp8_base_unet

* Update requirements as per sd-scripts suggestion

* Upgrade to cu124

* Update IPEX and ROCm

* Fix issue with balancing when folder with name already exist

* Update sd-scripts

* Removed unsupported parameters from flux lora network

* Bump crate-ci/typos from 1.23.6 to 1.24.3

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.3.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.3)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update sd-scripts code

* Adding flux_shift option to timestep_sampling

* Update sd-scripts release

* Add support for Train T5-XXL

* Update sd-scripts submodule

* Add support for cpu_offload_checkpointing to GUI

* Force t5xxl_max_token_length  to be served as an integer

* Fix typo for flux_shift

* Update to latest sd-scripts code

* Grouping lora parameters

* Validate if lora type is Flux1 when flux1_checkbox is true

* Improve visual sectioning of parameters for lora

* Add dark mode styles

* Missed one color

* Update sd-scripts and add support for t5xxl LR

* Update transformers and wandb module

* Fix issue with new text_encoder_lr parameter syntax

* Add support for lr_warmup_steps override

* Update lr_warmup_steps code

* Removing stable-diffusion-1.5 default model

* Fix for max_train_steps

* Revert some changes

* Preliminary support for Flux1 OFT

* Fix logic typo

* Update sd-scripts

* Add support for Rank for layers

* Update lora_gui.py

Fixed minor typos of "Regularization"

* Update dreambooth_gui.py

Fixed minor typos of "Regularization"

* Update textual_inversion_gui.py

Fixed minor typos of "Regularization"

* Add support for Blocks to train

* Add missing network parms

* Fix issue with old_lr_warmup_steps

* Update sd-scripts

* Add support for ScheduleFree Optimizer Type

* Update sd-scripts

* Update requirements_pytorch_windows.txt

* Update requirements_pytorch_windows.txt

* Update sd-scripts from origin

* Another sd-script update

* Adding support for blocks_to_swap option to gui

* Fix xformers install issue

* feat(docker): mount models folder as a volume

* feat(docker): add models folder to .dockerignore

* Add support for AdEMAMix8bit optimizer

* Bump crate-ci/typos from 1.23.6 to 1.25.0

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.25.0.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.25.0)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Fix typo on README.md

* Add new --noverify option to skip requirements validation on startup

* Update startup GUI code

* Update setup code

* Update sd-scripts

* Update sf-scripts

* Update Lycoris support

* Allow to specify tensorboard host via env var TENSORBOARD_HOST

* Update sd-scripts version

* Update sd-scripts release

* Update sd-scripts

* Add --skip_cache_check option to GUI

* Fix requirements issue

* Add support for LyCORIS LoRA when training Flux.1

* Pin huggingface-hub version for gradio 5

* Update sd-scripts

* Add support for --save_last_n_epochs_state

* Update sd-scripts to version with Differential Output Preservation

* Increase maximum flux-lora merge strength to 2

* Update to latest sd-scripts

* Update requirements syntax (for windows)

* Update requirements for linux

* Update torch version and validation output

* Fix typo

* Update README

* Fix validation issue on linux

* Update sd-scripts, improve requirements outputs

* Update requirements_runpod.txt

* Update requirements for onnxruntime-gpu

Needed for compatibility with CUDA 12.

* Update onnxruntime-gpu==1.19.2

* Update sd-scripts release

* Add support for save_last_n_epochs

* Update sd-scripts

* Bump crate-ci/typos from 1.23.6 to 1.26.8 (#2940)

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.26.8.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.26.8)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: bmaltais <bernard@ducourier.com>

* fix 'cached_download' from 'huggingface_hub' (#2947)

Describe the bug: cannot import name 'cached_download' from 'huggingface_hub'

It's applyed for all platforms

Co-authored-by: bmaltais <bernard@ducourier.com>

* Add support for quiet output for linux setup

* Fix quiet issue

* Update sd-scripts

* Update sd-scripts with blocks_to_swap support

* Make blocks_to_swap visible in LoRA tab

* Fix blocks_to_swap not properly working

* Update sd-scripts and allow python 3.10 to 3.12

* Fix issue with max_train_steps

* Fix max_train_steps_info error

* Reverting all changes for max_train_steps

* Update sd-scripts

* Update sd-scripts

* Update to latest sd-scripts

* Add support for RAdamScheduleFree

* Add support for huber_scale

* Add support for fused_backward_pass for sd3 finetuning

* Add support for prodigyplus.ProdigyPlusScheduleFree

* SD3 LoRA training MVP

* Make blocks_to_swap common

* Add support for sd3 lora disable_mmap_load_safetensors

* Add a bunch of missing SD3 parameters

* Fix clip_l issue for missing path

* Fix train_t5xxl issue

* Fix network_module issue

* Add uniform to weighting_scheme

* Bump crate-ci/typos from 1.23.6 to 1.28.1 (#2996)

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.28.1.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.28.1)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: bmaltais <bernard@ducourier.com>

* Update README.md (#3031)

* Bump crate-ci/typos from 1.23.6 to 1.29.0 (#3029)

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.29.0.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.29.0)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: bmaltais <bernard@ducourier.com>

* Update sd-scripts version

* Update setup.sh (#3054)

Enter the current directory before executing setup.sh, otherwise the installer might failed to find rqeuirements.txt

* Removing wrong folder

* Fix issue with SD3 Lora training blocks_to_swap and fused_backward_pass

* Fix dreambooth issue

* Update to lastest sd-scripts code

* Run on novita (#3119) (#3120)

* add run on novita

* adjust position

Co-authored-by: hugo <liyiligang@users.noreply.github.com>

* Bump crate-ci/typos from 1.23.6 to 1.30.0 (#3101)

Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.30.0.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.30.0)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: bmaltais <bernard@ducourier.com>

* updated prodigyopt to 1.1.2 and removed duplicated row in requirements.txt (#3065)

* fixed names on LR Schedure dropdown (#3064)

* Update to latest sd-scripts version

* fixed names on LR Schedure dropdown (#3064)

* Cleanup venv3

* Fix issue with gradio on new installations
Add support for latest sd-scripts pytorch-optimizer

* Update README for v25.0.0 release

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: b-fission <b-fission@users.noreply.github.com>
Co-authored-by: DevArqSangoi <lucas.sangoi@gmail.com>
Co-authored-by: Кирилл Москвин <retreat.cost@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: b-fission <131207849+b-fission@users.noreply.github.com>
Co-authored-by: eftSharptooth <76253264+eftSharptooth@users.noreply.github.com>
Co-authored-by: Disty0 <disty@disty.xyz>
Co-authored-by: wcole3 <will.cole3@gmail.com>
Co-authored-by: rohitanshu <85547195+iamrohitanshu@users.noreply.github.com>
Co-authored-by: wzgrx <39661556+wzgrx@users.noreply.github.com>
Co-authored-by: Vladimir Sotnikov <vladimir.s@alphakek.ai>
Co-authored-by: bulieme0 <53142287+bulieme@users.noreply.github.com>
Co-authored-by: Nicolas Pereira <41456803+hqnicolas@users.noreply.github.com>
Co-authored-by: ruucm <ruucm.a@gmail.com>
Co-authored-by: CaledoniaProject <CaledoniaProject@users.noreply.github.com>
Co-authored-by: hugo <liyiligang@users.noreply.github.com>
Co-authored-by: Koro <Koronos@users.noreply.github.com>
pull/3051/head v25.0.0
bmaltais 2025-03-28 11:00:44 -04:00 committed by GitHub
parent a1b16e44f0
commit ed55e81997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
71 changed files with 5387 additions and 1472 deletions

View File

@ -3,6 +3,7 @@ cudnn_windows/
bitsandbytes_windows/
bitsandbytes_windows_deprecated/
dataset/
models/
__pycache__/
venv/
**/.hadolint.yml

View File

@ -71,7 +71,7 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
id: publish
with:
context: .

View File

@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.23.6
uses: crate-ci/typos@v1.30.0

4
.gitignore vendored
View File

@ -51,4 +51,6 @@ dataset/**
models
data
config.toml
sd-scripts
sd-scripts
venv
venv*

View File

@ -1 +1 @@
v24.1.7
v25.0.0

View File

@ -48,13 +48,20 @@ The GUI allows you to set the training parameters and generate and run the requi
- [Potential Solutions](#potential-solutions)
- [SDXL training](#sdxl-training)
- [Masked loss](#masked-loss)
- [Guides](#guides)
- [Using Accelerate Lora Tab to Select GPU ID](#using-accelerate-lora-tab-to-select-gpu-id)
- [Starting Accelerate in GUI](#starting-accelerate-in-gui)
- [Running Multiple Instances (linux)](#running-multiple-instances-linux)
- [Monitoring Processes](#monitoring-processes)
- [Interesting Forks](#interesting-forks)
- [Change History](#change-history)
- [v25.0.0](#v2500)
## 🦒 Colab
This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: <https://github.com/camenduru/kohya_ss-colab>.
I would like to express my gratitude to camendutu for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository.
I would like to express my gratitude to camenduru for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository.
| Colab | Info |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ |
@ -71,7 +78,7 @@ To install the necessary dependencies on a Windows system, follow these steps:
1. Install [Python 3.10.11](https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe).
- During the installation process, ensure that you select the option to add Python to the 'PATH' environment variable.
2. Install [CUDA 11.8 toolkit](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows&target_arch=x86_64).
2. Install [CUDA 12.4 toolkit](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Windows&target_arch=x86_64).
3. Install [Git](https://git-scm.com/download/win).
@ -129,7 +136,7 @@ To install the necessary dependencies on a Linux system, ensure that you fulfill
apt install python3.10-venv
```
- Install the CUDA 11.8 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64).
- Install the CUDA 12.4 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Linux&target_arch=x86_64).
- Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system.
@ -179,7 +186,7 @@ If you choose to use the interactive mode, the default values for the accelerate
To install the necessary components for Runpod and run kohya_ss, follow these steps:
1. Select the Runpod pytorch 2.0.1 template. This is important. Other templates may not work.
1. Select the Runpod pytorch 2.2.0 template. This is important. Other templates may not work.
2. SSH into the Runpod.
@ -222,7 +229,9 @@ To run from a pre-built Runpod template, you can:
3. Once deployed, connect to the Runpod on HTTP 3010 to access the kohya_ss GUI. You can also connect to auto1111 on HTTP 3000.
### Novita
#### Pre-built Novita template
1. Open the Novita template by clicking on <https://novita.ai/gpus-console?templateId=312>.
2. Deploy the template on the desired host.
@ -339,13 +348,27 @@ To upgrade your installation on Linux or macOS, follow these steps:
To launch the GUI service, you can use the provided scripts or run the `kohya_gui.py` script directly. Use the command line arguments listed below to configure the underlying service.
```text
--listen: Specify the IP address to listen on for connections to Gradio.
--username: Set a username for authentication.
--password: Set a password for authentication.
--server_port: Define the port to run the server listener on.
--inbrowser: Open the Gradio UI in a web browser.
--share: Share the Gradio UI.
--language: Set custom language
--help show this help message and exit
--config CONFIG Path to the toml config file for interface defaults
--debug Debug on
--listen LISTEN IP to listen on for connections to Gradio
--username USERNAME Username for authentication
--password PASSWORD Password for authentication
--server_port SERVER_PORT
Port to run the server listener on
--inbrowser Open in browser
--share Share the gradio UI
--headless Is the server headless
--language LANGUAGE Set custom language
--use-ipex Use IPEX environment
--use-rocm Use ROCm environment
--do_not_use_shell Enforce not to use shell=True when running external commands
--do_not_share Do not share the gradio UI
--requirements REQUIREMENTS
requirements file to use for validation
--root_path ROOT_PATH
`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss
--noverify Disable requirements verification
```
### Launching the GUI on Windows
@ -448,6 +471,45 @@ The feature is not fully tested, so there may be bugs. If you find any issues, p
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
## Guides
The following are guides extracted from issues discussions
### Using Accelerate Lora Tab to Select GPU ID
#### Starting Accelerate in GUI
- Open the kohya GUI on your desired port.
- Open the `Accelerate launch` tab
- Ensure the Multi-GPU checkbox is unchecked.
- Set GPU IDs to the desired GPU (like 1).
#### Running Multiple Instances (linux)
- For tracking multiple processes, use separate kohya GUI instances on different ports (e.g., 7860, 7861).
- Start instances using `nohup ./gui.sh --listen 0.0.0.0 --server_port <port> --headless > log.log 2>&1 &`.
#### Monitoring Processes
- Open each GUI in a separate browser tab.
- For terminal access, use SSH and tools like `tmux` or `screen`.
For more details, visit the [GitHub issue](https://github.com/bmaltais/kohya_ss/issues/2577).
## Interesting Forks
To finetune HunyuanDiT models or create LoRAs, visit this [fork](https://github.com/Tencent/HunyuanDiT/tree/main/kohya_ss-hydit)
## Change History
See release information.
### v25.0.0
This is a SIGNIFICANT upgrade. I am groing in uncharted territory here because kohya has not merged any of the recent flux.1 and sd3 updated to his code in his main branch yet... but I feel updates in his code has pretty much dried down and I think his code is probably ready for prime time. So instead of keeping my GUI in the cave man ages, I am opting to move the code for the GUI with support for flux.1 and sd3 to the main branch of my project. Perhaps this will bite me in the proverbias ass... but for those who would rather stay on the older pre "flux.1 and sd3" updates, you can always do:
```shell
git checkout v24.1.7
```
after cloning the repo.
For all the info regarding the new flux.1 and sd3 parameters, see <https://github.com/kohya-ss/sd-scripts/blob/sd3/README.md> for more details.

View File

@ -9,6 +9,7 @@ parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
pn="pn"
shs="shs"
sts="sts"
scs="scs"

View File

@ -1,4 +1,4 @@
#open_folder_small{
#open_folder_small {
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
@ -7,14 +7,14 @@
font-size: 1.5em;
}
#open_folder{
#open_folder {
height: auto;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}
#number_input{
#number_input {
min-width: min-content;
flex-grow: 0.3;
padding-left: 0.75em;
@ -22,7 +22,7 @@
}
.ver-class {
color: #808080;
color: #6d6d6d; /* Neutral dark gray */
font-size: small;
text-align: right;
padding-right: 1em;
@ -35,13 +35,212 @@
}
#myTensorButton {
background: radial-gradient(ellipse, #3a99ff, #52c8ff);
background: #555c66; /* Muted dark gray */
color: white;
border: #296eb8;
border: none;
border-radius: 4px;
padding: 0.5em 1em;
/* box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); Subtle shadow */
/* transition: box-shadow 0.3s ease; */
}
#myTensorButton:hover {
/* box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); Slightly increased shadow on hover */
}
#myTensorButtonStop {
background: radial-gradient(ellipse, #52c8ff, #3a99ff);
color: black;
border: #296eb8;
background: #777d85; /* Lighter muted gray */
color: white;
border: none;
border-radius: 4px;
padding: 0.5em 1em;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
/* transition: box-shadow 0.3s ease; */
}
#myTensorButtonStop:hover {
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
}
.advanced_background {
background: #f4f4f4; /* Light neutral gray */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */
}
.advanced_background:hover {
background-color: #ebebeb; /* Slightly darker background on hover */
border: 1px solid #ccc; /* Add a subtle border */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.basic_background {
background: #eaeff1; /* Muted cool gray */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.basic_background:hover {
background-color: #dfe4e7; /* Slightly darker cool gray on hover */
border: 1px solid #ccc;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.huggingface_background {
background: #e0e4e7; /* Light gray with a hint of blue */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.huggingface_background:hover {
background-color: #d6dce0; /* Slightly darker on hover */
border: 1px solid #bbb;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.flux1_background {
background: #ece9e6; /* Light beige tone */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.flux1_background:hover {
background-color: #e2dfdb; /* Slightly darker beige on hover */
border: 1px solid #ccc;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.preset_background {
background: #f0f0f0; /* Light gray */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.preset_background:hover {
background-color: #e6e6e6; /* Slightly darker on hover */
border: 1px solid #ccc;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.samples_background {
background: #d9dde1; /* Soft muted gray-blue */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.samples_background:hover {
background-color: #cfd3d8; /* Slightly darker on hover */
border: 1px solid #bbb;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
/* Dark mode styles */
.dark .advanced_background {
background: #172029; /* Slightly darker gradio dark theme */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */
}
.dark .advanced_background:hover {
background-color: #121920; /* Slightly darker background on hover */
border: 1px solid #000000; /* Add a subtle border */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.dark .basic_background {
background: #172029;
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.dark .basic_background:hover {
background-color: #11181e;
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.dark .huggingface_background {
background: #131c25;
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.dark .huggingface_background:hover {
background-color: #131c25;
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.dark .flux1_background {
background: #131c25;
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.dark .flux1_background:hover {
background-color: #131c25;
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.dark .preset_background {
background: #191d25;
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.dark .preset_background:hover {
background-color: #212530;
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.dark .samples_background {
background: #101e2c;
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.dark .samples_background:hover {
background-color: #17293a;
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.flux1_rank_layers_background {
background: #ece9e6; /* White background for clear theme */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.flux1_rank_layers_background:hover {
background-color: #dddad7; /* Slightly darker on hover */
border: 1px solid #ccc;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
.dark .flux1_rank_layers_background {
background: #131c25; /* Dark background for dark theme */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}
.dark .flux1_rank_layers_background:hover {
background-color: #131c25; /* Slightly darker on hover */
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}

View File

@ -48,6 +48,7 @@ learning_rate_te1 = 0.0001 # Learning rate text encoder 1
learning_rate_te2 = 0.0001 # Learning rate text encoder 2
lr_scheduler = "cosine" # LR Scheduler
lr_scheduler_args = "" # LR Scheduler args
lr_scheduler_type = "" # LR Scheduler type
lr_warmup = 0 # LR Warmup (% of total steps)
lr_scheduler_num_cycles = 1 # LR Scheduler num cycles
lr_scheduler_power = 1.0 # LR Scheduler power
@ -150,6 +151,9 @@ sample_prompts = "" # Sample prompts
sample_sampler = "euler_a" # Sampler to use for image sampling
[sdxl]
disable_mmap_load_safetensors = false # Disable mmap load safe tensors
fused_backward_pass = false # Fused backward pass
fused_optimizer_groups = 0 # Fused optimizer groups
sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs
sdxl_no_half_vae = true # No half VAE

View File

@ -20,11 +20,13 @@ services:
- /tmp
volumes:
- /tmp/.X11-unix:/tmp/.X11-unix
- ./models:/app/models
- ./dataset:/dataset
- ./dataset/images:/app/data
- ./dataset/logs:/app/logs
- ./dataset/outputs:/app/outputs
- ./dataset/regularization:/app/regularization
- ./models:/app/models
- ./.cache/config:/app/config
- ./.cache/user:/home/1000/.cache
- ./.cache/triton:/home/1000/.triton

View File

@ -1,32 +1,27 @@
## Updating a Local Branch with the Latest sd-scripts Changes
## Updating a Local Submodule with the Latest sd-scripts Changes
To update your local branch with the most recent changes from kohya/sd-scripts, follow these steps:
1. Add sd-scripts as an alternative remote by executing the following command:
1. When you wish to perform an update of the dev branch, execute the following commands:
```
git remote add sd-scripts https://github.com/kohya-ss/sd-scripts.git
```
2. When you wish to perform an update, execute the following commands:
```
```bash
cd sd-scripts
git fetch
git checkout dev
git pull sd-scripts main
git pull origin dev
cd ..
git add sd-scripts
git commit -m "Update sd-scripts submodule to the latest on dev"
```
Alternatively, if you want to obtain the latest code, even if it may be unstable:
Alternatively, if you want to obtain the latest code from main:
```bash
cd sd-scripts
git fetch
git checkout main
git pull origin main
cd ..
git add sd-scripts
git commit -m "Update sd-scripts submodule to the latest on main"
```
git checkout dev
git pull sd-scripts dev
```
3. If you encounter a conflict with the Readme file, you can resolve it by taking the following steps:
```
git add README.md
git merge --continue
```
This may open a text editor for a commit message, but you can simply save and close it to proceed. Following these steps should resolve the conflict. If you encounter additional merge conflicts, consider them as valuable learning opportunities for personal growth.

View File

@ -7,11 +7,13 @@ call .\venv\Scripts\deactivate.bat
:: Activate the virtual environment
call .\venv\Scripts\activate.bat
:: Update pip to latest version
python -m pip install --upgrade pip -q
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
:: Validate requirements
python.exe .\setup\validate_requirements.py
if %errorlevel% neq 0 exit /b %errorlevel%
echo Starting the GUI... this might take some time...
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
if %errorlevel% equ 0 (

30
gui.ps1
View File

@ -7,28 +7,18 @@ if ($env:VIRTUAL_ENV) {
# Activate the virtual environment
# Write-Host "Activating the virtual environment..."
& .\venv\Scripts\activate
python.exe -m pip install --upgrade pip -q
$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib"
# Debug info about system
# python.exe .\setup\debug_info.py
Write-Host "Starting the GUI... this might take some time..."
# Validate the requirements and store the exit code
python.exe .\setup\validate_requirements.py
# Check the exit code and stop execution if it is not 0
if ($LASTEXITCODE -ne 0) {
Write-Host "Failed to validate requirements. Exiting script..."
exit $LASTEXITCODE
$argsFromFile = @()
if (Test-Path .\gui_parameters.txt) {
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
}
$args_combo = $argsFromFile + $args
# Write-Host "The arguments passed to this script were: $args_combo"
python.exe kohya_gui.py $args_combo
# If the exit code is 0, read arguments from gui_parameters.txt (if it exists)
# and run the kohya_gui.py script with the command-line arguments
if ($LASTEXITCODE -eq 0) {
$argsFromFile = @()
if (Test-Path .\gui_parameters.txt) {
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
}
$args_combo = $argsFromFile + $args
# Write-Host "The arguments passed to this script were: $args_combo"
python.exe kohya_gui.py $args_combo
}

8
gui.sh
View File

@ -111,10 +111,4 @@ then
STARTUP_CMD=python
fi
# Validate the requirements and run the script if successful
if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then
"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "$@"
else
echo "Validation failed. Exiting..."
exit 1
fi
"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "--requirements=""$REQUIREMENTS_FILE" "$@"

View File

@ -1,6 +1,10 @@
import gradio as gr
import os
import sys
import argparse
import subprocess
import contextlib
import gradio as gr
from kohya_gui.class_gui_config import KohyaSSGUIConfig
from kohya_gui.dreambooth_gui import dreambooth_tab
from kohya_gui.finetune_gui import finetune_tab
@ -8,71 +12,43 @@ from kohya_gui.textual_inversion_gui import ti_tab
from kohya_gui.utilities import utilities_tab
from kohya_gui.lora_gui import lora_tab
from kohya_gui.class_lora_tab import LoRATools
from kohya_gui.custom_logging import setup_logging
from kohya_gui.localization_ext import add_javascript
PYTHON = sys.executable
project_dir = os.path.dirname(os.path.abspath(__file__))
def UI(**kwargs):
add_javascript(kwargs.get("language"))
css = ""
# Function to read file content, suppressing any FileNotFoundError
def read_file_content(file_path):
with contextlib.suppress(FileNotFoundError):
with open(file_path, "r", encoding="utf8") as file:
return file.read()
return ""
headless = kwargs.get("headless", False)
log.info(f"headless: {headless}")
# Function to initialize the Gradio UI interface
def initialize_ui_interface(config, headless, use_shell, release_info, readme_content):
# Load custom CSS if available
css = read_file_content("./assets/style.css")
if os.path.exists("./assets/style.css"):
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
log.debug("Load CSS...")
css += file.read() + "\n"
if os.path.exists("./.release"):
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
release = file.read()
if os.path.exists("./README.md"):
with open(os.path.join("./README.md"), "r", encoding="utf8") as file:
README = file.read()
interface = gr.Blocks(
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
)
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
if config.is_config_loaded():
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
use_shell_flag = True
# if os.name == "posix":
# use_shell_flag = True
use_shell_flag = config.get("settings.use_shell", use_shell_flag)
if kwargs.get("do_not_use_shell", False):
use_shell_flag = False
if use_shell_flag:
log.info("Using shell=True when running external commands...")
with interface:
# Create the main Gradio Blocks interface
ui_interface = gr.Blocks(css=css, title=f"Kohya_ss GUI {release_info}", theme=gr.themes.Default())
with ui_interface:
# Create tabs for different functionalities
with gr.Tab("Dreambooth"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(
headless=headless, config=config, use_shell_flag=use_shell_flag
)
) = dreambooth_tab(headless=headless, config=config, use_shell_flag=use_shell)
with gr.Tab("LoRA"):
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
lora_tab(headless=headless, config=config, use_shell_flag=use_shell)
with gr.Tab("Textual Inversion"):
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
ti_tab(headless=headless, config=config, use_shell_flag=use_shell)
with gr.Tab("Finetuning"):
finetune_tab(
headless=headless, config=config, use_shell_flag=use_shell_flag
)
finetune_tab(headless=headless, config=config, use_shell_flag=use_shell)
with gr.Tab("Utilities"):
# Utilities tab requires inputs from the Dreambooth tab
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
@ -84,102 +60,97 @@ def UI(**kwargs):
with gr.Tab("LoRA"):
_ = LoRATools(headless=headless)
with gr.Tab("About"):
gr.Markdown(f"kohya_ss GUI release {release}")
# About tab to display release information and README content
gr.Markdown(f"kohya_ss GUI release {release_info}")
with gr.Tab("README"):
gr.Markdown(README)
gr.Markdown(readme_content)
htmlStr = f"""
<html>
<body>
<div class="ver-class">{release}</div>
</body>
</html>
"""
gr.HTML(htmlStr)
# Show the interface
launch_kwargs = {}
username = kwargs.get("username")
password = kwargs.get("password")
server_port = kwargs.get("server_port", 0)
inbrowser = kwargs.get("inbrowser", False)
share = kwargs.get("share", False)
do_not_share = kwargs.get("do_not_share", False)
server_name = kwargs.get("listen")
root_path = kwargs.get("root_path", None)
# Display release information in a div element
gr.Markdown(f"<div class='ver-class'>{release_info}</div>")
launch_kwargs["server_name"] = server_name
if username and password:
launch_kwargs["auth"] = (username, password)
if server_port > 0:
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs["inbrowser"] = inbrowser
if do_not_share:
launch_kwargs["share"] = False
else:
if share:
launch_kwargs["share"] = share
if root_path:
launch_kwargs["root_path"] = root_path
launch_kwargs["debug"] = True
interface.launch(**launch_kwargs)
return ui_interface
# Function to configure and launch the UI
def UI(**kwargs):
# Add custom JavaScript if specified
add_javascript(kwargs.get("language"))
log.info(f"headless: {kwargs.get('headless', False)}")
if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
# Load release and README information
release_info = read_file_content("./.release")
readme_content = read_file_content("./README.md")
# Load configuration from the specified file
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
if config.is_config_loaded():
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
# Determine if shell should be used for running external commands
use_shell = not kwargs.get("do_not_use_shell", False) and config.get("settings.use_shell", True)
if use_shell:
log.info("Using shell=True when running external commands...")
# Initialize the Gradio UI interface
ui_interface = initialize_ui_interface(config, kwargs.get("headless", False), use_shell, release_info, readme_content)
# Construct launch parameters using dictionary comprehension
launch_params = {
"server_name": kwargs.get("listen"),
"auth": (kwargs["username"], kwargs["password"]) if kwargs.get("username") and kwargs.get("password") else None,
"server_port": kwargs.get("server_port", 0) if kwargs.get("server_port", 0) > 0 else None,
"inbrowser": kwargs.get("inbrowser", False),
"share": False if kwargs.get("do_not_share", False) else kwargs.get("share", False),
"root_path": kwargs.get("root_path", None),
"debug": kwargs.get("debug", False),
}
# This line filters out any key-value pairs from `launch_params` where the value is `None`, ensuring only valid parameters are passed to the `launch` function.
launch_params = {k: v for k, v in launch_params.items() if v is not None}
# Launch the Gradio interface with the specified parameters
ui_interface.launch(**launch_params)
# Function to initialize argument parser for command-line arguments
def initialize_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
default="./config.toml",
help="Path to the toml config file for interface defaults",
)
parser.add_argument("--config", type=str, default="./config.toml", help="Path to the toml config file for interface defaults")
parser.add_argument("--debug", action="store_true", help="Debug on")
parser.add_argument(
"--listen",
type=str,
default="127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument(
"--username", type=str, default="", help="Username for authentication"
)
parser.add_argument(
"--password", type=str, default="", help="Password for authentication"
)
parser.add_argument(
"--server_port",
type=int,
default=0,
help="Port to run the server listener on",
)
parser.add_argument("--listen", type=str, default="127.0.0.1", help="IP to listen on for connections to Gradio")
parser.add_argument("--username", type=str, default="", help="Username for authentication")
parser.add_argument("--password", type=str, default="", help="Password for authentication")
parser.add_argument("--server_port", type=int, default=0, help="Port to run the server listener on")
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--headless", action="store_true", help="Is the server headless"
)
parser.add_argument(
"--language", type=str, default=None, help="Set custom language"
)
parser.add_argument("--headless", action="store_true", help="Is the server headless")
parser.add_argument("--language", type=str, default=None, help="Set custom language")
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
parser.add_argument("--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands")
parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI")
parser.add_argument("--requirements", type=str, default=None, help="requirements file to use for validation")
parser.add_argument("--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss")
parser.add_argument("--noverify", action="store_true", help="Disable requirements verification")
return parser
parser.add_argument(
"--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands"
)
parser.add_argument(
"--do_not_share", action="store_true", help="Do not share the gradio UI"
)
parser.add_argument(
"--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss"
)
if __name__ == "__main__":
# Initialize argument parser and parse arguments
parser = initialize_arg_parser()
args = parser.parse_args()
# Set up logging
# Set up logging based on the debug flag
log = setup_logging(debug=args.debug)
UI(**vars(args))
# Verify requirements unless `noverify` flag is set
if args.noverify:
log.warning("Skipping requirements verification.")
else:
# Run the validation command to verify requirements
validation_command = [PYTHON, os.path.join(project_dir, "setup", "validate_requirements.py")]
if args.requirements is not None:
validation_command.append(f"--requirements={args.requirements}")
subprocess.run(validation_command, check=True)
# Launch the UI with the provided arguments
UI(**vars(args))

View File

@ -102,7 +102,7 @@ def caption_images(
postfix=postfix,
)
# Replace specified text in caption files if find and replace text is provided
if find_text and replace_text:
if find_text:
find_replace(
folder_path=images_dir,
caption_file_ext=caption_ext,

View File

@ -42,7 +42,7 @@ def get_images_in_directory(directory_path):
import os
# List of common image file extensions to look for
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
# Generate a list of image file paths in the directory
image_files = [

View File

@ -3,6 +3,10 @@ import os
import shlex
from .class_gui_config import KohyaSSGUIConfig
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
class AccelerateLaunch:
@ -79,12 +83,16 @@ class AccelerateLaunch:
)
self.dynamo_use_fullgraph = gr.Checkbox(
label="Dynamo use fullgraph",
value=self.config.get("accelerate_launch.dynamo_use_fullgraph", False),
value=self.config.get(
"accelerate_launch.dynamo_use_fullgraph", False
),
info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
)
self.dynamo_use_dynamic = gr.Checkbox(
label="Dynamo use dynamic",
value=self.config.get("accelerate_launch.dynamo_use_dynamic", False),
value=self.config.get(
"accelerate_launch.dynamo_use_dynamic", False
),
info="Whether to enable dynamic shape tracing.",
)
@ -103,6 +111,24 @@ class AccelerateLaunch:
placeholder="example: 0,1",
info=" What GPUs (by id) should be used for training on this machine as a comma-separated list",
)
def validate_gpu_ids(value):
if value == "":
return
if not (
value.isdigit() and int(value) >= 0 and int(value) <= 128
):
log.error("GPU IDs must be an integer between 0 and 128")
return
else:
for id in value.split(","):
if not id.isdigit() or int(id) < 0 or int(id) > 128:
log.error(
"GPU IDs must be an integer between 0 and 128"
)
self.gpu_ids.blur(fn=validate_gpu_ids, inputs=self.gpu_ids)
self.main_process_port = gr.Number(
label="Main process port",
value=self.config.get("accelerate_launch.main_process_port", 0),
@ -136,9 +162,14 @@ class AccelerateLaunch:
if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"):
run_cmd.append("--dynamo_use_dynamic")
if "extra_accelerate_launch_args" in kwargs and kwargs["extra_accelerate_launch_args"] != "":
extra_accelerate_launch_args = kwargs["extra_accelerate_launch_args"].replace('"', "")
if (
"extra_accelerate_launch_args" in kwargs
and kwargs["extra_accelerate_launch_args"] != ""
):
extra_accelerate_launch_args = kwargs[
"extra_accelerate_launch_args"
].replace('"', "")
for arg in extra_accelerate_launch_args.split():
run_cmd.append(shlex.quote(arg))

View File

@ -146,7 +146,7 @@ class AdvancedTraining:
with gr.Row():
self.loss_type = gr.Dropdown(
label="Loss type",
choices=["huber", "smooth_l1", "l2"],
choices=["huber", "smooth_l1", "l1", "l2"],
value=self.config.get("advanced.loss_type", "l2"),
info="The type of loss to use and whether it's scheduled based on the timestep",
)
@ -168,6 +168,14 @@ class AdvancedTraining:
step=0.01,
info="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type",
)
self.huber_scale = gr.Number(
label="Huber scale",
value=self.config.get("advanced.huber_scale", 1.0),
minimum=0.0,
maximum=10.0,
step=0.01,
info="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)
with gr.Row():
self.save_every_n_steps = gr.Number(
@ -188,6 +196,18 @@ class AdvancedTraining:
precision=0,
info="(Optional) Save only the specified number of states (old models will be deleted)",
)
self.save_last_n_epochs = gr.Number(
label="Save last N epochs",
value=self.config.get("advanced.save_last_n_epochs", 0),
precision=0,
info="(Optional) Save only the specified number of epochs (old epochs will be deleted)",
)
self.save_last_n_epochs_state = gr.Number(
label="Save last N epochs state",
value=self.config.get("advanced.save_last_n_epochs_state", 0),
precision=0,
info="(Optional) Save only the specified number of epochs states (old models will be deleted)",
)
with gr.Row():
def full_options_update(full_fp16, full_bf16):
@ -228,12 +248,16 @@ class AdvancedTraining:
)
with gr.Row():
if training_type == "lora":
self.fp8_base = gr.Checkbox(
label="fp8 base training (experimental)",
info="U-Net and Text Encoder can be trained with fp8 (experimental)",
value=self.config.get("advanced.fp8_base", False),
)
self.fp8_base = gr.Checkbox(
label="fp8 base",
info="Use fp8 for base model",
value=self.config.get("advanced.fp8_base", False),
)
self.fp8_base_unet = gr.Checkbox(
label="fp8 base unet",
info="Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16.",
value=self.config.get("advanced.fp8_base_unet", False),
)
self.full_fp16 = gr.Checkbox(
label="Full fp16 training (experimental)",
value=self.config.get("advanced.full_fp16", False),
@ -254,6 +278,25 @@ class AdvancedTraining:
inputs=[self.full_fp16, self.full_bf16],
outputs=[self.full_fp16, self.full_bf16],
)
with gr.Row():
self.highvram = gr.Checkbox(
label="highvram",
value=self.config.get("advanced.highvram", False),
info="Disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM)",
interactive=True,
)
self.lowvram = gr.Checkbox(
label="lowvram",
value=self.config.get("advanced.lowvram", False),
info="Enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)",
interactive=True,
)
self.skip_cache_check = gr.Checkbox(
label="Skip cache check",
value=self.config.get("advanced.skip_cache_check", False),
info="Skip cache check for faster training start",
)
with gr.Row():
self.gradient_checkpointing = gr.Checkbox(
@ -450,6 +493,15 @@ class AdvancedTraining:
value=self.config.get("advanced.vae_batch_size", 0),
step=1,
)
self.blocks_to_swap = gr.Slider(
label="Blocks to swap",
value=self.config.get("advanced.blocks_to_swap", 0),
info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.",
minimum=0,
maximum=57,
step=1,
interactive=True,
)
with gr.Group(), gr.Row():
self.save_state = gr.Checkbox(
label="Save training state",
@ -534,6 +586,11 @@ class AdvancedTraining:
self.current_log_tracker_config_dir = path if not path == "" else "."
return list(list_files(path, exts=[".json"], all=True))
self.log_config = gr.Checkbox(
label="Log config",
value=self.config.get("advanced.log_config", False),
info="Log training parameter to WANDB",
)
self.log_tracker_name = gr.Textbox(
label="Log tracker name",
value=self.config.get("advanced.log_tracker_name", ""),

View File

@ -25,6 +25,7 @@ class BasicTraining:
learning_rate_value: float = "1e-6",
lr_scheduler_value: str = "constant",
lr_warmup_value: float = "0",
lr_warmup_steps_value: int = 0,
finetuning: bool = False,
dreambooth: bool = False,
config: dict = {},
@ -44,10 +45,14 @@ class BasicTraining:
self.learning_rate_value = learning_rate_value
self.lr_scheduler_value = lr_scheduler_value
self.lr_warmup_value = lr_warmup_value
self.lr_warmup_steps_value= lr_warmup_steps_value
self.finetuning = finetuning
self.dreambooth = dreambooth
self.config = config
# Initialize old_lr_warmup and old_lr_warmup_steps with default values
self.old_lr_warmup = 0
self.old_lr_warmup_steps = 0
# Initialize the UI components
self.initialize_ui_components()
@ -162,20 +167,37 @@ class BasicTraining:
"cosine",
"cosine_with_restarts",
"linear",
"piecewise_constant",
"polynomial",
"cosine_with_min_lr",
"inverse_sqrt",
"warmup_stable_decay",
],
value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
)
# Initialize the learning rate scheduler type dropdown
self.lr_scheduler_type = gr.Dropdown(
label="LR Scheduler type",
info="(Optional) custom scheduler module name",
choices=[
"",
"CosineAnnealingLR",
],
value=self.config.get("basic.lr_scheduler_type", ""),
allow_custom_value=True,
)
# Initialize the optimizer dropdown
self.optimizer = gr.Dropdown(
label="Optimizer",
choices=[
"AdamW",
"AdamWScheduleFree",
"AdamW8bit",
"Adafactor",
"bitsandbytes.optim.AdEMAMix8bit",
"bitsandbytes.optim.PagedAdEMAMix8bit",
"DAdaptation",
"DAdaptAdaGrad",
"DAdaptAdam",
@ -190,11 +212,15 @@ class BasicTraining:
"PagedAdamW32bit",
"PagedLion8bit",
"Prodigy",
"prodigyplus.ProdigyPlusScheduleFree",
"RAdamScheduleFree",
"SGDNesterov",
"SGDNesterov8bit",
"SGDScheduleFree",
],
value=self.config.get("basic.optimizer", "AdamW8bit"),
interactive=True,
allow_custom_value=True,
)
def init_grad_and_lr_controls(self) -> None:
@ -240,7 +266,7 @@ class BasicTraining:
self.learning_rate = gr.Number(
label=lr_label,
value=self.config.get("basic.learning_rate", self.learning_rate_value),
minimum=0,
minimum=-1,
maximum=1,
info="Set to 0 to not train the Unet",
)
@ -251,7 +277,7 @@ class BasicTraining:
"basic.learning_rate_te", self.learning_rate_value
),
visible=self.finetuning or self.dreambooth,
minimum=0,
minimum=-1,
maximum=1,
info="Set to 0 to not train the Text Encoder",
)
@ -262,7 +288,7 @@ class BasicTraining:
"basic.learning_rate_te1", self.learning_rate_value
),
visible=False,
minimum=0,
minimum=-1,
maximum=1,
info="Set to 0 to not train the Text Encoder 1",
)
@ -273,7 +299,7 @@ class BasicTraining:
"basic.learning_rate_te2", self.learning_rate_value
),
visible=False,
minimum=0,
minimum=-1,
maximum=1,
info="Set to 0 to not train the Text Encoder 2",
)
@ -285,25 +311,37 @@ class BasicTraining:
maximum=100,
step=1,
)
# Initialize the learning rate warmup steps override
self.lr_warmup_steps = gr.Number(
label="LR warmup steps (override)",
value=self.config.get("basic.lr_warmup_steps", self.lr_warmup_steps_value),
minimum=0,
step=1,
)
def lr_scheduler_changed(scheduler, value):
def lr_scheduler_changed(scheduler, value, value_lr_warmup_steps):
if scheduler == "constant":
self.old_lr_warmup = value
self.old_lr_warmup_steps = value_lr_warmup_steps
value = 0
value_lr_warmup_steps = 0
interactive=False
info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..."
else:
if self.old_lr_warmup != 0:
value = self.old_lr_warmup
self.old_lr_warmup = 0
if self.old_lr_warmup_steps != 0:
value_lr_warmup_steps = self.old_lr_warmup_steps
self.old_lr_warmup_steps = 0
interactive=True
info=""
return gr.Slider(value=value, interactive=interactive, info=info)
return gr.Slider(value=value, interactive=interactive, info=info), gr.Number(value=value_lr_warmup_steps, interactive=interactive, info=info)
self.lr_scheduler.change(
lr_scheduler_changed,
inputs=[self.lr_scheduler, self.lr_warmup],
outputs=self.lr_warmup,
inputs=[self.lr_scheduler, self.lr_warmup, self.lr_warmup_steps],
outputs=[self.lr_warmup, self.lr_warmup_steps],
)
def init_scheduler_controls(self) -> None:

View File

@ -48,7 +48,7 @@ class CommandExecutor:
# Execute the command securely
self.process = subprocess.Popen(run_cmd, **kwargs)
log.info("Command executed.")
log.debug("Command executed.")
def kill_command(self):
"""

336
kohya_gui/class_flux1.py Normal file
View File

@ -0,0 +1,336 @@
import gradio as gr
from typing import Tuple
from .common_gui import (
get_any_file_path,
document_symbol,
)
class flux1Training:
def __init__(
self,
headless: bool = False,
finetuning: bool = False,
training_type: str = "",
config: dict = {},
flux1_checkbox: gr.Checkbox = False,
) -> None:
self.headless = headless
self.finetuning = finetuning
self.training_type = training_type
self.config = config
self.flux1_checkbox = flux1_checkbox
# Define the behavior for changing noise offset type.
def noise_offset_type_change(
noise_offset_type: str,
) -> Tuple[gr.Group, gr.Group]:
if noise_offset_type == "Original":
return (gr.Group(visible=True), gr.Group(visible=False))
else:
return (gr.Group(visible=False), gr.Group(visible=True))
with gr.Accordion(
"Flux.1", open=True, visible=False, elem_classes=["flux1_background"]
) as flux1_accordion:
with gr.Group():
with gr.Row():
self.ae = gr.Textbox(
label="VAE Path",
placeholder="Path to VAE model",
value=self.config.get("flux1.ae", ""),
interactive=True,
)
self.ae_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.ae_button.click(
get_any_file_path,
outputs=self.ae,
show_progress=False,
)
self.clip_l = gr.Textbox(
label="CLIP-L Path",
placeholder="Path to CLIP-L model",
value=self.config.get("flux1.clip_l", ""),
interactive=True,
)
self.clip_l_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.clip_l_button.click(
get_any_file_path,
outputs=self.clip_l,
show_progress=False,
)
self.t5xxl = gr.Textbox(
label="T5-XXL Path",
placeholder="Path to T5-XXL model",
value=self.config.get("flux1.t5xxl", ""),
interactive=True,
)
self.t5xxl_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.t5xxl_button.click(
get_any_file_path,
outputs=self.t5xxl,
show_progress=False,
)
with gr.Row():
self.discrete_flow_shift = gr.Number(
label="Discrete Flow Shift",
value=self.config.get("flux1.discrete_flow_shift", 3.0),
info="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0",
minimum=-1024,
maximum=1024,
step=0.01,
interactive=True,
)
self.model_prediction_type = gr.Dropdown(
label="Model Prediction Type",
choices=["raw", "additive", "sigma_scaled"],
value=self.config.get(
"flux1.timestep_sampling", "sigma_scaled"
),
interactive=True,
)
self.timestep_sampling = gr.Dropdown(
label="Timestep Sampling",
choices=["flux_shift", "sigma", "shift", "sigmoid", "uniform"],
value=self.config.get("flux1.timestep_sampling", "sigma"),
interactive=True,
)
self.apply_t5_attn_mask = gr.Checkbox(
label="Apply T5 Attention Mask",
value=self.config.get("flux1.apply_t5_attn_mask", False),
info="Apply attention mask to T5-XXL encode and FLUX double blocks ",
interactive=True,
)
with gr.Row(visible=True if not finetuning else False):
self.split_mode = gr.Checkbox(
label="Split Mode",
value=self.config.get("flux1.split_mode", False),
info="Split mode for Flux1",
interactive=True,
)
self.train_blocks = gr.Dropdown(
label="Train Blocks",
choices=["all", "double", "single"],
value=self.config.get("flux1.train_blocks", "all"),
interactive=True,
)
self.split_qkv = gr.Checkbox(
label="Split QKV",
value=self.config.get("flux1.split_qkv", False),
info="Split the projection layers of q/k/v/txt in the attention",
interactive=True,
)
self.train_t5xxl = gr.Checkbox(
label="Train T5-XXL",
value=self.config.get("flux1.train_t5xxl", False),
info="Train T5-XXL model",
interactive=True,
)
self.cpu_offload_checkpointing = gr.Checkbox(
label="CPU Offload Checkpointing",
value=self.config.get("flux1.cpu_offload_checkpointing", False),
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
interactive=True,
)
with gr.Row():
self.guidance_scale = gr.Number(
label="Guidance Scale",
value=self.config.get("flux1.guidance_scale", 3.5),
info="Guidance scale for Flux1",
minimum=0,
maximum=1024,
step=0.1,
interactive=True,
)
self.t5xxl_max_token_length = gr.Number(
label="T5-XXL Max Token Length",
value=self.config.get("flux1.t5xxl_max_token_length", 512),
info="Max token length for T5-XXL",
minimum=0,
maximum=4096,
step=1,
interactive=True,
)
self.enable_all_linear = gr.Checkbox(
label="Enable All Linear",
value=self.config.get("flux1.enable_all_linear", False),
info="(Only applicable to 'FLux1 OFT' LoRA) Target all linear connections in the MLP layer. The default is False, which targets only attention.",
interactive=True,
)
with gr.Row():
self.flux1_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get(
"flux1.cache_text_encoder_outputs", False
),
info="Cache text encoder outputs to speed up inference",
interactive=True,
)
self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox(
label="Cache Text Encoder Outputs to Disk",
value=self.config.get(
"flux1.cache_text_encoder_outputs_to_disk", False
),
info="Cache text encoder outputs to disk to speed up inference",
interactive=True,
)
self.mem_eff_save = gr.Checkbox(
label="Memory Efficient Save",
value=self.config.get("flux1.mem_eff_save", False),
info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.",
interactive=True,
)
with gr.Row():
# self.blocks_to_swap = gr.Slider(
# label="Blocks to swap",
# value=self.config.get("flux1.blocks_to_swap", 0),
# info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.",
# minimum=0,
# maximum=57,
# step=1,
# interactive=True,
# )
self.single_blocks_to_swap = gr.Slider(
label="Single Blocks to swap (depercated)",
value=self.config.get("flux1.single_blocks_to_swap", 0),
info="[Experimental] Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes.",
minimum=0,
maximum=19,
step=1,
interactive=True,
)
self.double_blocks_to_swap = gr.Slider(
label="Double Blocks to swap (depercated)",
value=self.config.get("flux1.double_blocks_to_swap", 0),
info="[Experimental] Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes.",
minimum=0,
maximum=38,
step=1,
interactive=True,
)
with gr.Row(visible=True if finetuning else False):
self.blockwise_fused_optimizers = gr.Checkbox(
label="Blockwise Fused Optimizer",
value=self.config.get(
"flux1.blockwise_fused_optimizers", False
),
info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.",
interactive=True,
)
self.cpu_offload_checkpointing = gr.Checkbox(
label="CPU Offload Checkpointing",
value=self.config.get("flux1.cpu_offload_checkpointing", False),
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
interactive=True,
)
self.flux_fused_backward_pass = gr.Checkbox(
label="Fused Backward Pass",
value=self.config.get("flux1.fused_backward_pass", False),
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)
with gr.Accordion(
"Blocks to train",
open=True,
visible=False if finetuning else True,
elem_classes=["flux1_blocks_to_train_background"],
):
with gr.Row():
self.train_double_block_indices = gr.Textbox(
label="train_double_block_indices",
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of double blocks is 19.",
value=self.config.get("flux1.train_double_block_indices", "all"),
interactive=True,
)
self.train_single_block_indices = gr.Textbox(
label="train_single_block_indices",
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of single blocks is 38.",
value=self.config.get("flux1.train_single_block_indices", "all"),
interactive=True,
)
with gr.Accordion(
"Rank for layers",
open=False,
visible=False if finetuning else True,
elem_classes=["flux1_rank_layers_background"],
):
with gr.Row():
self.img_attn_dim = gr.Textbox(
label="img_attn_dim",
value=self.config.get("flux1.img_attn_dim", ""),
interactive=True,
)
self.img_mlp_dim = gr.Textbox(
label="img_mlp_dim",
value=self.config.get("flux1.img_mlp_dim", ""),
interactive=True,
)
self.img_mod_dim = gr.Textbox(
label="img_mod_dim",
value=self.config.get("flux1.img_mod_dim", ""),
interactive=True,
)
self.single_dim = gr.Textbox(
label="single_dim",
value=self.config.get("flux1.single_dim", ""),
interactive=True,
)
with gr.Row():
self.txt_attn_dim = gr.Textbox(
label="txt_attn_dim",
value=self.config.get("flux1.txt_attn_dim", ""),
interactive=True,
)
self.txt_mlp_dim = gr.Textbox(
label="txt_mlp_dim",
value=self.config.get("flux1.txt_mlp_dim", ""),
interactive=True,
)
self.txt_mod_dim = gr.Textbox(
label="txt_mod_dim",
value=self.config.get("flux1.txt_mod_dim", ""),
interactive=True,
)
self.single_mod_dim = gr.Textbox(
label="single_mod_dim",
value=self.config.get("flux1.single_mod_dim", ""),
interactive=True,
)
with gr.Row():
self.in_dims = gr.Textbox(
label="in_dims",
value=self.config.get("flux1.in_dims", ""),
placeholder="e.g., [4,0,0,0,4]",
info="Each number corresponds to img_in, time_in, vector_in, guidance_in, txt_in. The above example applies LoRA to all conditioning layers, with rank 4 for img_in, 2 for time_in, vector_in, guidance_in, and 4 for txt_in.",
interactive=True,
)
self.flux1_checkbox.change(
lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox),
inputs=[self.flux1_checkbox],
outputs=[flux1_accordion],
)

View File

@ -4,10 +4,12 @@ from .svd_merge_lora_gui import gradio_svd_merge_lora_tab
from .verify_lora_gui import gradio_verify_lora_tab
from .resize_lora_gui import gradio_resize_lora_tab
from .extract_lora_gui import gradio_extract_lora_tab
from .flux_extract_lora_gui import gradio_flux_extract_lora_tab
from .convert_lcm_gui import gradio_convert_lcm_tab
from .extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
from .extract_lora_from_dylora_gui import gradio_extract_dylora_tab
from .merge_lycoris_gui import gradio_merge_lycoris_tab
from .flux_merge_lora_gui import GradioFluxMergeLoRaTab
class LoRATools:
@ -19,9 +21,11 @@ class LoRATools:
gradio_extract_dylora_tab(headless=headless)
gradio_convert_lcm_tab(headless=headless)
gradio_extract_lora_tab(headless=headless)
gradio_flux_extract_lora_tab(headless=headless)
gradio_extract_lycoris_locon_tab(headless=headless)
gradio_merge_lora_tab = GradioMergeLoRaTab()
gradio_merge_lycoris_tab(headless=headless)
gradio_svd_merge_lora_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
GradioFluxMergeLoRaTab(headless=headless)

View File

@ -28,7 +28,10 @@ def create_prompt_file(sample_prompts, output_dir):
Returns:
str: The path to the prompt file.
"""
sample_prompts_path = os.path.join(output_dir, "prompt.txt")
sample_prompts_path = os.path.join(output_dir, "sample/prompt.txt")
if not os.path.exists(os.path.dirname(sample_prompts_path)):
os.makedirs(os.path.dirname(sample_prompts_path))
with open(sample_prompts_path, "w", encoding="utf-8") as f:
f.write(sample_prompts)

249
kohya_gui/class_sd3.py Normal file
View File

@ -0,0 +1,249 @@
import gradio as gr
from typing import Tuple
from .common_gui import (
get_folder_path,
get_any_file_path,
list_files,
list_dirs,
create_refresh_button,
document_symbol,
)
class sd3Training:
"""
This class configures and initializes the advanced training settings for a machine learning model,
including options for headless operation, fine-tuning, training type selection, and default directory paths.
Attributes:
headless (bool): If True, run without the Gradio interface.
finetuning (bool): If True, enables fine-tuning of the model.
training_type (str): Specifies the type of training to perform.
no_token_padding (gr.Checkbox): Checkbox to disable token padding.
gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps.
weighted_captions (gr.Checkbox): Checkbox to enable weighted captions.
"""
def __init__(
self,
headless: bool = False,
finetuning: bool = False,
training_type: str = "",
config: dict = {},
sd3_checkbox: gr.Checkbox = False,
) -> None:
"""
Initializes the AdvancedTraining class with given settings.
Parameters:
headless (bool): Run in headless mode without GUI.
finetuning (bool): Enable model fine-tuning.
training_type (str): The type of training to be performed.
config (dict): Configuration options for the training process.
"""
self.headless = headless
self.finetuning = finetuning
self.training_type = training_type
self.config = config
self.sd3_checkbox = sd3_checkbox
# Define the behavior for changing noise offset type.
def noise_offset_type_change(
noise_offset_type: str,
) -> Tuple[gr.Group, gr.Group]:
"""
Returns a tuple of Gradio Groups with visibility set based on the noise offset type.
Parameters:
noise_offset_type (str): The selected noise offset type.
Returns:
Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set.
"""
if noise_offset_type == "Original":
return (gr.Group(visible=True), gr.Group(visible=False))
else:
return (gr.Group(visible=False), gr.Group(visible=True))
with gr.Accordion(
"SD3", open=False, elem_id="sd3_tab", visible=False
) as sd3_accordion:
with gr.Group():
gr.Markdown("### SD3 Specific Parameters")
with gr.Row():
self.weighting_scheme = gr.Dropdown(
label="Weighting Scheme",
choices=["logit_normal", "sigma_sqrt", "mode", "cosmap", "uniform"],
value=self.config.get("sd3.weighting_scheme", "logit_normal"),
interactive=True,
)
self.logit_mean = gr.Number(
label="Logit Mean",
value=self.config.get("sd3.logit_mean", 0.0),
interactive=True,
)
self.logit_std = gr.Number(
label="Logit Std",
value=self.config.get("sd3.logit_std", 1.0),
interactive=True,
)
self.mode_scale = gr.Number(
label="Mode Scale",
value=self.config.get("sd3.mode_scale", 1.29),
interactive=True,
)
with gr.Row():
self.clip_l = gr.Textbox(
label="CLIP-L Path",
placeholder="Path to CLIP-L model",
value=self.config.get("sd3.clip_l", ""),
interactive=True,
)
self.clip_l_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.clip_l_button.click(
get_any_file_path,
outputs=self.clip_l,
show_progress=False,
)
self.clip_g = gr.Textbox(
label="CLIP-G Path",
placeholder="Path to CLIP-G model",
value=self.config.get("sd3.clip_g", ""),
interactive=True,
)
self.clip_g_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.clip_g_button.click(
get_any_file_path,
outputs=self.clip_g,
show_progress=False,
)
self.t5xxl = gr.Textbox(
label="T5-XXL Path",
placeholder="Path to T5-XXL model",
value=self.config.get("sd3.t5xxl", ""),
interactive=True,
)
self.t5xxl_button = gr.Button(
document_symbol,
elem_id="open_folder_small",
visible=(not headless),
interactive=True,
)
self.t5xxl_button.click(
get_any_file_path,
outputs=self.t5xxl,
show_progress=False,
)
with gr.Row():
self.save_clip = gr.Checkbox(
label="Save CLIP models",
value=self.config.get("sd3.save_clip", False),
interactive=True,
)
self.save_t5xxl = gr.Checkbox(
label="Save T5-XXL model",
value=self.config.get("sd3.save_t5xxl", False),
interactive=True,
)
with gr.Row():
self.t5xxl_device = gr.Textbox(
label="T5-XXL Device",
placeholder="Device for T5-XXL (e.g., cuda:0)",
value=self.config.get("sd3.t5xxl_device", ""),
interactive=True,
)
self.t5xxl_dtype = gr.Dropdown(
label="T5-XXL Dtype",
choices=["float32", "fp16", "bf16"],
value=self.config.get("sd3.t5xxl_dtype", "bf16"),
interactive=True,
)
self.sd3_text_encoder_batch_size = gr.Number(
label="Text Encoder Batch Size",
value=self.config.get("sd3.text_encoder_batch_size", 1),
minimum=1,
maximum=1024,
step=1,
interactive=True,
)
self.sd3_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get("sd3.cache_text_encoder_outputs", False),
info="Cache text encoder outputs to speed up inference",
interactive=True,
)
self.sd3_cache_text_encoder_outputs_to_disk = gr.Checkbox(
label="Cache Text Encoder Outputs to Disk",
value=self.config.get(
"sd3.cache_text_encoder_outputs_to_disk", False
),
info="Cache text encoder outputs to disk to speed up inference",
interactive=True,
)
with gr.Row():
self.clip_l_dropout_rate = gr.Number(
label="CLIP-L Dropout Rate",
value=self.config.get("sd3.clip_l_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for CLIP-L encoder"
)
self.clip_g_dropout_rate = gr.Number(
label="CLIP-G Dropout Rate",
value=self.config.get("sd3.clip_g_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for CLIP-G encoder"
)
self.t5_dropout_rate = gr.Number(
label="T5 Dropout Rate",
value=self.config.get("sd3.t5_dropout_rate", 0.0),
interactive=True,
minimum=0.0,
info="Dropout rate for T5-XXL encoder"
)
with gr.Row():
self.sd3_fused_backward_pass = gr.Checkbox(
label="Fused Backward Pass",
value=self.config.get("sd3.fused_backward_pass", False),
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)
self.disable_mmap_load_safetensors = gr.Checkbox(
label="Disable mmap load safe tensors",
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
)
self.enable_scaled_pos_embed = gr.Checkbox(
label="Enable Scaled Positional Embeddings",
info="Enable scaled positional embeddings in the model.",
value=self.config.get("sd3.enable_scaled_pos_embed", False),
)
self.pos_emb_random_crop_rate = gr.Number(
label="Positional Embedding Random Crop Rate",
value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0),
interactive=True,
minimum=0.0,
info="Random crop rate for positional embeddings"
)
self.sd3_checkbox.change(
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
inputs=[self.sd3_checkbox],
outputs=[sd3_accordion],
)

View File

@ -7,10 +7,12 @@ class SDXLParameters:
sdxl_checkbox: gr.Checkbox,
show_sdxl_cache_text_encoder_outputs: bool = True,
config: KohyaSSGUIConfig = {},
trainer: str = "",
):
self.sdxl_checkbox = sdxl_checkbox
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
self.config = config
self.trainer = trainer
self.initialize_accordion()
@ -30,6 +32,41 @@ class SDXLParameters:
info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.",
value=self.config.get("sdxl.sdxl_no_half_vae", False),
)
self.fused_backward_pass = gr.Checkbox(
label="Fused backward pass",
info="Enable fused backward pass. This option is useful to reduce the GPU memory usage. Can't be used if Fused optimizer groups is > 0. Only AdaFactor is supported",
value=self.config.get("sdxl.fused_backward_pass", False),
visible=self.trainer == "finetune" or self.trainer == "dreambooth",
)
self.fused_optimizer_groups = gr.Number(
label="Fused optimizer groups",
info="Number of optimizer groups to fuse. This option is useful to reduce the GPU memory usage. Can't be used if Fused backward pass is enabled. Since the effect is limited to a certain number, it is recommended to specify 4-10.",
value=self.config.get("sdxl.fused_optimizer_groups", 0),
minimum=0,
step=1,
visible=self.trainer == "finetune" or self.trainer == "dreambooth",
)
self.disable_mmap_load_safetensors = gr.Checkbox(
label="Disable mmap load safe tensors",
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
value=self.config.get("sdxl.disable_mmap_load_safetensors", False),
)
self.fused_backward_pass.change(
lambda fused_backward_pass: gr.Number(
interactive=not fused_backward_pass
),
inputs=[self.fused_backward_pass],
outputs=[self.fused_optimizer_groups],
)
self.fused_optimizer_groups.change(
lambda fused_optimizer_groups: gr.Checkbox(
interactive=fused_optimizer_groups == 0
),
inputs=[self.fused_optimizer_groups],
outputs=[self.fused_backward_pass],
)
self.sdxl_checkbox.change(
lambda sdxl_checkbox: gr.Accordion(visible=sdxl_checkbox),

View File

@ -26,7 +26,6 @@ default_models = [
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2",
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
]
@ -245,19 +244,88 @@ class SourceModel:
with gr.Column():
with gr.Row():
self.v2 = gr.Checkbox(
label="v2", value=False, visible=False, min_width=60
label="v2", value=False, visible=False, min_width=60,
interactive=True,
)
self.v_parameterization = gr.Checkbox(
label="v_parameterization",
label="v_param",
value=False,
visible=False,
min_width=130,
interactive=True,
)
self.sdxl_checkbox = gr.Checkbox(
label="SDXL",
value=False,
visible=False,
min_width=60,
interactive=True,
)
self.sd3_checkbox = gr.Checkbox(
label="SD3",
value=False,
visible=False,
min_width=60,
interactive=True,
)
self.flux1_checkbox = gr.Checkbox(
label="Flux.1",
value=False,
visible=False,
min_width=60,
interactive=True,
)
def toggle_checkboxes(v2, v_parameterization, sdxl_checkbox, sd3_checkbox, flux1_checkbox):
# Check if all checkboxes are unchecked
if not v2 and not v_parameterization and not sdxl_checkbox and not sd3_checkbox and not flux1_checkbox:
# If all unchecked, return new interactive checkboxes
return (
gr.Checkbox(interactive=True), # v2 checkbox
gr.Checkbox(interactive=True), # v_parameterization checkbox
gr.Checkbox(interactive=True), # sdxl_checkbox
gr.Checkbox(interactive=True), # sd3_checkbox
gr.Checkbox(interactive=True), # sd3_checkbox
)
else:
# If any checkbox is checked, return checkboxes with current interactive state
return (
gr.Checkbox(interactive=v2), # v2 checkbox
gr.Checkbox(interactive=v_parameterization), # v_parameterization checkbox
gr.Checkbox(interactive=sdxl_checkbox), # sdxl_checkbox
gr.Checkbox(interactive=sd3_checkbox), # sd3_checkbox
gr.Checkbox(interactive=flux1_checkbox), # flux1_checkbox
)
self.v2.change(
fn=toggle_checkboxes,
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
show_progress=False,
)
self.v_parameterization.change(
fn=toggle_checkboxes,
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
show_progress=False,
)
self.sdxl_checkbox.change(
fn=toggle_checkboxes,
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
show_progress=False,
)
self.sd3_checkbox.change(
fn=toggle_checkboxes,
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
show_progress=False,
)
self.flux1_checkbox.change(
fn=toggle_checkboxes,
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
show_progress=False,
)
with gr.Column():
gr.Group(visible=False)
@ -294,6 +362,8 @@ class SourceModel:
self.v2,
self.v_parameterization,
self.sdxl_checkbox,
self.sd3_checkbox,
self.flux1_checkbox,
],
show_progress=False,
)

View File

@ -20,6 +20,7 @@ from .common_gui import setup_environment
class TensorboardManager:
DEFAULT_TENSORBOARD_PORT = 6006
DEFAULT_TENSORBOARD_HOST = "0.0.0.0"
def __init__(self, logging_dir, headless: bool = False, wait_time=5):
self.logging_dir = logging_dir
@ -29,6 +30,9 @@ class TensorboardManager:
self.tensorboard_port = os.environ.get(
"TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT
)
self.tensorboard_host = os.environ.get(
"TENSORBOARD_HOST", self.DEFAULT_TENSORBOARD_HOST
)
self.log = setup_logging()
self.thread = None
self.stop_event = Event()
@ -64,7 +68,7 @@ class TensorboardManager:
"--logdir",
logging_dir,
"--host",
"0.0.0.0",
self.tensorboard_host,
"--port",
str(self.tensorboard_port),
]

View File

@ -5,6 +5,7 @@ except ImportError:
from easygui import msgbox, ynbox
from typing import Optional
from .custom_logging import setup_logging
from .sd_modeltype import SDModelType
import os
import re
@ -327,7 +328,6 @@ def update_my_data(my_data):
# Convert values to int if they are strings
for key in [
"adaptive_noise_scale",
"clip_skip",
"epoch",
"gradient_accumulation_steps",
@ -379,7 +379,13 @@ def update_my_data(my_data):
my_data[key] = int(75)
# Convert values to float if they are strings, correctly handling float representations
for key in ["noise_offset", "learning_rate", "text_encoder_lr", "unet_lr"]:
for key in [
"adaptive_noise_scale",
"noise_offset",
"learning_rate",
"text_encoder_lr",
"unet_lr",
]:
value = my_data.get(key)
if value is not None:
try:
@ -956,11 +962,15 @@ def set_pretrained_model_name_or_path_input(
v2 = gr.Checkbox(value=False, visible=False)
v_parameterization = gr.Checkbox(value=False, visible=False)
sdxl = gr.Checkbox(value=True, visible=False)
sd3 = gr.Checkbox(value=False, visible=False)
flux1 = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
sd3,
flux1,
)
# Check if the given pretrained_model_name_or_path is in the list of V2 base models
@ -969,11 +979,15 @@ def set_pretrained_model_name_or_path_input(
v2 = gr.Checkbox(value=True, visible=False)
v_parameterization = gr.Checkbox(value=False, visible=False)
sdxl = gr.Checkbox(value=False, visible=False)
sd3 = gr.Checkbox(value=False, visible=False)
flux1 = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
sd3,
flux1,
)
# Check if the given pretrained_model_name_or_path is in the list of V parameterization models
@ -984,11 +998,15 @@ def set_pretrained_model_name_or_path_input(
v2 = gr.Checkbox(value=True, visible=False)
v_parameterization = gr.Checkbox(value=True, visible=False)
sdxl = gr.Checkbox(value=False, visible=False)
sd3 = gr.Checkbox(value=False, visible=False)
flux1 = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
sd3,
flux1,
)
# Check if the given pretrained_model_name_or_path is in the list of V1 models
@ -997,17 +1015,32 @@ def set_pretrained_model_name_or_path_input(
v2 = gr.Checkbox(value=False, visible=False)
v_parameterization = gr.Checkbox(value=False, visible=False)
sdxl = gr.Checkbox(value=False, visible=False)
sd3 = gr.Checkbox(value=False, visible=False)
flux1 = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
sd3,
flux1,
)
# Check if the model_list is set to 'custom'
v2 = gr.Checkbox(visible=True)
v_parameterization = gr.Checkbox(visible=True)
sdxl = gr.Checkbox(visible=True)
sd3 = gr.Checkbox(visible=True)
flux1 = gr.Checkbox(visible=True)
# Auto-detect model type if safetensors file path is given
if pretrained_model_name_or_path.lower().endswith(".safetensors"):
detect = SDModelType(pretrained_model_name_or_path)
v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True)
sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True)
sd3 = gr.Checkbox(value=detect.Is_SD3(), visible=True)
flux1 = gr.Checkbox(value=detect.Is_FLUX1(), visible=True)
#TODO: v_parameterization
# If a refresh method is provided, use it to update the choices for the Dropdown widget
if refresh_method is not None:
@ -1021,6 +1054,8 @@ def set_pretrained_model_name_or_path_input(
v2,
v_parameterization,
sdxl,
sd3,
flux1,
)
@ -1369,7 +1404,11 @@ def validate_file_path(file_path: str) -> bool:
return True
def validate_folder_path(folder_path: str, can_be_written_to: bool = False, create_if_not_exists: bool = False) -> bool:
def validate_folder_path(
folder_path: str,
can_be_written_to: bool = False,
create_if_not_exists: bool = False,
) -> bool:
if folder_path == "":
return True
msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..."
@ -1387,6 +1426,7 @@ def validate_folder_path(folder_path: str, can_be_written_to: bool = False, crea
log.info(f"{msg} SUCCESS")
return True
def validate_toml_file(file_path: str) -> bool:
if file_path == "":
return True
@ -1394,7 +1434,7 @@ def validate_toml_file(file_path: str) -> bool:
if not os.path.isfile(file_path):
log.error(f"{msg} FAILED: does not exist")
return False
try:
toml.load(file_path)
except:
@ -1425,11 +1465,14 @@ def validate_model_path(pretrained_model_name_or_path: str) -> bool:
log.info(f"{msg} SUCCESS")
else:
# If not one of the default models, check if it's a valid local path
if not validate_file_path(pretrained_model_name_or_path) and not validate_folder_path(pretrained_model_name_or_path):
if not validate_file_path(
pretrained_model_name_or_path
) and not validate_folder_path(pretrained_model_name_or_path):
log.info(f"{msg} FAILURE: not a valid file or folder")
return False
return True
def is_file_writable(file_path: str) -> bool:
"""
Checks if a file is writable.
@ -1450,8 +1493,9 @@ def is_file_writable(file_path: str) -> bool:
pass
# If the file can be opened, it is considered writable
return True
except IOError:
except IOError as e:
# If an IOError occurs, the file cannot be written to
log.info(f"Error: {e}. File '{file_path}' is not writable.")
return False
@ -1462,7 +1506,7 @@ def print_command_and_toml(run_cmd, tmpfilename):
# Reconstruct the safe command string for display
command_to_run = " ".join(run_cmd)
log.info(command_to_run)
print(command_to_run)
print("")
log.info(f"Showing toml config file: {tmpfilename}")
@ -1489,10 +1533,11 @@ def validate_args_setting(input_string):
)
return False
def setup_environment():
env = os.environ.copy()
env["PYTHONPATH"] = (
fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

View File

@ -10,6 +10,11 @@ from .custom_logging import setup_logging
log = setup_logging()
import os
import re
import logging as log
from easygui import msgbox
def dataset_balancing(concept_repeats, folder, insecure):
if not concept_repeats > 0:
@ -78,7 +83,11 @@ def dataset_balancing(concept_repeats, folder, insecure):
old_name = os.path.join(folder, subdir)
new_name = os.path.join(folder, f"{repeats}_{subdir}")
os.rename(old_name, new_name)
# Check if the new folder name already exists
if os.path.exists(new_name):
log.warning(f"Destination folder {new_name} already exists. Skipping...")
else:
os.rename(old_name, new_name)
else:
log.info(
f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..."
@ -87,6 +96,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
msgbox("Dataset balancing completed...")
def warning(insecure):
if insecure:
if boolbox(

View File

@ -17,7 +17,9 @@ from .common_gui import (
SaveConfigFile,
scriptdir,
update_my_data,
validate_file_path, validate_folder_path, validate_model_path,
validate_file_path,
validate_folder_path,
validate_model_path,
validate_args_setting,
setup_environment,
)
@ -27,10 +29,13 @@ from .class_gui_config import KohyaSSGUIConfig
from .class_source_model import SourceModel
from .class_basic_training import BasicTraining
from .class_advanced_training import AdvancedTraining
from .class_sd3 import sd3Training
from .class_folders import Folders
from .class_command_executor import CommandExecutor
from .class_huggingface import HuggingFace
from .class_metadata import MetaData
from .class_sdxl_parameters import SDXLParameters
from .class_flux1 import flux1Training
from .dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
@ -60,6 +65,7 @@ def save_configuration(
v2,
v_parameterization,
sdxl,
flux1_checkbox,
logging_dir,
train_data_dir,
reg_data_dir,
@ -72,6 +78,7 @@ def save_configuration(
learning_rate_te2,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
train_batch_size,
epoch,
save_every_n_epochs,
@ -84,6 +91,7 @@ def save_configuration(
caption_extension,
enable_bucket,
gradient_checkpointing,
fp8_base,
full_fp16,
full_bf16,
no_token_padding,
@ -134,6 +142,7 @@ def save_configuration(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -150,18 +159,28 @@ def save_configuration(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
fused_backward_pass,
fused_optimizer_groups,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
max_timestep,
debiased_estimation_loss,
@ -178,6 +197,44 @@ def save_configuration(
metadata_license,
metadata_tags,
metadata_title,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
# Flux.1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
flux1_clip_l,
flux1_t5xxl,
discrete_flow_shift,
model_prediction_type,
timestep_sampling,
split_mode,
train_blocks,
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
blocks_to_swap,
single_blocks_to_swap,
double_blocks_to_swap,
mem_eff_save,
apply_t5_attn_mask,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -218,6 +275,7 @@ def open_configuration(
v2,
v_parameterization,
sdxl,
flux1_checkbox,
logging_dir,
train_data_dir,
reg_data_dir,
@ -230,6 +288,7 @@ def open_configuration(
learning_rate_te2,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
train_batch_size,
epoch,
save_every_n_epochs,
@ -242,6 +301,7 @@ def open_configuration(
caption_extension,
enable_bucket,
gradient_checkpointing,
fp8_base,
full_fp16,
full_bf16,
no_token_padding,
@ -292,6 +352,7 @@ def open_configuration(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -308,18 +369,28 @@ def open_configuration(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
fused_backward_pass,
fused_optimizer_groups,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
max_timestep,
debiased_estimation_loss,
@ -336,6 +407,44 @@ def open_configuration(
metadata_license,
metadata_tags,
metadata_title,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
# Flux.1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
flux1_clip_l,
flux1_t5xxl,
discrete_flow_shift,
model_prediction_type,
timestep_sampling,
split_mode,
train_blocks,
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
blocks_to_swap,
single_blocks_to_swap,
double_blocks_to_swap,
mem_eff_save,
apply_t5_attn_mask,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -371,6 +480,7 @@ def train_model(
v2,
v_parameterization,
sdxl,
flux1_checkbox,
logging_dir,
train_data_dir,
reg_data_dir,
@ -383,6 +493,7 @@ def train_model(
learning_rate_te2,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
train_batch_size,
epoch,
save_every_n_epochs,
@ -395,6 +506,7 @@ def train_model(
caption_extension,
enable_bucket,
gradient_checkpointing,
fp8_base,
full_fp16,
full_bf16,
no_token_padding,
@ -445,6 +557,7 @@ def train_model(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -461,18 +574,28 @@ def train_model(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
fused_backward_pass,
fused_optimizer_groups,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
max_timestep,
debiased_estimation_loss,
@ -489,6 +612,44 @@ def train_model(
metadata_license,
metadata_tags,
metadata_title,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
# Flux.1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
flux1_clip_l,
flux1_t5xxl,
discrete_flow_shift,
model_prediction_type,
timestep_sampling,
split_mode,
train_blocks,
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
blocks_to_swap,
single_blocks_to_swap,
double_blocks_to_swap,
mem_eff_save,
apply_t5_attn_mask,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -509,61 +670,50 @@ def train_model(
log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return
log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return TRAIN_BUTTON_VISIBLE
#
# Validate paths
#
#
if not validate_file_path(dataset_config):
return TRAIN_BUTTON_VISIBLE
if not validate_file_path(log_tracker_config):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True):
if not validate_folder_path(
logging_dir, can_be_written_to=True, create_if_not_exists=True
):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True):
if not validate_folder_path(
output_dir, can_be_written_to=True, create_if_not_exists=True
):
return TRAIN_BUTTON_VISIBLE
if not validate_model_path(pretrained_model_name_or_path):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(reg_data_dir):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(resume):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(train_data_dir):
return TRAIN_BUTTON_VISIBLE
if not validate_model_path(vae):
return TRAIN_BUTTON_VISIBLE
#
# End of path validation
#
# This function validates files or folder paths. Simply add new variables containing file of folder path
# to validate below
# if not validate_paths(
# dataset_config=dataset_config,
# headless=headless,
# log_tracker_config=log_tracker_config,
# logging_dir=logging_dir,
# output_dir=output_dir,
# pretrained_model_name_or_path=pretrained_model_name_or_path,
# reg_data_dir=reg_data_dir,
# resume=resume,
# train_data_dir=train_data_dir,
# vae=vae,
# ):
# return TRAIN_BUTTON_VISIBLE
if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless
):
@ -573,15 +723,6 @@ def train_model(
log.info(
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
)
if max_train_steps > 0:
if lr_warmup != 0:
lr_warmup_steps = round(
float(int(lr_warmup) * int(max_train_steps) / 100)
)
else:
lr_warmup_steps = 0
else:
lr_warmup_steps = 0
if max_train_steps == 0:
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
@ -640,11 +781,11 @@ def train_model(
reg_factor = 1
else:
log.warning(
"Regularisation images are used... Will double the number of steps required..."
"Regularization images are used... Will double the number of steps required..."
)
reg_factor = 2
log.info(f"Regulatization factor: {reg_factor}")
log.info(f"Regularization factor: {reg_factor}")
if max_train_steps == 0:
# calculate max_train_steps
@ -664,13 +805,18 @@ def train_model(
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"
if lr_warmup != 0:
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
else:
lr_warmup_steps = 0
log.info(f"Total steps: {total_steps}")
# Calculate lr_warmup_steps
if lr_warmup_steps > 0:
lr_warmup_steps = int(lr_warmup_steps)
if lr_warmup > 0:
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
elif lr_warmup != 0:
lr_warmup_steps = lr_warmup / 100
else:
lr_warmup_steps = 0
log.info(f"Train batch size: {train_batch_size}")
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
log.info(f"Epoch: {epoch}")
@ -682,7 +828,7 @@ def train_model(
log.error("accelerate not found")
return TRAIN_BUTTON_VISIBLE
run_cmd = [rf'{accelerate_path}', "launch"]
run_cmd = [rf"{accelerate_path}", "launch"]
run_cmd = AccelerateLaunch.run_cmd(
run_cmd=run_cmd,
@ -701,10 +847,23 @@ def train_model(
)
if sdxl:
run_cmd.append(rf'{scriptdir}/sd-scripts/sdxl_train.py')
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
elif sd3_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py")
elif flux1_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py")
else:
run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py")
cache_text_encoder_outputs = (
(sdxl and sdxl_cache_text_encoder_outputs)
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
)
cache_text_encoder_outputs_to_disk = (
sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
no_half_vae = sdxl and sdxl_no_half_vae
if max_data_loader_n_workers == "" or None:
max_data_loader_n_workers = 0
else:
@ -715,6 +874,19 @@ def train_model(
else:
max_train_steps = int(max_train_steps)
if sdxl:
train_text_encoder = (learning_rate_te1 != None and learning_rate_te1 > 0) or (
learning_rate_te2 != None and learning_rate_te2 > 0
)
fused_backward_pass_value = False
if sd3_checkbox:
fused_backward_pass_value = sd3_fused_backward_pass
elif flux1_checkbox:
fused_backward_pass_value = flux_fused_backward_pass
else:
fused_backward_pass_value = fused_backward_pass
# def save_huggingface_to_toml(self, toml_file_path: str):
config_toml_data = {
# Update the values in the TOML data
@ -724,22 +896,32 @@ def train_model(
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"cache_text_encoder_outputs": cache_text_encoder_outputs,
"cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None,
"clip_skip": clip_skip if clip_skip != 0 else None,
"color_aug": color_aug,
"dataset_config": dataset_config,
"debiased_estimation_loss": debiased_estimation_loss,
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
"dynamo_backend": dynamo_backend,
"enable_bucket": enable_bucket,
"epoch": int(epoch),
"flip_aug": flip_aug,
"fp8_base": fp8_base,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"fused_backward_pass": fused_backward_pass_value,
"fused_optimizer_groups": (
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
),
"gradient_accumulation_steps": int(gradient_accumulation_steps),
"gradient_checkpointing": gradient_checkpointing,
"huber_c": huber_c,
"huber_scale": huber_scale,
"huber_schedule": huber_schedule,
"huggingface_path_in_repo": huggingface_path_in_repo,
"huggingface_repo_id": huggingface_repo_id,
@ -750,16 +932,11 @@ def train_model(
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
"keep_tokens": int(keep_tokens),
"learning_rate": learning_rate, # both for sd1.5 and sdxl
"learning_rate_te": (
learning_rate_te if not sdxl and not 0 else None
), # only for sd1.5 and not 0
"learning_rate_te1": (
learning_rate_te1 if sdxl and not 0 else None
), # only for sdxl and not 0
"learning_rate_te2": (
learning_rate_te2 if sdxl and not 0 else None
), # only for sdxl and not 0
"learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5
"learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl
"learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl
"logging_dir": logging_dir,
"log_config": log_config,
"log_tracker_config": log_tracker_config,
"log_tracker_name": log_tracker_name,
"log_with": log_with,
@ -767,15 +944,20 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_num_cycles": (
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
int(lr_scheduler_num_cycles)
if lr_scheduler_num_cycles != ""
else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": max_bucket_reso,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_epochs": (
int(max_train_epochs) if int(max_train_epochs) != 0 else None
),
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
@ -789,6 +971,7 @@ def train_model(
"mixed_precision": mixed_precision,
"multires_noise_discount": multires_noise_discount,
"multires_noise_iterations": multires_noise_iterations if not 0 else None,
"no_half_vae": no_half_vae,
"no_token_padding": no_token_padding,
"noise_offset": noise_offset if not 0 else None,
"noise_offset_random_strength": noise_offset_random_strength,
@ -825,6 +1008,10 @@ def train_model(
"save_last_n_steps_state": (
save_last_n_steps_state if save_last_n_steps_state != 0 else None
),
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
"save_last_n_epochs_state": (
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
),
"save_model_as": save_model_as,
"save_precision": save_precision,
"save_state": save_state,
@ -834,20 +1021,65 @@ def train_model(
"sdpa": True if xformers == "sdpa" else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"skip_cache_check": skip_cache_check,
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
),
"t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None,
"train_batch_size": train_batch_size,
"train_data_dir": train_data_dir,
"train_text_encoder": train_text_encoder if sdxl else None,
"v2": v2,
"v_parameterization": v_parameterization,
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
"vae": vae,
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
"wandb_api_key": wandb_api_key,
"wandb_run_name": wandb_run_name,
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
"weighted_captions": weighted_captions,
"xformers": True if xformers == "xformers" else None,
# SD3 only Parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"clip_g": clip_g if sd3_checkbox else None,
# "clip_l": see previous assignment above for code
"logit_mean": logit_mean if sd3_checkbox else None,
"logit_std": logit_std if sd3_checkbox else None,
"mode_scale": mode_scale if sd3_checkbox else None,
"save_clip": save_clip if sd3_checkbox else None,
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
# "t5xxl": see previous assignment above for code
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
"text_encoder_batch_size": (
sd3_text_encoder_batch_size if sd3_checkbox else None
),
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
# Flux.1 specific parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"ae": ae if flux1_checkbox else None,
# "clip_l": see previous assignment above for code
# "t5xxl": see previous assignment above for code
"discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None,
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
"split_mode": split_mode if flux1_checkbox else None,
"train_blocks": train_blocks if flux1_checkbox else None,
"t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None,
"guidance_scale": guidance_scale if flux1_checkbox else None,
"blockwise_fused_optimizers": (
blockwise_fused_optimizers if flux1_checkbox else None
),
# "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code
"cpu_offload_checkpointing": (
cpu_offload_checkpointing if flux1_checkbox else None
),
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else None,
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
}
# Given dictionary `config_toml_data`
@ -855,7 +1087,7 @@ def train_model(
config_toml_data = {
key: value
for key, value in config_toml_data.items()
if value not in ["", False, None]
if not any([value == "", value is False, value is None])
}
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)
@ -865,8 +1097,8 @@ def train_model(
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
tmpfilename = fr"{output_dir}/config_dreambooth-{formatted_datetime}.toml"
tmpfilename = rf"{output_dir}/config_dreambooth-{formatted_datetime}.toml"
# Save the updated TOML data back to the file
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
toml.dump(config_toml_data, toml_file)
@ -875,7 +1107,7 @@ def train_model(
log.error(f"Failed to write TOML file: {toml_file.name}")
run_cmd.append(f"--config_file")
run_cmd.append(rf'{tmpfilename}')
run_cmd.append(rf"{tmpfilename}")
# Initialize a dictionary with always-included keyword arguments
kwargs_for_training = {
@ -981,6 +1213,26 @@ def dreambooth_tab(
config=config,
)
# Add SDXL Parameters
sdxl_params = SDXLParameters(
source_model.sdxl_checkbox,
config=config,
trainer="finetune",
)
# Add FLUX1 Parameters
flux1_training = flux1Training(
headless=headless,
config=config,
flux1_checkbox=source_model.flux1_checkbox,
finetuning=True,
)
# Add SD3 Parameters
sd3_training = sd3Training(
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
)
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
advanced_training = AdvancedTraining(headless=headless, config=config)
advanced_training.color_aug.change(
@ -1011,6 +1263,7 @@ def dreambooth_tab(
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
source_model.flux1_checkbox,
folders.logging_dir,
source_model.train_data_dir,
folders.reg_data_dir,
@ -1023,6 +1276,7 @@ def dreambooth_tab(
basic_training.learning_rate_te2,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.lr_warmup_steps,
basic_training.train_batch_size,
basic_training.epoch,
basic_training.save_every_n_epochs,
@ -1035,6 +1289,7 @@ def dreambooth_tab(
basic_training.caption_extension,
basic_training.enable_bucket,
advanced_training.gradient_checkpointing,
advanced_training.fp8_base,
advanced_training.full_fp16,
advanced_training.full_bf16,
advanced_training.no_token_padding,
@ -1084,6 +1339,7 @@ def dreambooth_tab(
basic_training.optimizer,
basic_training.optimizer_args,
basic_training.lr_scheduler_args,
basic_training.lr_scheduler_type,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.noise_offset_random_strength,
@ -1100,18 +1356,28 @@ def dreambooth_tab(
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.huber_scale,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.weighted_captions,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.save_last_n_epochs,
advanced_training.save_last_n_epochs_state,
advanced_training.skip_cache_check,
advanced_training.log_with,
advanced_training.wandb_api_key,
advanced_training.wandb_run_name,
advanced_training.log_tracker_name,
advanced_training.log_tracker_config,
advanced_training.log_config,
advanced_training.scale_v_pred_loss_like_noise_pred,
sdxl_params.disable_mmap_load_safetensors,
sdxl_params.fused_backward_pass,
sdxl_params.fused_optimizer_groups,
sdxl_params.sdxl_cache_text_encoder_outputs,
sdxl_params.sdxl_no_half_vae,
advanced_training.min_timestep,
advanced_training.max_timestep,
advanced_training.debiased_estimation_loss,
@ -1128,6 +1394,44 @@ def dreambooth_tab(
metadata.metadata_license,
metadata.metadata_tags,
metadata.metadata_title,
# SD3 Parameters
sd3_training.sd3_cache_text_encoder_outputs,
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
sd3_training.sd3_fused_backward_pass,
sd3_training.clip_g,
sd3_training.clip_l,
sd3_training.logit_mean,
sd3_training.logit_std,
sd3_training.mode_scale,
sd3_training.save_clip,
sd3_training.save_t5xxl,
sd3_training.t5xxl,
sd3_training.t5xxl_device,
sd3_training.t5xxl_dtype,
sd3_training.sd3_text_encoder_batch_size,
sd3_training.weighting_scheme,
source_model.sd3_checkbox,
# Flux1 parameters
flux1_training.flux1_cache_text_encoder_outputs,
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
flux1_training.ae,
flux1_training.clip_l,
flux1_training.t5xxl,
flux1_training.discrete_flow_shift,
flux1_training.model_prediction_type,
flux1_training.timestep_sampling,
flux1_training.split_mode,
flux1_training.train_blocks,
flux1_training.t5xxl_max_token_length,
flux1_training.guidance_scale,
flux1_training.blockwise_fused_optimizers,
flux1_training.flux_fused_backward_pass,
flux1_training.cpu_offload_checkpointing,
advanced_training.blocks_to_swap,
flux1_training.single_blocks_to_swap,
flux1_training.double_blocks_to_swap,
flux1_training.mem_eff_save,
flux1_training.apply_t5_attn_mask,
]
configuration.button_open_config.click(
@ -1181,4 +1485,4 @@ def dreambooth_tab(
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
)

View File

@ -12,6 +12,7 @@ from .common_gui import (
)
from .custom_logging import setup_logging
from .sd_modeltype import SDModelType
# Set up logging
log = setup_logging()
@ -337,6 +338,19 @@ def gradio_extract_lora_tab(
outputs=[load_tuned_model_to, load_original_model_to],
)
#secondary event on model_tuned for auto-detection of v2/SDXL
def change_modeltype_model_tuned(path):
detect = SDModelType(path)
v2 = gr.Checkbox(value=detect.Is_SD2())
sdxl = gr.Checkbox(value=detect.Is_SDXL())
return v2, sdxl
model_tuned.change(
change_modeltype_model_tuned,
inputs=model_tuned,
outputs=[v2, sdxl]
)
extract_button = gr.Button("Extract LoRA model")
extract_button.click(

View File

@ -18,14 +18,18 @@ from .common_gui import (
SaveConfigFile,
scriptdir,
update_my_data,
validate_file_path, validate_folder_path, validate_model_path,
validate_args_setting, setup_environment,
validate_file_path,
validate_folder_path,
validate_model_path,
validate_args_setting,
setup_environment,
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
from .class_source_model import SourceModel
from .class_basic_training import BasicTraining
from .class_advanced_training import AdvancedTraining
from .class_sd3 import sd3Training
from .class_folders import Folders
from .class_sdxl_parameters import SDXLParameters
from .class_command_executor import CommandExecutor
@ -34,6 +38,7 @@ from .class_sample_images import SampleImages, create_prompt_file
from .class_huggingface import HuggingFace
from .class_metadata import MetaData
from .class_gui_config import KohyaSSGUIConfig
from .class_flux1 import flux1Training
from .custom_logging import setup_logging
@ -65,6 +70,7 @@ def save_configuration(
v2,
v_parameterization,
sdxl_checkbox,
flux1_checkbox,
train_dir,
image_folder,
output_dir,
@ -82,6 +88,7 @@ def save_configuration(
learning_rate,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
dataset_repeats,
train_batch_size,
epoch,
@ -116,6 +123,7 @@ def save_configuration(
save_state_on_train_end,
resume,
gradient_checkpointing,
fp8_base,
gradient_accumulation_steps,
block_lr,
mem_eff_attn,
@ -142,6 +150,7 @@ def save_configuration(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -158,18 +167,26 @@ def save_configuration(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
fused_backward_pass,
fused_optimizer_groups,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
@ -188,6 +205,44 @@ def save_configuration(
metadata_license,
metadata_tags,
metadata_title,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
# Flux.1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
flux1_clip_l,
flux1_t5xxl,
discrete_flow_shift,
model_prediction_type,
timestep_sampling,
split_mode,
train_blocks,
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
blocks_to_swap,
single_blocks_to_swap,
double_blocks_to_swap,
mem_eff_save,
apply_t5_attn_mask,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -231,6 +286,7 @@ def open_configuration(
v2,
v_parameterization,
sdxl_checkbox,
flux1_checkbox,
train_dir,
image_folder,
output_dir,
@ -248,6 +304,7 @@ def open_configuration(
learning_rate,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
dataset_repeats,
train_batch_size,
epoch,
@ -282,6 +339,7 @@ def open_configuration(
save_state_on_train_end,
resume,
gradient_checkpointing,
fp8_base,
gradient_accumulation_steps,
block_lr,
mem_eff_attn,
@ -308,6 +366,7 @@ def open_configuration(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -324,18 +383,26 @@ def open_configuration(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
fused_backward_pass,
fused_optimizer_groups,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
@ -354,6 +421,44 @@ def open_configuration(
metadata_license,
metadata_tags,
metadata_title,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
# Flux.1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
flux1_clip_l,
flux1_t5xxl,
discrete_flow_shift,
model_prediction_type,
timestep_sampling,
split_mode,
train_blocks,
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
blocks_to_swap,
single_blocks_to_swap,
double_blocks_to_swap,
mem_eff_save,
apply_t5_attn_mask,
training_preset,
):
# Get list of function parameters and values
@ -403,6 +508,7 @@ def train_model(
v2,
v_parameterization,
sdxl_checkbox,
flux1_checkbox,
train_dir,
image_folder,
output_dir,
@ -420,6 +526,7 @@ def train_model(
learning_rate,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
dataset_repeats,
train_batch_size,
epoch,
@ -454,6 +561,7 @@ def train_model(
save_state_on_train_end,
resume,
gradient_checkpointing,
fp8_base,
gradient_accumulation_steps,
block_lr,
mem_eff_attn,
@ -480,6 +588,7 @@ def train_model(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -496,18 +605,26 @@ def train_model(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
fused_backward_pass,
fused_optimizer_groups,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
@ -526,6 +643,44 @@ def train_model(
metadata_license,
metadata_tags,
metadata_title,
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
# Flux.1
flux1_cache_text_encoder_outputs,
flux1_cache_text_encoder_outputs_to_disk,
ae,
flux1_clip_l,
flux1_t5xxl,
discrete_flow_shift,
model_prediction_type,
timestep_sampling,
split_mode,
train_blocks,
t5xxl_max_token_length,
guidance_scale,
blockwise_fused_optimizers,
flux_fused_backward_pass,
cpu_offload_checkpointing,
blocks_to_swap,
single_blocks_to_swap,
double_blocks_to_swap,
mem_eff_save,
apply_t5_attn_mask,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -558,44 +713,36 @@ def train_model(
#
# Validate paths
#
#
if not validate_file_path(dataset_config):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(image_folder):
return TRAIN_BUTTON_VISIBLE
if not validate_file_path(log_tracker_config):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True):
if not validate_folder_path(
logging_dir, can_be_written_to=True, create_if_not_exists=True
):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True):
if not validate_folder_path(
output_dir, can_be_written_to=True, create_if_not_exists=True
):
return TRAIN_BUTTON_VISIBLE
if not validate_model_path(pretrained_model_name_or_path):
return TRAIN_BUTTON_VISIBLE
if not validate_folder_path(resume):
return TRAIN_BUTTON_VISIBLE
#
# End of path validation
#
# if not validate_paths(
# dataset_config=dataset_config,
# finetune_image_folder=image_folder,
# headless=headless,
# log_tracker_config=log_tracker_config,
# logging_dir=logging_dir,
# output_dir=output_dir,
# pretrained_model_name_or_path=pretrained_model_name_or_path,
# resume=resume,
# ):
# return TRAIN_BUTTON_VISIBLE
if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless
@ -727,10 +874,16 @@ def train_model(
log.info(max_train_steps_info)
if max_train_steps != 0:
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
# Calculate lr_warmup_steps
if lr_warmup_steps > 0:
lr_warmup_steps = int(lr_warmup_steps)
if lr_warmup > 0:
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
elif lr_warmup != 0:
lr_warmup_steps = lr_warmup / 100
else:
lr_warmup_steps = 0
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
accelerate_path = get_executable_path("accelerate")
@ -738,7 +891,7 @@ def train_model(
log.error("accelerate not found")
return TRAIN_BUTTON_VISIBLE
run_cmd = [rf'{accelerate_path}', "launch"]
run_cmd = [rf"{accelerate_path}", "launch"]
run_cmd = AccelerateLaunch.run_cmd(
run_cmd=run_cmd,
@ -758,6 +911,10 @@ def train_model(
if sdxl_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
elif sd3_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py")
elif flux1_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py")
else:
run_cmd.append(rf"{scriptdir}/sd-scripts/fine_tune.py")
@ -766,7 +923,14 @@ def train_model(
if use_latent_files == "Yes"
else f"{train_dir}/{caption_metadata_filename}"
)
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
cache_text_encoder_outputs = (
(sdxl_checkbox and sdxl_cache_text_encoder_outputs)
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
)
cache_text_encoder_outputs_to_disk = (
sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
if max_data_loader_n_workers == "" or None:
@ -791,22 +955,31 @@ def train_model(
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"cache_text_encoder_outputs": cache_text_encoder_outputs,
"cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None,
"clip_skip": clip_skip if clip_skip != 0 else None,
"color_aug": color_aug,
"dataset_config": dataset_config,
"dataset_repeats": int(dataset_repeats),
"debiased_estimation_loss": debiased_estimation_loss,
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
"dynamo_backend": dynamo_backend,
"enable_bucket": True,
"flip_aug": flip_aug,
"fp8_base": fp8_base,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"fused_backward_pass": sd3_fused_backward_pass if sd3_checkbox else flux_fused_backward_pass if flux1_checkbox else fused_backward_pass,
"fused_optimizer_groups": (
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
),
"gradient_accumulation_steps": int(gradient_accumulation_steps),
"gradient_checkpointing": gradient_checkpointing,
"huber_c": huber_c,
"huber_scale": huber_scale,
"huber_schedule": huber_schedule,
"huggingface_repo_id": huggingface_repo_id,
"huggingface_token": huggingface_token,
@ -828,11 +1001,13 @@ def train_model(
learning_rate_te2 if sdxl_checkbox else None
), # only for sdxl
"logging_dir": logging_dir,
"log_config": log_config,
"log_tracker_name": log_tracker_name,
"log_tracker_config": log_tracker_config,
"loss_type": loss_type,
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": int(max_bucket_reso),
@ -886,6 +1061,10 @@ def train_model(
"save_last_n_steps_state": (
save_last_n_steps_state if save_last_n_steps_state != 0 else None
),
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
"save_last_n_epochs_state": (
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
),
"save_model_as": save_model_as,
"save_precision": save_precision,
"save_state": save_state,
@ -895,6 +1074,8 @@ def train_model(
"sdpa": True if xformers == "sdpa" else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"skip_cache_check": skip_cache_check,
"t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None,
"train_batch_size": train_batch_size,
"train_data_dir": image_folder,
"train_text_encoder": train_text_encoder,
@ -904,9 +1085,50 @@ def train_model(
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
"wandb_api_key": wandb_api_key,
"wandb_run_name": wandb_run_name,
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
"weighted_captions": weighted_captions,
"xformers": True if xformers == "xformers" else None,
# SD3 only Parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"clip_g": clip_g if sd3_checkbox else None,
# "clip_l": see previous assignment above for code
"logit_mean": logit_mean if sd3_checkbox else None,
"logit_std": logit_std if sd3_checkbox else None,
"mode_scale": mode_scale if sd3_checkbox else None,
"save_clip": save_clip if sd3_checkbox else None,
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
# "t5xxl": see previous assignment above for code
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
"text_encoder_batch_size": (
sd3_text_encoder_batch_size if sd3_checkbox else None
),
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
# Flux.1 specific parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"ae": ae if flux1_checkbox else None,
# "clip_l": see previous assignment above for code
# "t5xxl": see previous assignment above for code
"discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None,
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
"split_mode": split_mode if flux1_checkbox else None,
"train_blocks": train_blocks if flux1_checkbox else None,
"t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None,
"guidance_scale": guidance_scale if flux1_checkbox else None,
"blockwise_fused_optimizers": (
blockwise_fused_optimizers if flux1_checkbox else None
),
"cpu_offload_checkpointing": (
cpu_offload_checkpointing if flux1_checkbox else None
),
"blocks_to_swap": blocks_to_swap if flux1_checkbox else None,
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
}
# Given dictionary `config_toml_data`
@ -924,7 +1146,7 @@ def train_model(
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
tmpfilename = fr"{output_dir}/config_finetune-{formatted_datetime}.toml"
tmpfilename = rf"{output_dir}/config_finetune-{formatted_datetime}.toml"
# Save the updated TOML data back to the file
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
toml.dump(config_toml_data, toml_file)
@ -1090,7 +1312,9 @@ def finetune_tab(
# Add SDXL Parameters
sdxl_params = SDXLParameters(
source_model.sdxl_checkbox, config=config
source_model.sdxl_checkbox,
config=config,
trainer="finetune",
)
with gr.Row():
@ -1099,6 +1323,19 @@ def finetune_tab(
label="Train text encoder", value=True
)
# Add FLUX1 Parameters
flux1_training = flux1Training(
headless=headless,
config=config,
flux1_checkbox=source_model.flux1_checkbox,
finetuning=True,
)
# Add SD3 Parameters
sd3_training = sd3Training(
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
)
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
with gr.Row():
gradient_accumulation_steps = gr.Slider(
@ -1146,6 +1383,7 @@ def finetune_tab(
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
source_model.flux1_checkbox,
train_dir,
image_folder,
output_dir,
@ -1163,6 +1401,7 @@ def finetune_tab(
basic_training.learning_rate,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.lr_warmup_steps,
dataset_repeats,
basic_training.train_batch_size,
basic_training.epoch,
@ -1196,6 +1435,7 @@ def finetune_tab(
advanced_training.save_state_on_train_end,
advanced_training.resume,
advanced_training.gradient_checkpointing,
advanced_training.fp8_base,
gradient_accumulation_steps,
block_lr,
advanced_training.mem_eff_attn,
@ -1222,6 +1462,7 @@ def finetune_tab(
basic_training.optimizer,
basic_training.optimizer_args,
basic_training.lr_scheduler_args,
basic_training.lr_scheduler_type,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.noise_offset_random_strength,
@ -1238,18 +1479,26 @@ def finetune_tab(
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.huber_scale,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
weighted_captions,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.save_last_n_epochs,
advanced_training.save_last_n_epochs_state,
advanced_training.skip_cache_check,
advanced_training.log_with,
advanced_training.wandb_api_key,
advanced_training.wandb_run_name,
advanced_training.log_tracker_name,
advanced_training.log_tracker_config,
advanced_training.log_config,
advanced_training.scale_v_pred_loss_like_noise_pred,
sdxl_params.disable_mmap_load_safetensors,
sdxl_params.fused_backward_pass,
sdxl_params.fused_optimizer_groups,
sdxl_params.sdxl_cache_text_encoder_outputs,
sdxl_params.sdxl_no_half_vae,
advanced_training.min_timestep,
@ -1268,6 +1517,44 @@ def finetune_tab(
metadata.metadata_license,
metadata.metadata_tags,
metadata.metadata_title,
# SD3 Parameters
sd3_training.sd3_cache_text_encoder_outputs,
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
sd3_training.clip_g,
sd3_training.clip_l,
sd3_training.logit_mean,
sd3_training.logit_std,
sd3_training.mode_scale,
sd3_training.save_clip,
sd3_training.save_t5xxl,
sd3_training.t5xxl,
sd3_training.t5xxl_device,
sd3_training.t5xxl_dtype,
sd3_training.sd3_text_encoder_batch_size,
sd3_training.sd3_fused_backward_pass,
sd3_training.weighting_scheme,
source_model.sd3_checkbox,
# Flux1 parameters
flux1_training.flux1_cache_text_encoder_outputs,
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
flux1_training.ae,
flux1_training.clip_l,
flux1_training.t5xxl,
flux1_training.discrete_flow_shift,
flux1_training.model_prediction_type,
flux1_training.timestep_sampling,
flux1_training.split_mode,
flux1_training.train_blocks,
flux1_training.t5xxl_max_token_length,
flux1_training.guidance_scale,
flux1_training.blockwise_fused_optimizers,
flux1_training.flux_fused_backward_pass,
flux1_training.cpu_offload_checkpointing,
advanced_training.blocks_to_swap,
flux1_training.single_blocks_to_swap,
flux1_training.double_blocks_to_swap,
flux1_training.mem_eff_save,
flux1_training.apply_t5_attn_mask,
]
configuration.button_open_config.click(
@ -1353,4 +1640,4 @@ def finetune_tab(
if os.path.exists(top_level_path):
with open(os.path.join(top_level_path), "r", encoding="utf-8") as file:
guides_top_level = file.read() + "\n"
gr.Markdown(guides_top_level)
gr.Markdown(guides_top_level)

View File

@ -0,0 +1,273 @@
import gradio as gr
import subprocess
import os
import sys
from .common_gui import (
get_saveasfilename_path,
get_file_path,
scriptdir,
list_files,
create_refresh_button,
setup_environment,
)
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
def extract_flux_lora(
model_org,
model_tuned,
save_to,
save_precision,
dim,
device,
clamp_quantile,
no_metadata,
mem_eff_safe_open,
):
# Check for required inputs
if model_org == "" or model_tuned == "" or save_to == "":
log.info(
"Please provide all required inputs: original model, tuned model, and save path."
)
return
# Check if source models exist
if not os.path.isfile(model_org):
log.info("The provided original model is not a file")
return
if not os.path.isfile(model_tuned):
log.info("The provided tuned model is not a file")
return
# Prepare save path
if os.path.dirname(save_to) == "":
save_to = os.path.join(os.path.dirname(model_tuned), save_to)
if os.path.isdir(save_to):
save_to = os.path.join(save_to, "flux_lora.safetensors")
if os.path.normpath(model_tuned) == os.path.normpath(save_to):
path, ext = os.path.splitext(save_to)
save_to = f"{path}_lora{ext}"
run_cmd = [
rf"{PYTHON}",
rf"{scriptdir}/sd-scripts/networks/flux_extract_lora.py",
"--model_org",
rf"{model_org}",
"--model_tuned",
rf"{model_tuned}",
"--save_to",
rf"{save_to}",
"--dim",
str(dim),
"--device",
device,
"--clamp_quantile",
str(clamp_quantile),
]
if save_precision:
run_cmd.extend(["--save_precision", save_precision])
if no_metadata:
run_cmd.append("--no_metadata")
if mem_eff_safe_open:
run_cmd.append("--mem_eff_safe_open")
env = setup_environment()
# Reconstruct the safe command string for display
command_to_run = " ".join(run_cmd)
log.info(f"Executing command: {command_to_run}")
# Run the command
subprocess.run(run_cmd, env=env)
def gradio_flux_extract_lora_tab(headless=False):
current_model_dir = os.path.join(scriptdir, "outputs")
current_save_dir = os.path.join(scriptdir, "outputs")
def list_models(path):
return list(list_files(path, exts=[".safetensors"], all=True))
with gr.Tab("Extract Flux LoRA"):
gr.Markdown(
"This utility can extract a LoRA network from a finetuned Flux model."
)
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
model_ext = gr.Textbox(value="*.safetensors", visible=False)
model_ext_name = gr.Textbox(value="Model types", visible=False)
with gr.Group(), gr.Row():
model_org = gr.Dropdown(
label="Original Flux model (path to the original model)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
model_org,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_model_org_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_model_org_file.click(
get_file_path,
inputs=[model_org, model_ext, model_ext_name],
outputs=model_org,
show_progress=False,
)
model_tuned = gr.Dropdown(
label="Finetuned Flux model (path to the finetuned model to extract)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
model_tuned,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_model_tuned_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_model_tuned_file.click(
get_file_path,
inputs=[model_tuned, model_ext, model_ext_name],
outputs=model_tuned,
show_progress=False,
)
with gr.Group(), gr.Row():
save_to = gr.Dropdown(
label="Save to (path where to save the extracted LoRA model...)",
interactive=True,
choices=[""] + list_models(current_save_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_models(current_save_dir)},
"open_folder_small",
)
button_save_to = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_save_to.click(
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,
)
save_precision = gr.Dropdown(
label="Save precision",
choices=["None", "float", "fp16", "bf16"],
value="None",
interactive=True,
)
with gr.Row():
dim = gr.Slider(
minimum=1,
maximum=1024,
label="Network Dimension (Rank)",
value=4,
step=1,
interactive=True,
)
device = gr.Dropdown(
label="Device",
choices=["cpu", "cuda"],
value="cuda",
interactive=True,
)
clamp_quantile = gr.Slider(
minimum=0,
maximum=1,
label="Clamp Quantile",
value=0.99,
step=0.01,
interactive=True,
)
with gr.Row():
no_metadata = gr.Checkbox(
label="No metadata (do not save sai modelspec metadata)",
value=False,
interactive=True,
)
mem_eff_safe_open = gr.Checkbox(
label="Memory efficient safe open (experimental feature)",
value=False,
interactive=True,
)
extract_button = gr.Button("Extract Flux LoRA model")
extract_button.click(
extract_flux_lora,
inputs=[
model_org,
model_tuned,
save_to,
save_precision,
dim,
device,
clamp_quantile,
no_metadata,
mem_eff_safe_open,
],
show_progress=False,
)
model_org.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
inputs=model_org,
outputs=model_org,
show_progress=False,
)
model_tuned.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
inputs=model_tuned,
outputs=model_tuned,
show_progress=False,
)
save_to.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
inputs=save_to,
outputs=save_to,
show_progress=False,
)

View File

@ -0,0 +1,470 @@
# Standard library imports
import os
import subprocess
import sys
import json
# Third-party imports
import gradio as gr
# Local module imports
from .common_gui import (
get_saveasfilename_path,
get_file_path,
scriptdir,
list_files,
create_refresh_button,
setup_environment,
)
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
def check_model(model):
if not model:
return True
if not os.path.isfile(model):
log.info(f"The provided {model} is not a file")
return False
return True
def verify_conditions(flux_model, lora_models):
lora_models_count = sum(1 for model in lora_models if model)
if flux_model and lora_models_count >= 1:
return True
elif not flux_model and lora_models_count >= 2:
return True
return False
class GradioFluxMergeLoRaTab:
def __init__(self, headless=False):
self.headless = headless
self.build_tab()
def save_inputs_to_json(self, file_path, inputs):
with open(file_path, "w", encoding="utf-8") as file:
json.dump(inputs, file)
log.info(f"Saved inputs to {file_path}")
def load_inputs_from_json(self, file_path):
with open(file_path, "r", encoding="utf-8") as file:
inputs = json.load(file)
log.info(f"Loaded inputs from {file_path}")
return inputs
def build_tab(self):
current_flux_model_dir = os.path.join(scriptdir, "outputs")
current_save_dir = os.path.join(scriptdir, "outputs")
current_lora_model_dir = current_flux_model_dir
def list_flux_models(path):
nonlocal current_flux_model_dir
current_flux_model_dir = path
return list(list_files(path, exts=[".safetensors"], all=True))
def list_lora_models(path):
nonlocal current_lora_model_dir
current_lora_model_dir = path
return list(list_files(path, exts=[".safetensors"], all=True))
def list_save_to(path):
nonlocal current_save_dir
current_save_dir = path
return list(list_files(path, exts=[".safetensors"], all=True))
with gr.Tab("Merge FLUX LoRA"):
gr.Markdown(
"This utility can merge up to 4 LoRA into a FLUX model or alternatively merge up to 4 LoRA together."
)
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
flux_ext = gr.Textbox(value="*.safetensors", visible=False)
flux_ext_name = gr.Textbox(value="FLUX model types", visible=False)
with gr.Group(), gr.Row():
flux_model = gr.Dropdown(
label="FLUX Model (Optional. FLUX model path, if you want to merge it with LoRA files via the 'concat' method)",
interactive=True,
choices=[""] + list_flux_models(current_flux_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
flux_model,
lambda: None,
lambda: {"choices": list_flux_models(current_flux_model_dir)},
"open_folder_small",
)
flux_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
flux_model_file.click(
get_file_path,
inputs=[flux_model, flux_ext, flux_ext_name],
outputs=flux_model,
show_progress=False,
)
flux_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_flux_models(path)),
inputs=flux_model,
outputs=flux_model,
show_progress=False,
)
with gr.Group(), gr.Row():
lora_a_model = gr.Dropdown(
label='LoRA model "A" (path to the LoRA A model)',
interactive=True,
choices=[""] + list_lora_models(current_lora_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
lora_a_model,
lambda: None,
lambda: {"choices": list_lora_models(current_lora_model_dir)},
"open_folder_small",
)
button_lora_a_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_a_model_file.click(
get_file_path,
inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model,
show_progress=False,
)
lora_b_model = gr.Dropdown(
label='LoRA model "B" (path to the LoRA B model)',
interactive=True,
choices=[""] + list_lora_models(current_lora_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
lora_b_model,
lambda: None,
lambda: {"choices": list_lora_models(current_lora_model_dir)},
"open_folder_small",
)
button_lora_b_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_b_model_file.click(
get_file_path,
inputs=[lora_b_model, lora_ext, lora_ext_name],
outputs=lora_b_model,
show_progress=False,
)
lora_a_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
inputs=lora_a_model,
outputs=lora_a_model,
show_progress=False,
)
lora_b_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
inputs=lora_b_model,
outputs=lora_b_model,
show_progress=False,
)
with gr.Row():
ratio_a = gr.Slider(
label="Model A merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=2,
step=0.01,
value=0.0,
interactive=True,
)
ratio_b = gr.Slider(
label="Model B merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=2,
step=0.01,
value=0.0,
interactive=True,
)
with gr.Group(), gr.Row():
lora_c_model = gr.Dropdown(
label='LoRA model "C" (path to the LoRA C model)',
interactive=True,
choices=[""] + list_lora_models(current_lora_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
lora_c_model,
lambda: None,
lambda: {"choices": list_lora_models(current_lora_model_dir)},
"open_folder_small",
)
button_lora_c_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_c_model_file.click(
get_file_path,
inputs=[lora_c_model, lora_ext, lora_ext_name],
outputs=lora_c_model,
show_progress=False,
)
lora_d_model = gr.Dropdown(
label='LoRA model "D" (path to the LoRA D model)',
interactive=True,
choices=[""] + list_lora_models(current_lora_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
lora_d_model,
lambda: None,
lambda: {"choices": list_lora_models(current_lora_model_dir)},
"open_folder_small",
)
button_lora_d_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_d_model_file.click(
get_file_path,
inputs=[lora_d_model, lora_ext, lora_ext_name],
outputs=lora_d_model,
show_progress=False,
)
lora_c_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
inputs=lora_c_model,
outputs=lora_c_model,
show_progress=False,
)
lora_d_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
inputs=lora_d_model,
outputs=lora_d_model,
show_progress=False,
)
with gr.Row():
ratio_c = gr.Slider(
label="Model C merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=2,
step=0.01,
value=0.0,
interactive=True,
)
ratio_d = gr.Slider(
label="Model D merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=2,
step=0.01,
value=0.0,
interactive=True,
)
with gr.Group(), gr.Row():
save_to = gr.Dropdown(
label="Save to (path for the file to save...)",
interactive=True,
choices=[""] + list_save_to(current_save_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_save_to = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_save_to.click(
get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
show_progress=False,
)
precision = gr.Radio(
label="Merge precision",
choices=["float", "fp16", "bf16"],
value="float",
interactive=True,
)
save_precision = gr.Radio(
label="Save precision",
choices=["float", "fp16", "bf16", "fp8"],
value="fp16",
interactive=True,
)
save_to.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_save_to(path)),
inputs=save_to,
outputs=save_to,
show_progress=False,
)
with gr.Row():
loading_device = gr.Dropdown(
label="Loading device",
choices=["cpu", "cuda"],
value="cpu",
interactive=True,
)
working_device = gr.Dropdown(
label="Working device",
choices=["cpu", "cuda"],
value="cpu",
interactive=True,
)
with gr.Row():
concat = gr.Checkbox(label="Concat LoRA", value=False)
shuffle = gr.Checkbox(label="Shuffle LoRA weights", value=False)
no_metadata = gr.Checkbox(label="Don't save metadata", value=False)
diffusers = gr.Checkbox(label="Diffusers LoRA", value=False)
merge_button = gr.Button("Merge model")
merge_button.click(
self.merge_flux_lora,
inputs=[
flux_model,
lora_a_model,
lora_b_model,
lora_c_model,
lora_d_model,
ratio_a,
ratio_b,
ratio_c,
ratio_d,
save_to,
precision,
save_precision,
loading_device,
working_device,
concat,
shuffle,
no_metadata,
diffusers,
],
show_progress=False,
)
def merge_flux_lora(
self,
flux_model,
lora_a_model,
lora_b_model,
lora_c_model,
lora_d_model,
ratio_a,
ratio_b,
ratio_c,
ratio_d,
save_to,
precision,
save_precision,
loading_device,
working_device,
concat,
shuffle,
no_metadata,
difffusers,
):
log.info("Merge FLUX LoRA...")
models = [
lora_a_model,
lora_b_model,
lora_c_model,
lora_d_model,
]
lora_models = [model for model in models if model]
ratios = [ratio for model, ratio in zip(models, [ratio_a, ratio_b, ratio_c, ratio_d]) if model]
# if not verify_conditions(flux_model, lora_models):
# log.info(
# "Warning: Either provide at least one LoRA model along with the FLUX model or at least two LoRA models if no FLUX model is provided."
# )
# return
for model in [flux_model] + lora_models:
if not check_model(model):
return
run_cmd = [rf"{PYTHON}", rf"{scriptdir}/sd-scripts/networks/flux_merge_lora.py"]
if flux_model:
run_cmd.extend(["--flux_model", rf"{flux_model}"])
run_cmd.extend([
"--save_precision", save_precision,
"--precision", precision,
"--save_to", rf"{save_to}",
"--loading_device", loading_device,
"--working_device", working_device,
])
if lora_models:
run_cmd.append("--models")
run_cmd.extend(lora_models)
run_cmd.append("--ratios")
run_cmd.extend(map(str, ratios))
if concat:
run_cmd.append("--concat")
if shuffle:
run_cmd.append("--shuffle")
if no_metadata:
run_cmd.append("--no_metadata")
if difffusers:
run_cmd.append("--diffusers")
env = setup_environment()
# Reconstruct the safe command string for display
command_to_run = " ".join(run_cmd)
log.info(f"Executing command: {command_to_run}")
# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, env=env)
log.info("Done merging...")

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ from .common_gui import (
create_refresh_button, setup_environment
)
from .custom_logging import setup_logging
from .sd_modeltype import SDModelType
# Set up logging
log = setup_logging()
@ -145,6 +146,13 @@ class GradioMergeLoRaTab:
show_progress=False,
)
#secondary event on sd_model for auto-detection of SDXL
sd_model.change(
lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()),
inputs=sd_model,
outputs=sdxl_model
)
with gr.Group(), gr.Row():
lora_a_model = gr.Dropdown(
label='LoRA model "A" (path to the LoRA A model)',

65
kohya_gui/sd_modeltype.py Executable file
View File

@ -0,0 +1,65 @@
from os.path import isfile
from safetensors import safe_open
import enum
# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403
class ModelType(enum.Enum):
UNKNOWN = 0
SD1 = 1
SD2 = 2
SDXL = 3
SD3 = 4
FLUX1 = 5
class SDModelType:
def __init__(self, safetensors_path):
self.model_type = ModelType.UNKNOWN
if not isfile(safetensors_path):
return
try:
st = safe_open(filename=safetensors_path, framework="numpy", device="cpu")
# print(st.keys())
def hasKeyPrefix(pfx):
return any(k.startswith(pfx) for k in st.keys())
if "model.diffusion_model.x_embedder.proj.weight" in st.keys():
self.model_type = ModelType.SD3
elif (
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
in st.keys()
or "double_blocks.0.img_attn.norm.key_norm.scale" in st.keys()
):
# print("flux1 model detected...")
self.model_type = ModelType.FLUX1
elif hasKeyPrefix("conditioner."):
self.model_type = ModelType.SDXL
elif hasKeyPrefix("cond_stage_model.model."):
self.model_type = ModelType.SD2
elif hasKeyPrefix("model."):
self.model_type = ModelType.SD1
except:
pass
# print(f"Model type: {self.model_type}")
def Is_SD1(self):
return self.model_type == ModelType.SD1
def Is_SD2(self):
return self.model_type == ModelType.SD2
def Is_SDXL(self):
return self.model_type == ModelType.SDXL
def Is_SD3(self):
return self.model_type == ModelType.SD3
def Is_FLUX1(self):
return self.model_type == ModelType.FLUX1

View File

@ -70,6 +70,7 @@ def save_configuration(
learning_rate,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
train_batch_size,
epoch,
save_every_n_epochs,
@ -135,6 +136,7 @@ def save_configuration(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -151,17 +153,23 @@ def save_configuration(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
min_timestep,
max_timestep,
sdxl_no_half_vae,
@ -229,6 +237,7 @@ def open_configuration(
learning_rate,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
train_batch_size,
epoch,
save_every_n_epochs,
@ -294,6 +303,7 @@ def open_configuration(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -310,17 +320,23 @@ def open_configuration(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
min_timestep,
max_timestep,
sdxl_no_half_vae,
@ -381,6 +397,7 @@ def train_model(
learning_rate,
lr_scheduler,
lr_warmup,
lr_warmup_steps,
train_batch_size,
epoch,
save_every_n_epochs,
@ -446,6 +463,7 @@ def train_model(
optimizer,
optimizer_args,
lr_scheduler_args,
lr_scheduler_type,
noise_offset_type,
noise_offset,
noise_offset_random_strength,
@ -462,17 +480,23 @@ def train_model(
loss_type,
huber_schedule,
huber_c,
huber_scale,
vae_batch_size,
min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
save_last_n_epochs,
save_last_n_epochs_state,
skip_cache_check,
log_with,
wandb_api_key,
wandb_run_name,
log_tracker_name,
log_tracker_config,
log_config,
scale_v_pred_loss_like_noise_pred,
disable_mmap_load_safetensors,
min_timestep,
max_timestep,
sdxl_no_half_vae,
@ -549,20 +573,6 @@ def train_model(
# End of path validation
#
# if not validate_paths(
# dataset_config=dataset_config,
# headless=headless,
# log_tracker_config=log_tracker_config,
# logging_dir=logging_dir,
# output_dir=output_dir,
# pretrained_model_name_or_path=pretrained_model_name_or_path,
# reg_data_dir=reg_data_dir,
# resume=resume,
# train_data_dir=train_data_dir,
# vae=vae,
# ):
# return TRAIN_BUTTON_VISIBLE
if token_string == "":
output_message(msg="Token string is missing", headless=headless)
return TRAIN_BUTTON_VISIBLE
@ -588,13 +598,6 @@ def train_model(
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
)
if lr_warmup != 0:
lr_warmup_steps = round(
float(int(lr_warmup) * int(max_train_steps) / 100)
)
else:
lr_warmup_steps = 0
else:
stop_text_encoder_training = 0
lr_warmup_steps = 0
@ -657,11 +660,11 @@ def train_model(
reg_factor = 1
else:
log.warning(
"Regularisation images are used... Will double the number of steps required..."
"Regularization images are used... Will double the number of steps required..."
)
reg_factor = 2
log.info(f"Regulatization factor: {reg_factor}")
log.info(f"Regularization factor: {reg_factor}")
if max_train_steps == 0:
# calculate max_train_steps
@ -689,13 +692,18 @@ def train_model(
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
)
if lr_warmup != 0:
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
else:
lr_warmup_steps = 0
log.info(f"Total steps: {total_steps}")
# Calculate lr_warmup_steps
if lr_warmup_steps > 0:
lr_warmup_steps = int(lr_warmup_steps)
if lr_warmup > 0:
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
elif lr_warmup != 0:
lr_warmup_steps = lr_warmup / 100
else:
lr_warmup_steps = 0
log.info(f"Train batch size: {train_batch_size}")
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
log.info(f"Epoch: {epoch}")
@ -757,6 +765,7 @@ def train_model(
"clip_skip": clip_skip if clip_skip != 0 else None,
"color_aug": color_aug,
"dataset_config": dataset_config,
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
"dynamo_backend": dynamo_backend,
"enable_bucket": enable_bucket,
"epoch": int(epoch),
@ -765,6 +774,7 @@ def train_model(
"gradient_accumulation_steps": int(gradient_accumulation_steps),
"gradient_checkpointing": gradient_checkpointing,
"huber_c": huber_c,
"huber_scale": huber_scale,
"huber_schedule": huber_schedule,
"huggingface_repo_id": huggingface_repo_id,
"huggingface_token": huggingface_token,
@ -777,6 +787,7 @@ def train_model(
"keep_tokens": int(keep_tokens),
"learning_rate": learning_rate,
"logging_dir": logging_dir,
"log_config": log_config,
"log_tracker_name": log_tracker_name,
"log_tracker_config": log_tracker_config,
"loss_type": loss_type,
@ -786,6 +797,7 @@ def train_model(
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
"lr_warmup_steps": lr_warmup_steps,
"max_bucket_reso": max_bucket_reso,
"max_timestep": max_timestep if max_timestep != 0 else None,
@ -840,6 +852,10 @@ def train_model(
"save_last_n_steps_state": (
save_last_n_steps_state if save_last_n_steps_state != 0 else None
),
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
"save_last_n_epochs_state": (
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
),
"save_model_as": save_model_as,
"save_precision": save_precision,
"save_state": save_state,
@ -849,6 +865,7 @@ def train_model(
"sdpa": True if xformers == "sdpa" else None,
"seed": int(seed) if int(seed) != 0 else None,
"shuffle_caption": shuffle_caption,
"skip_cache_check": skip_cache_check,
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
),
@ -862,8 +879,8 @@ def train_model(
"vae": vae,
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
"wandb_api_key": wandb_api_key,
"wandb_run_name": wandb_run_name,
"weigts": weights,
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
"weights": weights,
"use_object_template": True if template == "object template" else None,
"use_style_template": True if template == "style template" else None,
"xformers": True if xformers == "xformers" else None,
@ -1130,6 +1147,7 @@ def ti_tab(
basic_training.learning_rate,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.lr_warmup_steps,
basic_training.train_batch_size,
basic_training.epoch,
basic_training.save_every_n_epochs,
@ -1194,6 +1212,7 @@ def ti_tab(
basic_training.optimizer,
basic_training.optimizer_args,
basic_training.lr_scheduler_args,
basic_training.lr_scheduler_type,
advanced_training.noise_offset_type,
advanced_training.noise_offset,
advanced_training.noise_offset_random_strength,
@ -1210,17 +1229,23 @@ def ti_tab(
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.huber_scale,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.save_every_n_steps,
advanced_training.save_last_n_steps,
advanced_training.save_last_n_steps_state,
advanced_training.save_last_n_epochs,
advanced_training.save_last_n_epochs_state,
advanced_training.skip_cache_check,
advanced_training.log_with,
advanced_training.wandb_api_key,
advanced_training.wandb_run_name,
advanced_training.log_tracker_name,
advanced_training.log_tracker_config,
advanced_training.log_config,
advanced_training.scale_v_pred_loss_like_noise_pred,
sdxl_params.disable_mmap_load_safetensors,
advanced_training.min_timestep,
advanced_training.max_timestep,
sdxl_params.sdxl_no_half_vae,
@ -1289,4 +1314,4 @@ def ti_tab(
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
)

View File

@ -0,0 +1,146 @@
{
"adaptive_noise_scale": 0,
"additional_parameters": "",
"async_upload": false,
"bucket_no_upscale": true,
"bucket_reso_steps": 64,
"cache_latents": true,
"cache_latents_to_disk": true,
"caption_dropout_every_n_epochs": 0,
"caption_dropout_rate": 0,
"caption_extension": ".txt",
"clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors",
"clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors",
"clip_skip": 1,
"color_aug": false,
"dataset_config": "",
"debiased_estimation_loss": false,
"disable_mmap_load_safetensors": false,
"dynamo_backend": "no",
"dynamo_mode": "default",
"dynamo_use_dynamic": false,
"dynamo_use_fullgraph": false,
"enable_bucket": true,
"epoch": 8,
"extra_accelerate_launch_args": "",
"flip_aug": false,
"full_bf16": false,
"full_fp16": false,
"fused_backward_pass": false,
"fused_optimizer_groups": 0,
"gpu_ids": "",
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"huber_c": 0.1,
"huber_schedule": "snr",
"huggingface_path_in_repo": "",
"huggingface_repo_id": "",
"huggingface_repo_type": "",
"huggingface_repo_visibility": "",
"huggingface_token": "",
"ip_noise_gamma": 0,
"ip_noise_gamma_random_strength": false,
"keep_tokens": 0,
"learning_rate": 5e-06,
"learning_rate_te": 0,
"learning_rate_te1": 1e-05,
"learning_rate_te2": 1e-05,
"log_config": false,
"log_tracker_config": "",
"log_tracker_name": "",
"log_with": "",
"logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3",
"logit_mean": 0,
"logit_std": 1,
"loss_type": "l2",
"lr_scheduler": "cosine",
"lr_scheduler_args": "",
"lr_scheduler_num_cycles": 1,
"lr_scheduler_power": 1,
"lr_scheduler_type": "",
"lr_warmup": 10,
"main_process_port": 0,
"masked_loss": false,
"max_bucket_reso": 1536,
"max_data_loader_n_workers": 0,
"max_resolution": "512,512",
"max_timestep": 1000,
"max_token_length": 225,
"max_train_epochs": 8,
"max_train_steps": 1600,
"mem_eff_attn": false,
"metadata_author": "",
"metadata_description": "",
"metadata_license": "",
"metadata_tags": "",
"metadata_title": "",
"min_bucket_reso": 256,
"min_snr_gamma": 0,
"min_timestep": 0,
"mixed_precision": "bf16",
"mode_scale": 1.29,
"model_list": "custom",
"multi_gpu": false,
"multires_noise_discount": 0.3,
"multires_noise_iterations": 0,
"no_token_padding": false,
"noise_offset": 0,
"noise_offset_random_strength": false,
"noise_offset_type": "Original",
"num_cpu_threads_per_process": 2,
"num_machines": 1,
"num_processes": 1,
"optimizer": "PagedAdamW8bit",
"optimizer_args": "weight_decay=0.1 betas=.9,.95",
"output_dir": "E:/models/sd3",
"output_name": "sd3",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors",
"prior_loss_weight": 1,
"random_crop": false,
"reg_data_dir": "",
"resume": "",
"resume_from_huggingface": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 0,
"sample_prompts": "",
"sample_sampler": "euler_a",
"save_as_bool": false,
"save_clip": false,
"save_every_n_epochs": 0,
"save_every_n_steps": 0,
"save_last_n_steps": 0,
"save_last_n_steps_state": 0,
"save_model_as": "safetensors",
"save_precision": "fp16",
"save_state": false,
"save_state_on_train_end": false,
"save_state_to_huggingface": false,
"save_t5xxl": false,
"scale_v_pred_loss_like_noise_pred": false,
"sd3_cache_text_encoder_outputs": true,
"sd3_cache_text_encoder_outputs_to_disk": true,
"sd3_checkbox": true,
"sd3_text_encoder_batch_size": 1,
"sdxl": false,
"sdxl_cache_text_encoder_outputs": false,
"sdxl_no_half_vae": false,
"seed": 1026,
"shuffle_caption": false,
"stop_text_encoder_training": 0,
"t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors",
"t5xxl_device": "",
"t5xxl_dtype": "bf16",
"train_batch_size": 1,
"train_data_dir": "C:/Users/berna/Downloads/martini/img2",
"v2": false,
"v_parameterization": false,
"v_pred_like_loss": 0,
"vae": "",
"vae_batch_size": 0,
"wandb_api_key": "",
"wandb_run_name": "",
"weighted_captions": false,
"weighting_scheme": "logit_normal",
"xformers": "sdpa"
}

View File

@ -0,0 +1,146 @@
{
"adaptive_noise_scale": 0,
"additional_parameters": "",
"async_upload": false,
"bucket_no_upscale": true,
"bucket_reso_steps": 64,
"cache_latents": true,
"cache_latents_to_disk": true,
"caption_dropout_every_n_epochs": 0,
"caption_dropout_rate": 0,
"caption_extension": ".txt",
"clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors",
"clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors",
"clip_skip": 1,
"color_aug": false,
"dataset_config": "",
"debiased_estimation_loss": false,
"disable_mmap_load_safetensors": false,
"dynamo_backend": "no",
"dynamo_mode": "default",
"dynamo_use_dynamic": false,
"dynamo_use_fullgraph": false,
"enable_bucket": true,
"epoch": 8,
"extra_accelerate_launch_args": "",
"flip_aug": false,
"full_bf16": false,
"full_fp16": false,
"fused_backward_pass": false,
"fused_optimizer_groups": 0,
"gpu_ids": "",
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"huber_c": 0.1,
"huber_schedule": "snr",
"huggingface_path_in_repo": "",
"huggingface_repo_id": "",
"huggingface_repo_type": "",
"huggingface_repo_visibility": "",
"huggingface_token": "",
"ip_noise_gamma": 0,
"ip_noise_gamma_random_strength": false,
"keep_tokens": 0,
"learning_rate": 5e-06,
"learning_rate_te": 0,
"learning_rate_te1": 1e-05,
"learning_rate_te2": 1e-05,
"log_config": false,
"log_tracker_config": "",
"log_tracker_name": "",
"log_with": "",
"logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3",
"logit_mean": 0,
"logit_std": 1,
"loss_type": "l2",
"lr_scheduler": "cosine",
"lr_scheduler_args": "",
"lr_scheduler_num_cycles": 1,
"lr_scheduler_power": 1,
"lr_scheduler_type": "",
"lr_warmup": 10,
"main_process_port": 0,
"masked_loss": false,
"max_bucket_reso": 1536,
"max_data_loader_n_workers": 0,
"max_resolution": "512,512",
"max_timestep": 1000,
"max_token_length": 150,
"max_train_epochs": 8,
"max_train_steps": 1600,
"mem_eff_attn": false,
"metadata_author": "",
"metadata_description": "",
"metadata_license": "",
"metadata_tags": "",
"metadata_title": "",
"min_bucket_reso": 256,
"min_snr_gamma": 0,
"min_timestep": 0,
"mixed_precision": "bf16",
"mode_scale": 1.29,
"model_list": "custom",
"multi_gpu": false,
"multires_noise_discount": 0.3,
"multires_noise_iterations": 0,
"no_token_padding": false,
"noise_offset": 0,
"noise_offset_random_strength": false,
"noise_offset_type": "Original",
"num_cpu_threads_per_process": 2,
"num_machines": 1,
"num_processes": 1,
"optimizer": "PagedAdamW8bit",
"optimizer_args": "weight_decay=0.1 betas=.9,.95",
"output_dir": "E:/models/sd3",
"output_name": "sd3_v2",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors",
"prior_loss_weight": 1,
"random_crop": false,
"reg_data_dir": "",
"resume": "",
"resume_from_huggingface": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 0,
"sample_prompts": "",
"sample_sampler": "euler_a",
"save_as_bool": false,
"save_clip": false,
"save_every_n_epochs": 0,
"save_every_n_steps": 0,
"save_last_n_steps": 0,
"save_last_n_steps_state": 0,
"save_model_as": "safetensors",
"save_precision": "fp16",
"save_state": false,
"save_state_on_train_end": false,
"save_state_to_huggingface": false,
"save_t5xxl": false,
"scale_v_pred_loss_like_noise_pred": false,
"sd3_cache_text_encoder_outputs": true,
"sd3_cache_text_encoder_outputs_to_disk": true,
"sd3_checkbox": true,
"sd3_text_encoder_batch_size": 1,
"sdxl": false,
"sdxl_cache_text_encoder_outputs": false,
"sdxl_no_half_vae": false,
"seed": 1026,
"shuffle_caption": false,
"stop_text_encoder_training": 0,
"t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors",
"t5xxl_device": "",
"t5xxl_dtype": "bf16",
"train_batch_size": 1,
"train_data_dir": "C:/Users/berna/Downloads/martini/img",
"v2": false,
"v_parameterization": false,
"v_pred_like_loss": 0,
"vae": "",
"vae_batch_size": 0,
"wandb_api_key": "",
"wandb_run_name": "",
"weighted_captions": false,
"weighting_scheme": "logit_normal",
"xformers": "sdpa"
}

View File

@ -0,0 +1,182 @@
{
"LoRA_type": "Flux1",
"LyCORIS_preset": "full",
"adaptive_noise_scale": 0,
"additional_parameters": "",
"ae": "put the full path to ae.sft here",
"apply_t5_attn_mask": true,
"async_upload": false,
"block_alphas": "",
"block_dims": "",
"block_lr_zero_threshold": "",
"bucket_no_upscale": true,
"bucket_reso_steps": 64,
"bypass_mode": false,
"cache_latents": true,
"cache_latents_to_disk": true,
"caption_dropout_every_n_epochs": 0,
"caption_dropout_rate": 0,
"caption_extension": ".txt",
"clip_l": "put the full path to clip_l.safetensors here",
"clip_skip": 1,
"color_aug": false,
"constrain": 0,
"conv_alpha": 1,
"conv_block_alphas": "",
"conv_block_dims": "",
"conv_dim": 1,
"dataset_config": "",
"debiased_estimation_loss": false,
"decompose_both": false,
"dim_from_weights": false,
"discrete_flow_shift": 3,
"dora_wd": false,
"down_lr_weight": "",
"dynamo_backend": "no",
"dynamo_mode": "default",
"dynamo_use_dynamic": false,
"dynamo_use_fullgraph": false,
"enable_bucket": true,
"epoch": 1,
"extra_accelerate_launch_args": "",
"factor": -1,
"flip_aug": false,
"flux1_cache_text_encoder_outputs": true,
"flux1_cache_text_encoder_outputs_to_disk": true,
"flux1_checkbox": true,
"fp8_base": true,
"full_bf16": true,
"full_fp16": false,
"gpu_ids": "",
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"guidance_scale": 1,
"highvram": false,
"huber_c": 0.1,
"huber_schedule": "snr",
"huggingface_path_in_repo": "",
"huggingface_repo_id": "",
"huggingface_repo_type": "",
"huggingface_repo_visibility": "",
"huggingface_token": "",
"ip_noise_gamma": 0,
"ip_noise_gamma_random_strength": false,
"keep_tokens": 0,
"learning_rate": 0.0003,
"log_config": false,
"log_tracker_config": "",
"log_tracker_name": "",
"log_with": "",
"logging_dir": "./test/logs-saruman",
"loraplus_lr_ratio": 0,
"loraplus_text_encoder_lr_ratio": 0,
"loraplus_unet_lr_ratio": 0,
"loss_type": "l2",
"lowvram": false,
"lr_scheduler": "constant",
"lr_scheduler_args": "",
"lr_scheduler_num_cycles": 1,
"lr_scheduler_power": 1,
"lr_scheduler_type": "",
"lr_warmup": 0,
"main_process_port": 0,
"masked_loss": false,
"max_bucket_reso": 2048,
"max_data_loader_n_workers": 0,
"max_grad_norm": 1,
"max_resolution": "512,512",
"max_timestep": 1000,
"max_token_length": 75,
"max_train_epochs": 0,
"max_train_steps": 1000,
"mem_eff_attn": false,
"mem_eff_save": false,
"metadata_author": "",
"metadata_description": "",
"metadata_license": "",
"metadata_tags": "",
"metadata_title": "",
"mid_lr_weight": "",
"min_bucket_reso": 256,
"min_snr_gamma": 7,
"min_timestep": 0,
"mixed_precision": "bf16",
"model_list": "custom",
"model_prediction_type": "raw",
"module_dropout": 0,
"multi_gpu": false,
"multires_noise_discount": 0.3,
"multires_noise_iterations": 0,
"network_alpha": 16,
"network_dim": 16,
"network_dropout": 0,
"network_weights": "",
"noise_offset": 0.05,
"noise_offset_random_strength": false,
"noise_offset_type": "Original",
"num_cpu_threads_per_process": 2,
"num_machines": 1,
"num_processes": 1,
"optimizer": "AdamW8bit",
"optimizer_args": "",
"output_dir": "put the full path to output folder here",
"output_name": "Flux.my-super-duper-model-name-goes-here-v1.0",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "put the full path to flux1-dev.safetensors here",
"prior_loss_weight": 1,
"random_crop": false,
"rank_dropout": 0,
"rank_dropout_scale": false,
"reg_data_dir": "",
"rescaled": false,
"resume": "",
"resume_from_huggingface": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 0,
"sample_prompts": "saruman posing under a stormy lightning sky, photorealistic --w 832 --h 1216 --s 20 --l 4 --d 42",
"sample_sampler": "euler",
"save_as_bool": false,
"save_every_n_epochs": 1,
"save_every_n_steps": 50,
"save_last_n_steps": 0,
"save_last_n_steps_state": 0,
"save_model_as": "safetensors",
"save_precision": "bf16",
"save_state": false,
"save_state_on_train_end": false,
"save_state_to_huggingface": false,
"scale_v_pred_loss_like_noise_pred": false,
"scale_weight_norms": 0,
"sdxl": false,
"sdxl_cache_text_encoder_outputs": true,
"sdxl_no_half_vae": true,
"seed": 42,
"shuffle_caption": false,
"split_mode": false,
"stop_text_encoder_training": 0,
"t5xxl": "put the full path to the file here. Use the fp16 version",
"t5xxl_max_token_length": 512,
"text_encoder_lr": 0,
"timestep_sampling": "sigmoid",
"train_batch_size": 1,
"train_blocks": "all",
"train_data_dir": "put your image folder here",
"train_norm": false,
"train_on_input": true,
"training_comment": "",
"unet_lr": 0.0003,
"unit": 1,
"up_lr_weight": "",
"use_cp": false,
"use_scalar": false,
"use_tucker": false,
"v2": false,
"v_parameterization": false,
"v_pred_like_loss": 0,
"vae": "",
"vae_batch_size": 0,
"wandb_api_key": "",
"wandb_run_name": "",
"weighted_captions": false,
"xformers": "sdpa"
}

View File

@ -1,35 +1,38 @@
accelerate==0.25.0
accelerate==0.33.0
aiofiles==23.2.1
altair==4.2.2
dadaptation==3.1
dadaptation==3.2
diffusers[torch]==0.25.0
easygui==0.98.3
einops==0.7.0
fairscale==0.4.13
ftfy==6.1.1
gradio==4.43.0
huggingface-hub==0.20.1
gradio==5.23.1
huggingface-hub==0.29.3
imagesize==1.4.1
invisible-watermark==0.2.0
lion-pytorch==0.0.6
lycoris_lora==2.2.0.post3
lycoris_lora==3.1.0
omegaconf==2.3.0
onnx==1.16.1
prodigyopt==1.0
prodigyopt==1.1.2
protobuf==3.20.3
open-clip-torch==2.20.0
opencv-python==4.7.0.68
prodigyopt==1.0
opencv-python==4.10.0.84
prodigy-plus-schedule-free==1.8.0
pytorch-lightning==1.9.0
pytorch-optimizer==3.5.0
rich>=13.7.1
safetensors==0.4.2
safetensors==0.4.4
schedulefree==1.4
scipy==1.11.4
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
timm==0.6.12
tk==0.1.0
toml==0.10.2
transformers==4.38.0
transformers==4.44.2
voluptuous==0.13.1
wandb==0.15.11
scipy==1.11.4
# for kohya_ss library
-e ./sd-scripts # no_verify leave this to specify not checking this a verification stage
wandb==0.18.0
# for kohya_ss sd-scripts library
-e ./sd-scripts

View File

@ -1,5 +1,13 @@
torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
bitsandbytes==0.43.0
tensorboard==2.15.2 tensorflow==2.15.0.post1
onnxruntime-gpu==1.17.1
# Custom index URL for specific packages
--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.5.0+cu124
torchvision==0.20.0+cu124
xformers==0.0.28.post2
bitsandbytes==0.44.0
tensorboard==2.15.2
tensorflow==2.15.0.post1
onnxruntime-gpu==1.19.2
-r requirements.txt

View File

@ -1,4 +1,4 @@
xformers>=0.0.20
bitsandbytes==0.43.0
accelerate==0.25.0
tensorboard
bitsandbytes==0.44.0
accelerate==0.33.0
tensorboard

View File

@ -1,5 +1,17 @@
torch==2.1.0.post0+cxx11.abi torchvision==0.16.0.post0+cxx11.abi intel-extension-for-pytorch==2.1.20+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
tensorboard==2.15.2 tensorflow==2.15.0 intel-extension-for-tensorflow[xpu]==2.15.0.0
mkl==2024.1.0 mkl-dpcpp==2024.1.0 oneccl-devel==2021.12.0 impi-devel==2021.12.0
onnxruntime-openvino==1.17.1
# Custom index URL for specific packages
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
torch==2.1.0.post3+cxx11.abi
torchvision==0.16.0.post3+cxx11.abi
intel-extension-for-pytorch==2.1.40+xpu
oneccl_bind_pt==2.1.400+xpu
tensorflow==2.15.1
intel-extension-for-tensorflow[xpu]==2.15.0.1
mkl==2024.2.0
mkl-dpcpp==2024.2.0
oneccl-devel==2021.13.0
impi-devel==2021.13.0
onnxruntime-openvino==1.18.0
-r requirements.txt

View File

@ -1,4 +1,13 @@
torch==2.3.0+rocm6.0 torchvision==0.18.0+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0
tensorboard==2.14.1 tensorflow-rocm==2.14.0.600
onnxruntime-training --pre --index-url https://pypi.lsh.sh/60/ --extra-index-url https://pypi.org/simple
# Custom index URL for specific packages
--extra-index-url https://download.pytorch.org/whl/rocm6.1
torch==2.5.0+rocm6.1
torchvision==0.20.0+rocm6.1
tensorboard==2.14.1
tensorflow-rocm==2.14.0.600
# Custom index URL for specific packages
--extra-index-url https://pypi.lsh.sh/60/
onnxruntime-training --pre
-r requirements.txt

View File

@ -1,5 +1,5 @@
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
xformers bitsandbytes==0.41.1
xformers bitsandbytes==0.43.3
tensorflow-macos tensorboard==2.14.1
onnxruntime==1.17.1
-r requirements.txt

View File

@ -1,5 +1,5 @@
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
xformers bitsandbytes==0.41.1
xformers bitsandbytes==0.43.3
tensorflow-macos tensorflow-metal tensorboard==2.14.1
onnxruntime==1.17.1
-r requirements.txt

View File

@ -1,3 +1,8 @@
torch==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118
torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118
xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118
# Custom index URL for specific packages
--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.5.0+cu124
torchvision==0.20.0+cu124
xformers==0.0.28.post2
-r requirements_windows.txt

View File

@ -1,6 +1,13 @@
torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
bitsandbytes==0.43.0
tensorboard==2.14.1 tensorflow==2.14.0 wheel
--extra-index-url https://download.pytorch.org/whl/cu124
torch==2.5.0+cu124
torchvision==0.20.0+cu124
xformers==0.0.28.post2
bitsandbytes==0.44.0
tensorboard==2.14.1
tensorflow==2.14.0
wheel
tensorrt
onnxruntime-gpu==1.17.1
onnxruntime-gpu==1.19.2
-r requirements.txt

View File

@ -1,5 +1,6 @@
bitsandbytes==0.43.0
bitsandbytes==0.44.0
tensorboard
tensorflow>=2.16.1
onnxruntime-gpu==1.17.1
onnxruntime-gpu==1.19.2
-r requirements.txt

@ -1 +1 @@
Subproject commit b8896aad400222c8c4441b217fda0f9bb0807ffd
Subproject commit 8ebe858f896340d698f03fc33d99ca010131320a

View File

@ -2,7 +2,7 @@
IF NOT EXIST venv (
echo Creating venv...
py -3.10 -m venv venv
py -3.10.11 -m venv venv
)
:: Create the directory if it doesn't exist
@ -13,6 +13,9 @@ call .\venv\Scripts\deactivate.bat
call .\venv\Scripts\activate.bat
REM first make sure we have setuptools available in the venv
python -m pip install --require-virtualenv --no-input -q -q setuptools
REM Check if the batch was started via double-click
IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" (
REM echo This script was started by double clicking.

View File

@ -1,4 +1,5 @@
#!/usr/bin/env bash
cd "$(dirname "$0")"
# Function to display help information
display_help() {
@ -23,6 +24,7 @@ Options:
-i, --interactive Interactively configure accelerate instead of using default config file.
-n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations.
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
-q, --quiet Suppress all output except errors.
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
-s, --skip-space-check Skip the 10Gb minimum storage space check.
-u, --no-gui Skips launching the GUI.
@ -91,6 +93,7 @@ PARENT_DIR=""
VENV_DIR=""
USE_IPEX=false
USE_ROCM=false
QUIET="--show_stdout"
# Function to get the distro name
get_distro_name() {
@ -206,20 +209,20 @@ install_python_dependencies() {
case "$OSTYPE" in
"lin"*)
if [ "$RUNPOD" = true ]; then
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt $QUIET
elif [ "$USE_IPEX" = true ]; then
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt $QUIET
elif [ "$USE_ROCM" = true ] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt $QUIET
else
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt $QUIET
fi
;;
"darwin"*)
if [[ "$(uname -m)" == "arm64" ]]; then
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_arm64.txt
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_arm64.txt $QUIET
else
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_amd64.txt
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_amd64.txt $QUIET
fi
;;
esac
@ -307,7 +310,7 @@ update_kohya_ss() {
# Section: Command-line options parsing
while getopts ":vb:d:g:inprus-:" opt; do
while getopts ":vb:d:g:inpqrus-:" opt; do
# support long options: https://stackoverflow.com/a/28466267/519360
if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG
opt="${OPTARG%%=*}" # extract long option name
@ -322,6 +325,7 @@ while getopts ":vb:d:g:inprus-:" opt; do
i | interactive) INTERACTIVE=true ;;
n | no-git-update) SKIP_GIT_UPDATE=true ;;
p | public) PUBLIC=true ;;
q | quiet) QUIET="" ;;
r | runpod) RUNPOD=true ;;
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
u | no-gui) SKIP_GUI=true ;;

View File

@ -1,363 +1,321 @@
import subprocess
import os
import re
import sys
import logging
import shutil
import datetime
import subprocess
import re
import pkg_resources
errors = 0 # Define the 'errors' variable before using it
log = logging.getLogger('sd')
log = logging.getLogger("sd")
# Constants
MIN_PYTHON_VERSION = (3, 10, 9)
MAX_PYTHON_VERSION = (3, 13, 0)
LOG_DIR = "../logs/setup/"
LOG_LEVEL = "INFO" # Set to "INFO" or "WARNING" for less verbose logging
def check_python_version():
"""
Check if the current Python version is within the acceptable range.
Returns:
bool: True if the current Python version is valid, False otherwise.
bool: True if the current Python version is valid, False otherwise.
"""
min_version = (3, 10, 9)
max_version = (3, 11, 0)
from packaging import version
log.debug("Checking Python version...")
try:
current_version = sys.version_info
log.info(f"Python version is {sys.version}")
if not (min_version <= current_version < max_version):
log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI")
log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0")
if not (MIN_PYTHON_VERSION <= current_version < MAX_PYTHON_VERSION):
log.error(
f"The current version of python ({sys.version}) is not supported."
)
log.error("The Python version must be >= 3.10.9 and < 3.13.0.")
return False
return True
except Exception as e:
log.error(f"Failed to verify Python version. Error: {e}")
return False
def update_submodule(quiet=True):
"""
Ensure the submodule is initialized and updated.
This function uses the Git command line interface to initialize and update
the specified submodule recursively. Errors during the Git operation
or if Git is not found are caught and logged.
Parameters:
- quiet: If True, suppresses the output of the Git command.
"""
log.debug("Updating submodule...")
git_command = ["git", "submodule", "update", "--init", "--recursive"]
if quiet:
git_command.append("--quiet")
try:
# Initialize and update the submodule
subprocess.run(git_command, check=True)
log.info("Submodule initialized and updated.")
except subprocess.CalledProcessError as e:
# Log the error if the Git operation fails
log.error(f"Error during Git operation: {e}")
except FileNotFoundError as e:
# Log the error if the file is not found
log.error(e)
# def read_tag_version_from_file(file_path):
# """
# Read the tag version from a given file.
# Parameters:
# - file_path: The path to the file containing the tag version.
# Returns:
# The tag version as a string.
# """
# with open(file_path, 'r') as file:
# # Read the first line and strip whitespace
# tag_version = file.readline().strip()
# return tag_version
def clone_or_checkout(repo_url, branch_or_tag, directory_name):
"""
Clone a repo or checkout a specific branch or tag if the repo already exists.
For branches, it updates to the latest version before checking out.
Suppresses detached HEAD advice for tags or specific commits.
Restores the original working directory after operations.
Parameters:
- repo_url: The URL of the Git repository.
- branch_or_tag: The name of the branch or tag to clone or checkout.
- directory_name: The name of the directory to clone into or where the repo already exists.
"""
original_dir = os.getcwd() # Store the original directory
log.debug(
f"Cloning or checking out repository: {repo_url}, branch/tag: {branch_or_tag}, directory: {directory_name}"
)
original_dir = os.getcwd()
try:
if not os.path.exists(directory_name):
# Directory does not exist, clone the repo quietly
# Construct the command as a string for logging
# run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}"
run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name]
# Log the command
log.debug(run_cmd)
# Run the command
process = subprocess.Popen(
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
output, error = process.communicate()
if error and not error.startswith("Note: switching to"):
log.warning(error)
else:
log.info(f"Successfully cloned sd-scripts {branch_or_tag}")
run_cmd = [
"git",
"clone",
"--branch",
branch_or_tag,
"--single-branch",
"--quiet",
repo_url,
directory_name,
]
log.debug(f"Cloning repository: {run_cmd}")
subprocess.run(run_cmd, check=True)
log.info(f"Successfully cloned {repo_url} ({branch_or_tag})")
else:
os.chdir(directory_name)
log.debug("Fetching all branches and tags...")
subprocess.run(["git", "fetch", "--all", "--quiet"], check=True)
subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True)
# Get the current branch or commit hash
current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode()
if current_branch_hash != tag_branch_hash:
run_cmd = f"git checkout {branch_or_tag} --quiet"
# Log the command
log.debug(run_cmd)
# Execute the checkout command
process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
output, error = process.communicate()
if error:
log.warning(error.decode())
else:
log.info(f"Checked out sd-scripts {branch_or_tag} successfully.")
subprocess.run(
["git", "config", "advice.detachedHead", "false"], check=True
)
current_branch_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
)
target_branch_hash = (
subprocess.check_output(["git", "rev-parse", branch_or_tag])
.strip()
.decode()
)
if current_branch_hash != target_branch_hash:
log.debug(f"Checking out branch/tag: {branch_or_tag}")
subprocess.run(
["git", "checkout", branch_or_tag, "--quiet"], check=True
)
log.info(f"Checked out {branch_or_tag} successfully.")
else:
log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.")
log.info(f"Already at required branch/tag: {branch_or_tag}")
except subprocess.CalledProcessError as e:
log.error(f"Error during Git operation: {e}")
finally:
os.chdir(original_dir) # Restore the original directory
os.chdir(original_dir)
# setup console and file logging
def setup_logging(clean=False):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#
def setup_logging():
"""
Set up logging to file and console.
"""
log.debug("Setting up logging...")
from rich.theme import Theme
from rich.logging import RichHandler
from rich.console import Console
from rich.pretty import install as pretty_install
from rich.traceback import install as traceback_install
console = Console(
log_time=True,
log_time_format='%H:%M:%S-%f',
theme=Theme(
{
'traceback.border': 'black',
'traceback.border.syntax_error': 'black',
'inspect.value.border': 'black',
}
),
log_time_format="%H:%M:%S-%f",
theme=Theme({"traceback.border": "black", "inspect.value.border": "black"}),
)
# logging.getLogger("urllib3").setLevel(logging.ERROR)
# logging.getLogger("httpx").setLevel(logging.ERROR)
current_datetime = datetime.datetime.now()
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S')
current_datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_file = os.path.join(
os.path.dirname(__file__),
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log',
os.path.dirname(__file__), f"{LOG_DIR}kohya_ss_gui_{current_datetime_str}.log"
)
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# Create directories if they don't exist
log_directory = os.path.dirname(log_file)
os.makedirs(log_directory, exist_ok=True)
level = logging.INFO
logging.basicConfig(
level=logging.ERROR,
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s',
format="%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s",
filename=log_file,
filemode='a',
encoding='utf-8',
filemode="a",
encoding="utf-8",
force=True,
)
log.setLevel(
logging.DEBUG
) # log to file is always at level debug for facility `sd`
pretty_install(console=console)
traceback_install(
console=console,
extra_lines=1,
width=console.width,
word_wrap=False,
indent_guides=False,
suppress=[],
)
rh = RichHandler(
show_time=True,
omit_repeated_times=False,
show_level=True,
show_path=False,
markup=False,
rich_tracebacks=True,
log_time_format='%H:%M:%S-%f',
level=level,
console=console,
)
rh.set_name(level)
while log.hasHandlers() and len(log.handlers) > 0:
log.removeHandler(log.handlers[0])
log.addHandler(rh)
log_level = os.getenv("LOG_LEVEL", LOG_LEVEL).upper()
log.setLevel(getattr(logging, log_level, logging.DEBUG))
rich_handler = RichHandler(console=console)
# Replace existing handlers with the rich handler
log.handlers.clear()
log.addHandler(rich_handler)
log.debug("Logging setup complete.")
def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False):
def install_requirements_inbulk(
requirements_file, show_stdout=True, optional_parm="", upgrade=False
):
log.debug(f"Installing requirements in bulk from: {requirements_file}")
if not os.path.exists(requirements_file):
log.error(f'Could not find the requirements file in {requirements_file}.')
log.error(f"Could not find the requirements file in {requirements_file}.")
return
log.info(f'Installing requirements from {requirements_file}...')
log.info(f"Installing/Validating requirements from {requirements_file}...")
# Build the command as a list
cmd = ["pip", "install", "-r", requirements_file]
if upgrade:
optional_parm += " -U"
cmd.append("--upgrade")
if not show_stdout:
cmd.append("--quiet")
if optional_parm:
cmd.extend(optional_parm.split())
if show_stdout:
run_cmd(f'pip install -r {requirements_file} {optional_parm}')
else:
run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet')
log.info(f'Requirements from {requirements_file} installed.')
try:
# Run the command and filter output in real-time
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
)
for line in process.stdout:
if "Requirement already satisfied" not in line:
log.info(line.strip()) if show_stdout else None
# Capture and log any errors
_, stderr = process.communicate()
if process.returncode != 0:
log.error(f"Failed to install requirements: {stderr.strip()}")
except subprocess.CalledProcessError as e:
log.error(f"An error occurred while installing requirements: {e}")
def configure_accelerate(run_accelerate=False):
#
# This function was taken and adapted from code written by jstayco
#
log.debug("Configuring accelerate...")
from pathlib import Path
def env_var_exists(var_name):
return var_name in os.environ and os.environ[var_name] != ''
return var_name in os.environ and os.environ[var_name] != ""
log.info("Configuring accelerate...")
log.info('Configuring accelerate...')
source_accelerate_config_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'..',
'config_files',
'accelerate',
'default_config.yaml',
"..",
"config_files",
"accelerate",
"default_config.yaml",
)
if not os.path.exists(source_accelerate_config_file):
log.warning(
f"Could not find the accelerate configuration file in {source_accelerate_config_file}."
)
if run_accelerate:
run_cmd('accelerate config')
log.debug("Running accelerate configuration command...")
run_cmd([sys.executable, "-m", "accelerate", "config"])
else:
log.warning(
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.'
"Please configure accelerate manually by running the option in the menu."
)
log.debug(
f'Source accelerate config location: {source_accelerate_config_file}'
)
return
log.debug(f"Source accelerate config location: {source_accelerate_config_file}")
target_config_location = None
log.debug(
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, "
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, "
f"USERPROFILE: {os.environ.get('USERPROFILE')}"
)
if env_var_exists('HF_HOME'):
target_config_location = Path(
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml'
)
elif env_var_exists('LOCALAPPDATA'):
target_config_location = Path(
os.environ['LOCALAPPDATA'],
'huggingface',
'accelerate',
'default_config.yaml',
)
elif env_var_exists('USERPROFILE'):
target_config_location = Path(
os.environ['USERPROFILE'],
'.cache',
'huggingface',
'accelerate',
'default_config.yaml',
)
env_vars = {
"HF_HOME": Path(os.environ.get("HF_HOME", "")),
"LOCALAPPDATA": Path(
os.environ.get("LOCALAPPDATA", ""),
"huggingface",
"accelerate",
"default_config.yaml",
),
"USERPROFILE": Path(
os.environ.get("USERPROFILE", ""),
".cache",
"huggingface",
"accelerate",
"default_config.yaml",
),
}
log.debug(f'Target config location: {target_config_location}')
for var, path in env_vars.items():
if env_var_exists(var):
target_config_location = path
break
log.debug(f"Target config location: {target_config_location}")
if target_config_location:
if not target_config_location.is_file():
log.debug(
f"Creating target config directory: {target_config_location.parent}"
)
target_config_location.parent.mkdir(parents=True, exist_ok=True)
log.debug(
f'Target accelerate config location: {target_config_location}'
f"Copying config file to target location: {target_config_location}"
)
shutil.copyfile(
source_accelerate_config_file, target_config_location
)
log.info(
f'Copied accelerate config file to: {target_config_location}'
)
else:
if run_accelerate:
run_cmd('accelerate config')
else:
log.warning(
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
)
else:
if run_accelerate:
run_cmd('accelerate config')
shutil.copyfile(source_accelerate_config_file, target_config_location)
log.info(f"Copied accelerate config file to: {target_config_location}")
elif run_accelerate:
log.debug("Running accelerate configuration command...")
run_cmd([sys.executable, "-m", "accelerate", "config"])
else:
log.warning(
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
)
elif run_accelerate:
log.debug("Running accelerate configuration command...")
run_cmd([sys.executable, "-m", "accelerate", "config"])
else:
log.warning(
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
)
def check_torch():
log.debug("Checking Torch installation...")
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
# This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
#
# Check for toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
if shutil.which("nvidia-smi") is not None or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
os.environ.get("SystemRoot") or r"C:\Windows",
"System32",
"nvidia-smi.exe",
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
log.info("nVidia toolkit detected")
elif shutil.which("rocminfo") is not None or os.path.exists(
"/opt/rocm/bin/rocminfo"
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
log.info("AMD toolkit detected")
elif (
shutil.which("sycl-ls") is not None
or os.environ.get("ONEAPI_ROOT") is not None
or os.path.exists("/opt/intel/oneapi")
):
log.info("Intel OneAPI toolkit detected")
else:
log.info('Using CPU-only Torch')
log.info("Using CPU-only Torch")
try:
import torch
log.debug("Torch module imported successfully.")
try:
# Import IPEX / XPU support
import intel_extension_for_pytorch as ipex
except Exception:
pass
log.info(f'Torch {torch.__version__}')
log.debug("Intel extension for PyTorch imported successfully.")
except Exception as e:
log.warning(f"Failed to import intel_extension_for_pytorch: {e}")
log.info(f"Torch {torch.__version__}")
if torch.cuda.is_available():
if torch.version.cuda:
@ -367,33 +325,33 @@ def check_torch():
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
else:
log.warning('Unknown Torch backend')
log.warning("Unknown Torch backend")
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
f"Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}"
)
# Check if XPU is available
elif hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX version
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
for device in [
torch.xpu.device(i) for i in range(torch.xpu.device_count())
]:
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
f"Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}"
)
else:
log.warning('Torch reports GPU not available')
log.warning("Torch reports GPU not available")
return int(torch.__version__[0])
except Exception as e:
# log.warning(f'Could not load torch: {e}')
log.error(f"Could not load torch: {e}")
return 0
@ -404,17 +362,19 @@ def check_repo_version():
in the current directory. If the file exists, it reads the release version from the file and logs it.
If the file does not exist, it logs a debug message indicating that the release could not be read.
"""
if os.path.exists('.release'):
log.debug("Checking repository version...")
if os.path.exists(".release"):
try:
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
release= file.read()
log.info(f'Kohya_ss GUI version: {release}')
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
release = file.read()
log.info(f"Kohya_ss GUI version: {release}")
except Exception as e:
log.error(f'Could not read release: {e}')
log.error(f"Could not read release: {e}")
else:
log.debug('Could not read release...')
log.debug("Could not read release...")
# execute git command
def git(arg: str, folder: str = None, ignore: bool = False):
"""
@ -433,22 +393,31 @@ def git(arg: str, folder: str = None, ignore: bool = False):
If set to True, errors will not be logged.
Note:
This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
"""
git_cmd = os.environ.get('GIT', "git")
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
log.debug(f"Running git command: git {arg} in folder: {folder or '.'}")
result = subprocess.run(
["git", arg],
check=False,
shell=True,
env=os.environ,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=folder or ".",
)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
encoding="utf8", errors="ignore"
)
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors
errors += 1
log.error(f'Error running git: {folder} / {arg}')
if 'or stash them' in txt:
log.error(f'Local changes detected: check log for details...')
log.debug(f'Git output: {txt}')
log.error(f"Error running git: {folder} / {arg}")
if "or stash them" in txt:
log.error(f"Local changes detected: check log for details...")
log.debug(f"Git output: {txt}")
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
@ -473,31 +442,42 @@ def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool =
Returns:
- The output of the pip command as a string, or None if the 'show_stdout' flag is set.
"""
# arg = arg.replace('>=', '==')
log.debug(f"Running pip command: {arg}")
if not quiet:
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}')
log.debug(f"Running pip: {arg}")
log.info(
f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}'
)
pip_cmd = [rf"{sys.executable}", "-m", "pip"] + arg.split(" ")
log.debug(f"Running pip: {pip_cmd}")
if show_stdout:
subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ)
subprocess.run(pip_cmd, shell=False, check=False, env=os.environ)
else:
result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = subprocess.run(
pip_cmd,
shell=False,
check=False,
env=os.environ,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
encoding="utf8", errors="ignore"
)
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors # pylint: disable=global-statement
errors += 1
log.error(f'Error running pip: {arg}')
log.debug(f'Pip output: {txt}')
log.error(f"Error running pip: {arg}")
log.error(f"Pip output: {txt}")
return txt
def installed(package, friendly: str = None):
"""
Checks if the specified package(s) are installed with the correct version.
This function can handle package specifications with or without version constraints,
and can also filter out command-line options and URLs when a 'friendly' string is provided.
Parameters:
- package: A string that specifies one or more packages with optional version constraints.
- friendly: An optional string used to provide a cleaner version of the package string
@ -505,43 +485,39 @@ def installed(package, friendly: str = None):
Returns:
- True if all specified packages are installed with the correct versions, False otherwise.
Note:
This function was adapted from code written by vladimandic.
"""
log.debug(f"Checking if package is installed: {package}")
# Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version")
package = re.sub(r'\[.*?\]', '', package)
package = re.sub(r"\[.*?\]", "", package)
try:
if friendly:
# If a 'friendly' version of the package string is provided, split it into components
pkgs = friendly.split()
# Filter out command-line options and URLs from the package specification
pkgs = [
p
for p in package.split()
if not p.startswith('--') and "://" not in p
p for p in package.split() if not p.startswith("--") and "://" not in p
]
else:
# Split the package string into components, excluding '-' and '=' prefixed items
pkgs = [
p
for p in package.split()
if not p.startswith('-') and not p.startswith('=')
if not p.startswith("-") and not p.startswith("=")
]
# For each package component, extract the package name, excluding any URLs
pkgs = [
p.split('/')[-1] for p in pkgs
]
pkgs = [p.split("/")[-1] for p in pkgs]
for pkg in pkgs:
# Parse the package name and version based on the version specifier used
if '>=' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
elif '==' in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
if ">=" in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split(">=")]
elif "==" in pkg:
pkg_name, pkg_version = [x.strip() for x in pkg.split("==")]
else:
pkg_name, pkg_version = pkg.strip(), None
@ -552,38 +528,41 @@ def installed(package, friendly: str = None):
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
if spec is None:
# Try replacing underscores with dashes
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
spec = pkg_resources.working_set.by_key.get(
pkg_name.replace("_", "-"), None
)
if spec is not None:
# Package is found, check version
version = pkg_resources.get_distribution(pkg_name).version
log.debug(f'Package version found: {pkg_name} {version}')
log.debug(f"Package version found: {pkg_name} {version}")
if pkg_version is not None:
# Verify if the installed version meets the specified constraints
if '>=' in pkg:
if ">=" in pkg:
ok = version >= pkg_version
else:
ok = version == pkg_version
if not ok:
# Version mismatch, log warning and return False
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
log.warning(
f"Package wrong version: {pkg_name} {version} required {pkg_version}"
)
return False
else:
# Package not found, log debug message and return False
log.debug(f'Package version not found: {pkg_name}')
log.debug(f"Package version not found: {pkg_name}")
return False
# All specified packages are installed with the correct versions
return True
except ModuleNotFoundError:
# One or more packages are not installed, log debug message and return False
log.debug(f'Package not installed: {pkgs}')
log.debug(f"Package not installed: {pkgs}")
return False
# install package using pip if not already installed
def install(
package,
@ -595,7 +574,7 @@ def install(
"""
Installs or upgrades a Python package using pip, with options to ignode errors,
reinstall packages, and display outputs.
Parameters:
- package (str): The name of the package to be installed or upgraded. Can include
version specifiers. Anything after a '#' in the package name will be ignored.
@ -611,103 +590,98 @@ def install(
Returns:
None. The function performs operations that affect the environment but does not return
any value.
Note:
If `reinstall` is True, it disables any mechanism that allows for skipping installations
when the package is already present, forcing a fresh install.
"""
log.debug(f"Installing package: {package}")
# Remove anything after '#' in the package variable
package = package.split('#')[0].strip()
package = package.split("#")[0].strip()
if reinstall:
global quick_allowed # pylint: disable=global-statement
global quick_allowed # pylint: disable=global-statement
quick_allowed = False
if reinstall or not installed(package, friendly):
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout)
pip(f"install --upgrade {package}", ignore=ignore, show_stdout=show_stdout)
def process_requirements_line(line, show_stdout: bool = False):
log.debug(f"Processing requirements line: {line}")
# Remove brackets and their contents from the line using regular expressions
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
package_name = re.sub(r'\[.*?\]', '', line)
package_name = re.sub(r"\[.*?\]", "", line)
install(line, package_name, show_stdout=show_stdout)
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False):
if check_no_verify_flag:
log.info(f'Verifying modules installation status from {requirements_file}...')
else:
log.info(f'Installing modules from {requirements_file}...')
with open(requirements_file, 'r', encoding='utf8') as f:
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.'
if check_no_verify_flag:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith('#')
and line is not None
and 'no_verify' not in line
]
else:
lines = [
line.strip()
for line in f.readlines()
if line.strip() != ''
and not line.startswith('#')
and line is not None
]
def install_requirements(
requirements_file, check_no_verify_flag=False, show_stdout: bool = False
):
"""
Install or verify modules from a requirements file.
# Iterate over each line and install the requirements
for line in lines:
# Check if the line starts with '-r' to include another requirements file
if line.startswith('-r'):
# Get the path to the included requirements file
included_file = line[2:].strip()
# Expand the included requirements file recursively
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout)
else:
process_requirements_line(line, show_stdout=show_stdout)
Parameters:
- requirements_file (str): Path to the requirements file.
- check_no_verify_flag (bool): If True, verify modules installation status without installing.
- show_stdout (bool): If True, show the standard output of the installation process.
"""
log.debug(f"Installing requirements from file: {requirements_file}")
action = "Verifying" if check_no_verify_flag else "Installing"
log.info(f"{action} modules from {requirements_file}...")
with open(requirements_file, "r", encoding="utf8") as f:
lines = [
line.strip()
for line in f.readlines()
if line.strip() and not line.startswith("#") and "no_verify" not in line
]
for line in lines:
if line.startswith("-r"):
included_file = line[2:].strip()
log.debug(f"Processing included requirements file: {included_file}")
install_requirements(
included_file,
check_no_verify_flag=check_no_verify_flag,
show_stdout=show_stdout,
)
else:
process_requirements_line(line, show_stdout=show_stdout)
def ensure_base_requirements():
try:
import rich # pylint: disable=unused-import
import rich # pylint: disable=unused-import
except ImportError:
install('--upgrade rich', 'rich')
install("--upgrade rich", "rich")
try:
import packaging
except ImportError:
install('packaging')
install("packaging")
def run_cmd(run_cmd):
"""
Execute a command using subprocess.
"""
log.debug(f"Running command: {run_cmd}")
try:
subprocess.run(run_cmd, shell=True, check=False, env=os.environ)
subprocess.run(run_cmd, shell=True, check=True, env=os.environ)
log.debug(f"Command executed successfully: {run_cmd}")
except subprocess.CalledProcessError as e:
log.error(f'Error occurred while running command: {run_cmd}')
log.error(f'Error: {e}')
def delete_file(file_path):
if os.path.exists(file_path):
os.remove(file_path)
def write_to_file(file_path, content):
try:
with open(file_path, 'w') as file:
file.write(content)
except IOError as e:
print(f'Error occurred while writing to file: {file_path}')
print(f'Error: {e}')
log.error(f"Error occurred while running command: {run_cmd}")
log.error(f"Error: {e}")
def clear_screen():
# Check the current operating system to execute the correct clear screen command
if os.name == 'nt': # If the operating system is Windows
os.system('cls')
else: # If the operating system is Linux or Mac
os.system('clear')
"""
Clear the terminal screen.
"""
log.debug("Attempting to clear the terminal screen")
try:
os.system("cls" if os.name == "nt" else "clear")
log.info("Terminal screen cleared successfully")
except Exception as e:
log.error("Error occurred while clearing the terminal screen")
log.error(f"Error: {e}")

View File

@ -19,7 +19,10 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce
# Upgrade pip if needed
setup_common.install('pip')
setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout)
setup_common.install_requirements_inbulk(
platform_requirements_file, show_stdout=show_stdout,
)
# setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout)
if not no_run_accelerate:
setup_common.configure_accelerate(run_accelerate=False)
@ -31,10 +34,6 @@ if __name__ == '__main__':
exit(1)
setup_common.update_submodule()
# setup_common.clone_or_checkout(
# "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
# )
parser = argparse.ArgumentParser()
parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file')

View File

@ -54,7 +54,10 @@ def main_menu(platform_requirements_file):
# Upgrade pip if needed
setup_common.install('pip')
setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=True)
setup_common.install_requirements_inbulk(
platform_requirements_file, show_stdout=True,
)
configure_accelerate()

View File

@ -123,12 +123,13 @@ def install_kohya_ss_torch2(headless: bool = False):
# )
setup_common.install_requirements_inbulk(
"requirements_pytorch_windows.txt", show_stdout=True, optional_parm="--index-url https://download.pytorch.org/whl/cu118"
"requirements_pytorch_windows.txt", show_stdout=True,
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
)
setup_common.install_requirements_inbulk(
"requirements_windows.txt", show_stdout=True, upgrade=True
)
# setup_common.install_requirements_inbulk(
# "requirements_windows.txt", show_stdout=True, upgrade=True
# )
setup_common.run_cmd("accelerate config default")

View File

@ -5,12 +5,11 @@ import argparse
import setup_common
# Get the absolute path of the current file's directory (Kohua_SS project directory)
project_directory = os.path.dirname(os.path.abspath(__file__))
# Check if the "setup" directory is present in the project_directory
if "setup" in project_directory:
# If the "setup" directory is present, move one level up to the parent directory
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
project_directory = (
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if "setup" in os.path.dirname(os.path.abspath(__file__))
else os.path.dirname(os.path.abspath(__file__))
)
# Add the project directory to the beginning of the Python search path
sys.path.insert(0, project_directory)
@ -19,115 +18,178 @@ from kohya_gui.custom_logging import setup_logging
# Set up logging
log = setup_logging()
log.debug(f"Project directory set to: {project_directory}")
def check_path_with_space():
# Get the current working directory
"""Check if the current working directory contains a space."""
cwd = os.getcwd()
# Check if the current working directory contains a space
log.debug(f"Current working directory: {cwd}")
if " " in cwd:
log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
log.error("The current working directory is: " + cwd)
exit(1)
# Log an error if the current working directory contains spaces
log.error(
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
)
log.error(
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
)
log.error(f"The current working directory is: {cwd}")
raise RuntimeError("Invalid path: contains spaces.")
def check_torch():
# Check for toolkit
if shutil.which('nvidia-smi') is not None or os.path.exists(
def detect_toolkit():
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
log.debug("Detecting available toolkit...")
# Check for NVIDIA toolkit by looking for nvidia-smi executable
if shutil.which("nvidia-smi") or os.path.exists(
os.path.join(
os.environ.get('SystemRoot') or r'C:\Windows',
'System32',
'nvidia-smi.exe',
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
)
):
log.info('nVidia toolkit detected')
elif shutil.which('rocminfo') is not None or os.path.exists(
'/opt/rocm/bin/rocminfo'
log.debug("nVidia toolkit detected")
return "nVidia"
# Check for AMD toolkit by looking for rocminfo executable
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
log.debug("AMD toolkit detected")
return "AMD"
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
elif (
shutil.which("sycl-ls")
or os.environ.get("ONEAPI_ROOT")
or os.path.exists("/opt/intel/oneapi")
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
log.debug("Intel toolkit detected")
return "Intel"
# Default to CPU if no toolkit is detected
else:
log.info('Using CPU-only Torch')
log.debug("No specific GPU toolkit detected, defaulting to CPU")
return "CPU"
def check_torch():
"""Check if torch is available and log the relevant information."""
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
toolkit = detect_toolkit()
log.info(f"{toolkit} toolkit detected")
try:
# Import PyTorch
log.debug("Importing PyTorch...")
import torch
try:
# Import IPEX / XPU support
import intel_extension_for_pytorch as ipex
except Exception:
pass
log.info(f'Torch {torch.__version__}')
ipex = None
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
if toolkit == "Intel":
try:
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
import intel_extension_for_pytorch as ipex
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
except ImportError:
log.warning("Intel Extension for PyTorch (IPEX) not found.")
# Log the PyTorch version
log.info(f"Torch {torch.__version__}")
# Check if CUDA (NVIDIA GPU) is available
if torch.cuda.is_available():
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
else:
log.warning('Unknown Torch backend')
# Log information about detected GPUs
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
# Check if XPU is available
log.debug("CUDA is available, logging CUDA info...")
log_cuda_info(torch)
# Check if XPU (Intel GPU) is available
elif hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX version
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
for device in [
torch.xpu.device(i) for i in range(torch.xpu.device_count())
]:
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
)
log.debug("XPU is available, logging XPU info...")
log_xpu_info(torch, ipex)
# Log a warning if no GPU is available
else:
log.warning('Torch reports GPU not available')
log.warning("Torch reports GPU not available")
# Return the major version of PyTorch
return int(torch.__version__[0])
except Exception as e:
log.error(f'Could not load torch: {e}')
except ImportError as e:
# Log an error if PyTorch cannot be loaded
log.error(f"Could not load torch: {e}")
sys.exit(1)
except Exception as e:
# Log an unexpected error
log.error(f"Unexpected error while checking torch: {e}")
sys.exit(1)
def log_cuda_info(torch):
"""Log information about CUDA-enabled GPUs."""
# Log the CUDA and cuDNN versions if available
if torch.version.cuda:
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
# Log the ROCm HIP version if using AMD GPU
elif torch.version.hip:
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
else:
log.warning("Unknown Torch backend")
# Log information about each detected CUDA-enabled GPU
for device in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(device)
log.info(
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
)
def log_xpu_info(torch, ipex):
"""Log information about Intel XPU-enabled GPUs."""
# Log the Intel Extension for PyTorch (IPEX) version if available
if ipex:
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
# Log information about each detected XPU-enabled GPU
for device in range(torch.xpu.device_count()):
props = torch.xpu.get_device_properties(device)
log.info(
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
)
def main():
# Check the repository version to ensure compatibility
log.debug("Checking repository version...")
setup_common.check_repo_version()
# Check if the current path contains spaces, which are not supported
log.debug("Checking if the current path contains spaces...")
check_path_with_space()
# Parse command line arguments
log.debug("Parsing command line arguments...")
parser = argparse.ArgumentParser(
description='Validate that requirements are satisfied.'
description="Validate that requirements are satisfied."
)
parser.add_argument(
'-r',
'--requirements',
type=str,
help='Path to the requirements file.',
"-r", "--requirements", type=str, help="Path to the requirements file."
)
parser.add_argument('--debug', action='store_true', help='Debug on')
parser.add_argument("--debug", action="store_true", help="Debug on")
args = parser.parse_args()
# Update git submodules if necessary
log.debug("Updating git submodules...")
setup_common.update_submodule()
# Check if PyTorch is installed and log relevant information
log.debug("Checking if PyTorch is installed...")
torch_ver = check_torch()
if not setup_common.check_python_version():
exit(1)
if args.requirements:
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
else:
setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)
if __name__ == '__main__':
# Check if the Python version is compatible
log.debug("Checking Python version...")
if not setup_common.check_python_version():
sys.exit(1)
# Install required packages from the specified requirements file
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
log.debug(f"Installing requirements from: {requirements_file}")
setup_common.install_requirements_inbulk(
requirements_file, show_stdout=True,
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
)
# setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
# log.debug("Installing additional requirements from: requirements_windows.txt")
# setup_common.install_requirements(
# "requirements_windows.txt", check_no_verify_flag=True
# )
if __name__ == "__main__":
log.debug("Starting main function...")
main()
log.debug("Main function finished.")

View File

@ -0,0 +1,125 @@
{
"adaptive_noise_scale": 0,
"additional_parameters": "",
"async_upload": false,
"bucket_no_upscale": true,
"bucket_reso_steps": 1,
"cache_latents": true,
"cache_latents_to_disk": false,
"caption_dropout_every_n_epochs": 0,
"caption_dropout_rate": 0.05,
"caption_extension": "",
"clip_skip": 2,
"color_aug": false,
"dataset_config": "",
"dynamo_backend": "no",
"dynamo_mode": "default",
"dynamo_use_dynamic": false,
"dynamo_use_fullgraph": false,
"enable_bucket": true,
"epoch": 8,
"extra_accelerate_launch_args": "",
"flip_aug": false,
"full_fp16": false,
"gpu_ids": "",
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"huber_c": 0.1,
"huber_schedule": "snr",
"huggingface_path_in_repo": "",
"huggingface_repo_id": "False",
"huggingface_repo_type": "",
"huggingface_repo_visibility": "",
"huggingface_token": "",
"init_word": "*",
"ip_noise_gamma": 0.1,
"ip_noise_gamma_random_strength": true,
"keep_tokens": 0,
"learning_rate": 0.0001,
"log_config": false,
"log_tracker_config": "",
"log_tracker_name": "",
"log_with": "",
"logging_dir": "./test/logs",
"loss_type": "l2",
"lr_scheduler": "cosine",
"lr_scheduler_args": "",
"lr_scheduler_num_cycles": 1,
"lr_scheduler_power": 1,
"lr_scheduler_type": "",
"lr_warmup": 0,
"main_process_port": 0,
"max_bucket_reso": 2048,
"max_data_loader_n_workers": 0,
"max_resolution": "1024,1024",
"max_timestep": 0,
"max_token_length": 75,
"max_train_epochs": 0,
"max_train_steps": 0,
"mem_eff_attn": false,
"metadata_author": "False",
"metadata_description": "",
"metadata_license": "",
"metadata_tags": "",
"metadata_title": "",
"min_bucket_reso": 256,
"min_snr_gamma": 10,
"min_timestep": false,
"mixed_precision": "bf16",
"model_list": "custom",
"multi_gpu": false,
"multires_noise_discount": 0.2,
"multires_noise_iterations": 8,
"no_token_padding": false,
"noise_offset": 0.05,
"noise_offset_random_strength": true,
"noise_offset_type": "Original",
"num_cpu_threads_per_process": 2,
"num_machines": 1,
"num_processes": 1,
"num_vectors_per_token": 8,
"optimizer": "AdamW8bit",
"optimizer_args": "",
"output_dir": "./test/output",
"output_name": "TI-Adamw8bit-SDXL",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
"prior_loss_weight": 1,
"random_crop": false,
"reg_data_dir": "",
"resume": "",
"resume_from_huggingface": "False",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 20,
"sample_prompts": "a painting of man wearing a gas mask , by darius kawasaki",
"sample_sampler": "euler_a",
"save_as_bool": false,
"save_every_n_epochs": 1,
"save_every_n_steps": 0,
"save_last_n_steps": 0,
"save_last_n_steps_state": 0,
"save_model_as": "safetensors",
"save_precision": "fp16",
"save_state": false,
"save_state_on_train_end": false,
"save_state_to_huggingface": false,
"scale_v_pred_loss_like_noise_pred": false,
"sdxl": true,
"sdxl_no_half_vae": true,
"seed": 1234,
"shuffle_caption": false,
"stop_text_encoder_training": 0,
"template": "style template",
"token_string": "zxc",
"train_batch_size": 4,
"train_data_dir": "./test/img",
"v2": false,
"v_parameterization": false,
"v_pred_like_loss": 0,
"vae": "",
"vae_batch_size": 0,
"wandb_api_key": "",
"wandb_run_name": "",
"weights": "",
"xformers": "xformers"
}

View File

@ -0,0 +1,40 @@
[general]
# define common settings here
flip_aug = true
color_aug = false
keep_tokens_separator= "|||"
shuffle_caption = false
caption_tag_dropout_rate = 0
caption_extension = ".txt"
min_bucket_reso = 64
max_bucket_reso = 2048
[[datasets]]
# define the first resolution here
batch_size = 1
enable_bucket = true
resolution = [1024, 1024]
[[datasets.subsets]]
image_dir = "./test/img/10_darius kawasaki person"
num_repeats = 10
[[datasets]]
# define the second resolution here
batch_size = 1
enable_bucket = true
resolution = [768, 768]
[[datasets.subsets]]
image_dir = "./test/img/10_darius kawasaki person"
num_repeats = 10
[[datasets]]
# define the third resolution here
batch_size = 1
enable_bucket = true
resolution = [512, 512]
[[datasets.subsets]]
image_dir = "./test/img/10_darius kawasaki person"
num_repeats = 10

View File

@ -1,49 +1,75 @@
{
"adaptive_noise_scale": 0,
"additional_parameters": "",
"async_upload": false,
"bucket_no_upscale": true,
"bucket_reso_steps": 64,
"cache_latents": true,
"cache_latents_to_disk": false,
"caption_dropout_every_n_epochs": 0.0,
"caption_dropout_every_n_epochs": 0,
"caption_dropout_rate": 0.05,
"caption_extension": "",
"clip_skip": 2,
"color_aug": false,
"dataset_config": "./test/config/dataset.toml",
"debiased_estimation_loss": false,
"disable_mmap_load_safetensors": false,
"dynamo_backend": "no",
"dynamo_mode": "default",
"dynamo_use_dynamic": false,
"dynamo_use_fullgraph": false,
"enable_bucket": true,
"epoch": 1,
"extra_accelerate_launch_args": "",
"flip_aug": false,
"full_bf16": false,
"full_fp16": false,
"fused_backward_pass": false,
"fused_optimizer_groups": 0,
"gpu_ids": "",
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"huber_c": 0.1,
"huber_schedule": "snr",
"huggingface_path_in_repo": "",
"huggingface_repo_id": "",
"huggingface_repo_type": "",
"huggingface_repo_visibility": "",
"huggingface_token": "",
"ip_noise_gamma": 0,
"ip_noise_gamma_random_strength": false,
"keep_tokens": "0",
"keep_tokens": 0,
"learning_rate": 5e-05,
"learning_rate_te": 1e-05,
"learning_rate_te1": 1e-05,
"learning_rate_te2": 1e-05,
"log_config": false,
"log_tracker_config": "",
"log_tracker_name": "",
"log_with": "",
"logging_dir": "./test/logs",
"loss_type": "l2",
"lr_scheduler": "constant",
"lr_scheduler_args": "",
"lr_scheduler_num_cycles": "",
"lr_scheduler_power": "",
"lr_scheduler_args": "T_max=100",
"lr_scheduler_num_cycles": 1,
"lr_scheduler_power": 1,
"lr_scheduler_type": "CosineAnnealingLR",
"lr_warmup": 0,
"main_process_port": 12345,
"masked_loss": false,
"max_bucket_reso": 2048,
"max_data_loader_n_workers": "0",
"max_data_loader_n_workers": 0,
"max_resolution": "512,512",
"max_timestep": 1000,
"max_token_length": "75",
"max_train_epochs": "",
"max_train_steps": "",
"max_token_length": 75,
"max_train_epochs": 0,
"max_train_steps": 0,
"mem_eff_attn": false,
"metadata_author": "",
"metadata_description": "",
"metadata_license": "",
"metadata_tags": "",
"metadata_title": "",
"min_bucket_reso": 256,
"min_snr_gamma": 0,
"min_timestep": 0,
@ -65,14 +91,16 @@
"output_name": "db-AdamW8bit-toml",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"prior_loss_weight": 1.0,
"prior_loss_weight": 1,
"random_crop": false,
"reg_data_dir": "",
"resume": "",
"resume_from_huggingface": "",
"sample_every_n_epochs": 0,
"sample_every_n_steps": 25,
"sample_prompts": "a painting of a gas mask , by darius kawasaki",
"sample_sampler": "euler_a",
"save_as_bool": false,
"save_every_n_epochs": 1,
"save_every_n_steps": 0,
"save_last_n_steps": 0,
@ -81,14 +109,16 @@
"save_precision": "fp16",
"save_state": false,
"save_state_on_train_end": false,
"save_state_to_huggingface": false,
"scale_v_pred_loss_like_noise_pred": false,
"sdxl": false,
"seed": "1234",
"sdxl_cache_text_encoder_outputs": false,
"sdxl_no_half_vae": false,
"seed": 1234,
"shuffle_caption": false,
"stop_text_encoder_training": 0,
"train_batch_size": 4,
"train_data_dir": "",
"use_wandb": false,
"v2": false,
"v_parameterization": false,
"v_pred_like_loss": 0,

View File

@ -1 +0,0 @@
solo,simple background,teeth,grey background,from side,no humans,mask,1other,science fiction,cable,gas mask,tube,steampunk,machine

View File

@ -1 +0,0 @@
no humans,what

View File

@ -1 +0,0 @@
1girl,solo,nude,colored skin,monster,blue skin

View File

@ -1 +0,0 @@
solo,upper body,horns,from side,no humans,blood,1other

View File

@ -1 +0,0 @@
solo,1boy,male focus,mask,instrument,science fiction,realistic,music,gas mask

View File

@ -1 +0,0 @@
solo,no humans,mask,helmet,robot,mecha,1other,science fiction,damaged,gas mask,steampunk

View File

@ -1 +0,0 @@
solo,from side,no humans,mask,moon,helmet,portrait,1other,ambiguous gender,gas mask

View File

@ -1 +0,0 @@
outdoors,sky,cloud,no humans,monster,realistic,desert