mirror of https://github.com/bmaltais/kohya_ss
v25.0.0 release (#3138)
* Add support for custom learning rate scheduler type to the GUI * Add .webp image extension support to BLIP2 captioning. * Check for --debug flag for gui command-line args at startup * Validate GPU ID accelerate input and return error when needed * Update to latest sd-scripts dev commit * Fix issue with pip upgrade * Remove confusing log after command execution. * piecewise_constant scheduler * Update to latest sd-scripts dev commit * fix: fixed docker-compose for passing models via volumes * Prevent providing the legacy learning_rate if unet or te learning rate is provided * Fix toml noise offset parameters based on selected type * Fix adaptive_noise_scale value not properly loading from json config * Fix prompt.txt location * Improve "print command" output format * Use output model name as wandb run name if not provided * Update sd-scripts dev release * Bump crate-ci/typos from 1.21.0 to 1.22.9 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.21.0 to 1.22.9. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.21.0...v1.22.9) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Bump docker/build-push-action from 5 to 6 Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 5 to 6. - [Release notes](https://github.com/docker/build-push-action/releases) - [Commits](https://github.com/docker/build-push-action/compare/v5...v6) --- updated-dependencies: - dependency-name: docker/build-push-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <support@github.com> * Get latest sd3 code * Adding SD3 GUI elements * Fix interactivity * MVP GUI for SD3 * Fix text encoder issue * Add fork section to readme * Update sd3 commit * Merge security-fix * Update sc-script to latest code * Auto-detect model type for safetensors files Automatically tick the checkboxes for v2 and SDXL on the common training UI and LoRA extract/merge utilities. * autodetect-modeltype: remove unused lambda inputs * rework TE1/TE2 learning rate handling for SDXL dreambooth SDXL dreambooth apparently trains without the text encoders by default, requiring the `--train_text_encoder` flag to be passed so that the learning rates for TE1/TE2 are recognized. The toml handling now permits 0 to be passed as a learning rate in order to disable training of one or both text encoders. This behavior aligns with the description given on the GUI. TE1/TE2 learning rate parameters can be left blank on the GUI to not pass a value to the training script. * dreambooth_gui: fix toml value filtering condition In python3, `0 == False` will evaluate True. That can cause arg values of 0 to be wrongly eliminated from the toml output. The conditional must check the type when comparing for False. * autodetect-modeltype: also do the v2 checkbox in extract_lora * Update to latest dev branch code * bring back SDXLConfig accordion for dreambooth gui (#2694) b-fission <b-fission@users.noreply.github.com> * Update to latest sd3 branch commit * Fix merge issue * Update gradio version * Update to latest flux.1 code * Add Flux.1 Model checkbox and detection * Adding LoRA type "Flux1" to dropdown * Added Flux.1 parameters to GUI * Update sd-scripts and requirements * Add missing Flux.1 GUI parameters * Update to latest sd-scripts sd3 code * Fix issue with cache_text_encoder_outputs * Update to latest sd-scripts flux1 code * Adding new flux.1 options to GUI * Update to latest sd-scripts version of flux.1 * Adding guidance_scale option * Update to latest sd3 flux.1 sd-scripts * Add dreambooth and finetuning support for flux.1 * Update README * Fix t5xxl path issue in DB * add missing fp8_base parameter * Fix issue with guidance scale not being passed as float for values like 1 * Temporary fir for blockwise_fused_optimizers * Update to latest sd-scripts Flux.1 code * Fix blockwise_fused_optimizers typo * Add mem_eff_save option to GUI for Flux.1 * Added support for Flux.1 LoRA Merge * Update to latest sd-scripts sd3 branch code * Add diffusers option to flux.1 merge LoRA utility * Fix issue with split_mode and train_blocks * Updating requirements * Add flux_fused_backward_pass to dreambooth and finetuning * Update requirements_linux_docker.txt update accelerate version for linux_docker * Update to latest sd3 flux code * Add extract flux lora GUI * MErged latest sd3 branch code * Add support for split_qkv * Add missing network argument for split_qkv * Add timestep_sampling shift support * Update to latest sd-scripts flux.1 code * Add support for fp8_base_unet * Update requirements as per sd-scripts suggestion * Upgrade to cu124 * Update IPEX and ROCm * Fix issue with balancing when folder with name already exist * Update sd-scripts * Removed unsupported parameters from flux lora network * Bump crate-ci/typos from 1.23.6 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update sd-scripts code * Adding flux_shift option to timestep_sampling * Update sd-scripts release * Add support for Train T5-XXL * Update sd-scripts submodule * Add support for cpu_offload_checkpointing to GUI * Force t5xxl_max_token_length to be served as an integer * Fix typo for flux_shift * Update to latest sd-scripts code * Grouping lora parameters * Validate if lora type is Flux1 when flux1_checkbox is true * Improve visual sectioning of parameters for lora * Add dark mode styles * Missed one color * Update sd-scripts and add support for t5xxl LR * Update transformers and wandb module * Fix issue with new text_encoder_lr parameter syntax * Add support for lr_warmup_steps override * Update lr_warmup_steps code * Removing stable-diffusion-1.5 default model * Fix for max_train_steps * Revert some changes * Preliminary support for Flux1 OFT * Fix logic typo * Update sd-scripts * Add support for Rank for layers * Update lora_gui.py Fixed minor typos of "Regularization" * Update dreambooth_gui.py Fixed minor typos of "Regularization" * Update textual_inversion_gui.py Fixed minor typos of "Regularization" * Add support for Blocks to train * Add missing network parms * Fix issue with old_lr_warmup_steps * Update sd-scripts * Add support for ScheduleFree Optimizer Type * Update sd-scripts * Update requirements_pytorch_windows.txt * Update requirements_pytorch_windows.txt * Update sd-scripts from origin * Another sd-script update * Adding support for blocks_to_swap option to gui * Fix xformers install issue * feat(docker): mount models folder as a volume * feat(docker): add models folder to .dockerignore * Add support for AdEMAMix8bit optimizer * Bump crate-ci/typos from 1.23.6 to 1.25.0 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.25.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.25.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Fix typo on README.md * Add new --noverify option to skip requirements validation on startup * Update startup GUI code * Update setup code * Update sd-scripts * Update sf-scripts * Update Lycoris support * Allow to specify tensorboard host via env var TENSORBOARD_HOST * Update sd-scripts version * Update sd-scripts release * Update sd-scripts * Add --skip_cache_check option to GUI * Fix requirements issue * Add support for LyCORIS LoRA when training Flux.1 * Pin huggingface-hub version for gradio 5 * Update sd-scripts * Add support for --save_last_n_epochs_state * Update sd-scripts to version with Differential Output Preservation * Increase maximum flux-lora merge strength to 2 * Update to latest sd-scripts * Update requirements syntax (for windows) * Update requirements for linux * Update torch version and validation output * Fix typo * Update README * Fix validation issue on linux * Update sd-scripts, improve requirements outputs * Update requirements_runpod.txt * Update requirements for onnxruntime-gpu Needed for compatibility with CUDA 12. * Update onnxruntime-gpu==1.19.2 * Update sd-scripts release * Add support for save_last_n_epochs * Update sd-scripts * Bump crate-ci/typos from 1.23.6 to 1.26.8 (#2940) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.26.8. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.26.8) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * fix 'cached_download' from 'huggingface_hub' (#2947) Describe the bug: cannot import name 'cached_download' from 'huggingface_hub' It's applyed for all platforms Co-authored-by: bmaltais <bernard@ducourier.com> * Add support for quiet output for linux setup * Fix quiet issue * Update sd-scripts * Update sd-scripts with blocks_to_swap support * Make blocks_to_swap visible in LoRA tab * Fix blocks_to_swap not properly working * Update sd-scripts and allow python 3.10 to 3.12 * Fix issue with max_train_steps * Fix max_train_steps_info error * Reverting all changes for max_train_steps * Update sd-scripts * Update sd-scripts * Update to latest sd-scripts * Add support for RAdamScheduleFree * Add support for huber_scale * Add support for fused_backward_pass for sd3 finetuning * Add support for prodigyplus.ProdigyPlusScheduleFree * SD3 LoRA training MVP * Make blocks_to_swap common * Add support for sd3 lora disable_mmap_load_safetensors * Add a bunch of missing SD3 parameters * Fix clip_l issue for missing path * Fix train_t5xxl issue * Fix network_module issue * Add uniform to weighting_scheme * Bump crate-ci/typos from 1.23.6 to 1.28.1 (#2996) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.28.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.28.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * Update README.md (#3031) * Bump crate-ci/typos from 1.23.6 to 1.29.0 (#3029) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.29.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.29.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * Update sd-scripts version * Update setup.sh (#3054) Enter the current directory before executing setup.sh, otherwise the installer might failed to find rqeuirements.txt * Removing wrong folder * Fix issue with SD3 Lora training blocks_to_swap and fused_backward_pass * Fix dreambooth issue * Update to lastest sd-scripts code * Run on novita (#3119) (#3120) * add run on novita * adjust position Co-authored-by: hugo <liyiligang@users.noreply.github.com> * Bump crate-ci/typos from 1.23.6 to 1.30.0 (#3101) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.30.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.30.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * updated prodigyopt to 1.1.2 and removed duplicated row in requirements.txt (#3065) * fixed names on LR Schedure dropdown (#3064) * Update to latest sd-scripts version * fixed names on LR Schedure dropdown (#3064) * Cleanup venv3 * Fix issue with gradio on new installations Add support for latest sd-scripts pytorch-optimizer * Update README for v25.0.0 release --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: b-fission <b-fission@users.noreply.github.com> Co-authored-by: DevArqSangoi <lucas.sangoi@gmail.com> Co-authored-by: Кирилл Москвин <retreat.cost@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: b-fission <131207849+b-fission@users.noreply.github.com> Co-authored-by: eftSharptooth <76253264+eftSharptooth@users.noreply.github.com> Co-authored-by: Disty0 <disty@disty.xyz> Co-authored-by: wcole3 <will.cole3@gmail.com> Co-authored-by: rohitanshu <85547195+iamrohitanshu@users.noreply.github.com> Co-authored-by: wzgrx <39661556+wzgrx@users.noreply.github.com> Co-authored-by: Vladimir Sotnikov <vladimir.s@alphakek.ai> Co-authored-by: bulieme0 <53142287+bulieme@users.noreply.github.com> Co-authored-by: Nicolas Pereira <41456803+hqnicolas@users.noreply.github.com> Co-authored-by: ruucm <ruucm.a@gmail.com> Co-authored-by: CaledoniaProject <CaledoniaProject@users.noreply.github.com> Co-authored-by: hugo <liyiligang@users.noreply.github.com> Co-authored-by: Koro <Koronos@users.noreply.github.com>pull/3051/head v25.0.0
parent
a1b16e44f0
commit
ed55e81997
|
|
@ -3,6 +3,7 @@ cudnn_windows/
|
||||||
bitsandbytes_windows/
|
bitsandbytes_windows/
|
||||||
bitsandbytes_windows_deprecated/
|
bitsandbytes_windows_deprecated/
|
||||||
dataset/
|
dataset/
|
||||||
|
models/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
venv/
|
venv/
|
||||||
**/.hadolint.yml
|
**/.hadolint.yml
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ jobs:
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Build and push
|
- name: Build and push
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v6
|
||||||
id: publish
|
id: publish
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
|
|
||||||
|
|
@ -18,4 +18,4 @@ jobs:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: typos-action
|
- name: typos-action
|
||||||
uses: crate-ci/typos@v1.23.6
|
uses: crate-ci/typos@v1.30.0
|
||||||
|
|
|
||||||
|
|
@ -51,4 +51,6 @@ dataset/**
|
||||||
models
|
models
|
||||||
data
|
data
|
||||||
config.toml
|
config.toml
|
||||||
sd-scripts
|
sd-scripts
|
||||||
|
venv
|
||||||
|
venv*
|
||||||
86
README.md
86
README.md
|
|
@ -48,13 +48,20 @@ The GUI allows you to set the training parameters and generate and run the requi
|
||||||
- [Potential Solutions](#potential-solutions)
|
- [Potential Solutions](#potential-solutions)
|
||||||
- [SDXL training](#sdxl-training)
|
- [SDXL training](#sdxl-training)
|
||||||
- [Masked loss](#masked-loss)
|
- [Masked loss](#masked-loss)
|
||||||
|
- [Guides](#guides)
|
||||||
|
- [Using Accelerate Lora Tab to Select GPU ID](#using-accelerate-lora-tab-to-select-gpu-id)
|
||||||
|
- [Starting Accelerate in GUI](#starting-accelerate-in-gui)
|
||||||
|
- [Running Multiple Instances (linux)](#running-multiple-instances-linux)
|
||||||
|
- [Monitoring Processes](#monitoring-processes)
|
||||||
|
- [Interesting Forks](#interesting-forks)
|
||||||
- [Change History](#change-history)
|
- [Change History](#change-history)
|
||||||
|
- [v25.0.0](#v2500)
|
||||||
|
|
||||||
## 🦒 Colab
|
## 🦒 Colab
|
||||||
|
|
||||||
This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: <https://github.com/camenduru/kohya_ss-colab>.
|
This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: <https://github.com/camenduru/kohya_ss-colab>.
|
||||||
|
|
||||||
I would like to express my gratitude to camendutu for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository.
|
I would like to express my gratitude to camenduru for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository.
|
||||||
|
|
||||||
| Colab | Info |
|
| Colab | Info |
|
||||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ |
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ |
|
||||||
|
|
@ -71,7 +78,7 @@ To install the necessary dependencies on a Windows system, follow these steps:
|
||||||
1. Install [Python 3.10.11](https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe).
|
1. Install [Python 3.10.11](https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe).
|
||||||
- During the installation process, ensure that you select the option to add Python to the 'PATH' environment variable.
|
- During the installation process, ensure that you select the option to add Python to the 'PATH' environment variable.
|
||||||
|
|
||||||
2. Install [CUDA 11.8 toolkit](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows&target_arch=x86_64).
|
2. Install [CUDA 12.4 toolkit](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Windows&target_arch=x86_64).
|
||||||
|
|
||||||
3. Install [Git](https://git-scm.com/download/win).
|
3. Install [Git](https://git-scm.com/download/win).
|
||||||
|
|
||||||
|
|
@ -129,7 +136,7 @@ To install the necessary dependencies on a Linux system, ensure that you fulfill
|
||||||
apt install python3.10-venv
|
apt install python3.10-venv
|
||||||
```
|
```
|
||||||
|
|
||||||
- Install the CUDA 11.8 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64).
|
- Install the CUDA 12.4 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Linux&target_arch=x86_64).
|
||||||
|
|
||||||
- Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system.
|
- Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system.
|
||||||
|
|
||||||
|
|
@ -179,7 +186,7 @@ If you choose to use the interactive mode, the default values for the accelerate
|
||||||
|
|
||||||
To install the necessary components for Runpod and run kohya_ss, follow these steps:
|
To install the necessary components for Runpod and run kohya_ss, follow these steps:
|
||||||
|
|
||||||
1. Select the Runpod pytorch 2.0.1 template. This is important. Other templates may not work.
|
1. Select the Runpod pytorch 2.2.0 template. This is important. Other templates may not work.
|
||||||
|
|
||||||
2. SSH into the Runpod.
|
2. SSH into the Runpod.
|
||||||
|
|
||||||
|
|
@ -222,7 +229,9 @@ To run from a pre-built Runpod template, you can:
|
||||||
3. Once deployed, connect to the Runpod on HTTP 3010 to access the kohya_ss GUI. You can also connect to auto1111 on HTTP 3000.
|
3. Once deployed, connect to the Runpod on HTTP 3010 to access the kohya_ss GUI. You can also connect to auto1111 on HTTP 3000.
|
||||||
|
|
||||||
### Novita
|
### Novita
|
||||||
|
|
||||||
#### Pre-built Novita template
|
#### Pre-built Novita template
|
||||||
|
|
||||||
1. Open the Novita template by clicking on <https://novita.ai/gpus-console?templateId=312>.
|
1. Open the Novita template by clicking on <https://novita.ai/gpus-console?templateId=312>.
|
||||||
|
|
||||||
2. Deploy the template on the desired host.
|
2. Deploy the template on the desired host.
|
||||||
|
|
@ -339,13 +348,27 @@ To upgrade your installation on Linux or macOS, follow these steps:
|
||||||
To launch the GUI service, you can use the provided scripts or run the `kohya_gui.py` script directly. Use the command line arguments listed below to configure the underlying service.
|
To launch the GUI service, you can use the provided scripts or run the `kohya_gui.py` script directly. Use the command line arguments listed below to configure the underlying service.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
--listen: Specify the IP address to listen on for connections to Gradio.
|
--help show this help message and exit
|
||||||
--username: Set a username for authentication.
|
--config CONFIG Path to the toml config file for interface defaults
|
||||||
--password: Set a password for authentication.
|
--debug Debug on
|
||||||
--server_port: Define the port to run the server listener on.
|
--listen LISTEN IP to listen on for connections to Gradio
|
||||||
--inbrowser: Open the Gradio UI in a web browser.
|
--username USERNAME Username for authentication
|
||||||
--share: Share the Gradio UI.
|
--password PASSWORD Password for authentication
|
||||||
--language: Set custom language
|
--server_port SERVER_PORT
|
||||||
|
Port to run the server listener on
|
||||||
|
--inbrowser Open in browser
|
||||||
|
--share Share the gradio UI
|
||||||
|
--headless Is the server headless
|
||||||
|
--language LANGUAGE Set custom language
|
||||||
|
--use-ipex Use IPEX environment
|
||||||
|
--use-rocm Use ROCm environment
|
||||||
|
--do_not_use_shell Enforce not to use shell=True when running external commands
|
||||||
|
--do_not_share Do not share the gradio UI
|
||||||
|
--requirements REQUIREMENTS
|
||||||
|
requirements file to use for validation
|
||||||
|
--root_path ROOT_PATH
|
||||||
|
`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss
|
||||||
|
--noverify Disable requirements verification
|
||||||
```
|
```
|
||||||
|
|
||||||
### Launching the GUI on Windows
|
### Launching the GUI on Windows
|
||||||
|
|
@ -448,6 +471,45 @@ The feature is not fully tested, so there may be bugs. If you find any issues, p
|
||||||
|
|
||||||
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
||||||
|
|
||||||
|
## Guides
|
||||||
|
|
||||||
|
The following are guides extracted from issues discussions
|
||||||
|
|
||||||
|
### Using Accelerate Lora Tab to Select GPU ID
|
||||||
|
|
||||||
|
#### Starting Accelerate in GUI
|
||||||
|
|
||||||
|
- Open the kohya GUI on your desired port.
|
||||||
|
- Open the `Accelerate launch` tab
|
||||||
|
- Ensure the Multi-GPU checkbox is unchecked.
|
||||||
|
- Set GPU IDs to the desired GPU (like 1).
|
||||||
|
|
||||||
|
#### Running Multiple Instances (linux)
|
||||||
|
|
||||||
|
- For tracking multiple processes, use separate kohya GUI instances on different ports (e.g., 7860, 7861).
|
||||||
|
- Start instances using `nohup ./gui.sh --listen 0.0.0.0 --server_port <port> --headless > log.log 2>&1 &`.
|
||||||
|
|
||||||
|
#### Monitoring Processes
|
||||||
|
|
||||||
|
- Open each GUI in a separate browser tab.
|
||||||
|
- For terminal access, use SSH and tools like `tmux` or `screen`.
|
||||||
|
|
||||||
|
For more details, visit the [GitHub issue](https://github.com/bmaltais/kohya_ss/issues/2577).
|
||||||
|
|
||||||
|
## Interesting Forks
|
||||||
|
|
||||||
|
To finetune HunyuanDiT models or create LoRAs, visit this [fork](https://github.com/Tencent/HunyuanDiT/tree/main/kohya_ss-hydit)
|
||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
See release information.
|
### v25.0.0
|
||||||
|
|
||||||
|
This is a SIGNIFICANT upgrade. I am groing in uncharted territory here because kohya has not merged any of the recent flux.1 and sd3 updated to his code in his main branch yet... but I feel updates in his code has pretty much dried down and I think his code is probably ready for prime time. So instead of keeping my GUI in the cave man ages, I am opting to move the code for the GUI with support for flux.1 and sd3 to the main branch of my project. Perhaps this will bite me in the proverbias ass... but for those who would rather stay on the older pre "flux.1 and sd3" updates, you can always do:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git checkout v24.1.7
|
||||||
|
```
|
||||||
|
|
||||||
|
after cloning the repo.
|
||||||
|
|
||||||
|
For all the info regarding the new flux.1 and sd3 parameters, see <https://github.com/kohya-ss/sd-scripts/blob/sd3/README.md> for more details.
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ parms="parms"
|
||||||
nin="nin"
|
nin="nin"
|
||||||
extention="extention" # Intentionally left
|
extention="extention" # Intentionally left
|
||||||
nd="nd"
|
nd="nd"
|
||||||
|
pn="pn"
|
||||||
shs="shs"
|
shs="shs"
|
||||||
sts="sts"
|
sts="sts"
|
||||||
scs="scs"
|
scs="scs"
|
||||||
|
|
|
||||||
217
assets/style.css
217
assets/style.css
|
|
@ -1,4 +1,4 @@
|
||||||
#open_folder_small{
|
#open_folder_small {
|
||||||
min-width: auto;
|
min-width: auto;
|
||||||
flex-grow: 0;
|
flex-grow: 0;
|
||||||
padding-left: 0.25em;
|
padding-left: 0.25em;
|
||||||
|
|
@ -7,14 +7,14 @@
|
||||||
font-size: 1.5em;
|
font-size: 1.5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#open_folder{
|
#open_folder {
|
||||||
height: auto;
|
height: auto;
|
||||||
flex-grow: 0;
|
flex-grow: 0;
|
||||||
padding-left: 0.25em;
|
padding-left: 0.25em;
|
||||||
padding-right: 0.25em;
|
padding-right: 0.25em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#number_input{
|
#number_input {
|
||||||
min-width: min-content;
|
min-width: min-content;
|
||||||
flex-grow: 0.3;
|
flex-grow: 0.3;
|
||||||
padding-left: 0.75em;
|
padding-left: 0.75em;
|
||||||
|
|
@ -22,7 +22,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.ver-class {
|
.ver-class {
|
||||||
color: #808080;
|
color: #6d6d6d; /* Neutral dark gray */
|
||||||
font-size: small;
|
font-size: small;
|
||||||
text-align: right;
|
text-align: right;
|
||||||
padding-right: 1em;
|
padding-right: 1em;
|
||||||
|
|
@ -35,13 +35,212 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
#myTensorButton {
|
#myTensorButton {
|
||||||
background: radial-gradient(ellipse, #3a99ff, #52c8ff);
|
background: #555c66; /* Muted dark gray */
|
||||||
color: white;
|
color: white;
|
||||||
border: #296eb8;
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 0.5em 1em;
|
||||||
|
/* box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); Subtle shadow */
|
||||||
|
/* transition: box-shadow 0.3s ease; */
|
||||||
|
}
|
||||||
|
|
||||||
|
#myTensorButton:hover {
|
||||||
|
/* box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); Slightly increased shadow on hover */
|
||||||
}
|
}
|
||||||
|
|
||||||
#myTensorButtonStop {
|
#myTensorButtonStop {
|
||||||
background: radial-gradient(ellipse, #52c8ff, #3a99ff);
|
background: #777d85; /* Lighter muted gray */
|
||||||
color: black;
|
color: white;
|
||||||
border: #296eb8;
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 0.5em 1em;
|
||||||
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||||
|
/* transition: box-shadow 0.3s ease; */
|
||||||
|
}
|
||||||
|
|
||||||
|
#myTensorButtonStop:hover {
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.advanced_background {
|
||||||
|
background: #f4f4f4; /* Light neutral gray */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */
|
||||||
|
}
|
||||||
|
|
||||||
|
.advanced_background:hover {
|
||||||
|
background-color: #ebebeb; /* Slightly darker background on hover */
|
||||||
|
border: 1px solid #ccc; /* Add a subtle border */
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.basic_background {
|
||||||
|
background: #eaeff1; /* Muted cool gray */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.basic_background:hover {
|
||||||
|
background-color: #dfe4e7; /* Slightly darker cool gray on hover */
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.huggingface_background {
|
||||||
|
background: #e0e4e7; /* Light gray with a hint of blue */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.huggingface_background:hover {
|
||||||
|
background-color: #d6dce0; /* Slightly darker on hover */
|
||||||
|
border: 1px solid #bbb;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.flux1_background {
|
||||||
|
background: #ece9e6; /* Light beige tone */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.flux1_background:hover {
|
||||||
|
background-color: #e2dfdb; /* Slightly darker beige on hover */
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.preset_background {
|
||||||
|
background: #f0f0f0; /* Light gray */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preset_background:hover {
|
||||||
|
background-color: #e6e6e6; /* Slightly darker on hover */
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.samples_background {
|
||||||
|
background: #d9dde1; /* Soft muted gray-blue */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.samples_background:hover {
|
||||||
|
background-color: #cfd3d8; /* Slightly darker on hover */
|
||||||
|
border: 1px solid #bbb;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dark mode styles */
|
||||||
|
.dark .advanced_background {
|
||||||
|
background: #172029; /* Slightly darker gradio dark theme */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .advanced_background:hover {
|
||||||
|
background-color: #121920; /* Slightly darker background on hover */
|
||||||
|
border: 1px solid #000000; /* Add a subtle border */
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .basic_background {
|
||||||
|
background: #172029;
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .basic_background:hover {
|
||||||
|
background-color: #11181e;
|
||||||
|
border: 1px solid #000000;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .huggingface_background {
|
||||||
|
background: #131c25;
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .huggingface_background:hover {
|
||||||
|
background-color: #131c25;
|
||||||
|
border: 1px solid #000000;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .flux1_background {
|
||||||
|
background: #131c25;
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .flux1_background:hover {
|
||||||
|
background-color: #131c25;
|
||||||
|
border: 1px solid #000000;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .preset_background {
|
||||||
|
background: #191d25;
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .preset_background:hover {
|
||||||
|
background-color: #212530;
|
||||||
|
border: 1px solid #000000;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .samples_background {
|
||||||
|
background: #101e2c;
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .samples_background:hover {
|
||||||
|
background-color: #17293a;
|
||||||
|
border: 1px solid #000000;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.flux1_rank_layers_background {
|
||||||
|
background: #ece9e6; /* White background for clear theme */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.flux1_rank_layers_background:hover {
|
||||||
|
background-color: #dddad7; /* Slightly darker on hover */
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .flux1_rank_layers_background {
|
||||||
|
background: #131c25; /* Dark background for dark theme */
|
||||||
|
padding: 1em;
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .flux1_rank_layers_background:hover {
|
||||||
|
background-color: #131c25; /* Slightly darker on hover */
|
||||||
|
border: 1px solid #000000;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||||
}
|
}
|
||||||
|
|
@ -48,6 +48,7 @@ learning_rate_te1 = 0.0001 # Learning rate text encoder 1
|
||||||
learning_rate_te2 = 0.0001 # Learning rate text encoder 2
|
learning_rate_te2 = 0.0001 # Learning rate text encoder 2
|
||||||
lr_scheduler = "cosine" # LR Scheduler
|
lr_scheduler = "cosine" # LR Scheduler
|
||||||
lr_scheduler_args = "" # LR Scheduler args
|
lr_scheduler_args = "" # LR Scheduler args
|
||||||
|
lr_scheduler_type = "" # LR Scheduler type
|
||||||
lr_warmup = 0 # LR Warmup (% of total steps)
|
lr_warmup = 0 # LR Warmup (% of total steps)
|
||||||
lr_scheduler_num_cycles = 1 # LR Scheduler num cycles
|
lr_scheduler_num_cycles = 1 # LR Scheduler num cycles
|
||||||
lr_scheduler_power = 1.0 # LR Scheduler power
|
lr_scheduler_power = 1.0 # LR Scheduler power
|
||||||
|
|
@ -150,6 +151,9 @@ sample_prompts = "" # Sample prompts
|
||||||
sample_sampler = "euler_a" # Sampler to use for image sampling
|
sample_sampler = "euler_a" # Sampler to use for image sampling
|
||||||
|
|
||||||
[sdxl]
|
[sdxl]
|
||||||
|
disable_mmap_load_safetensors = false # Disable mmap load safe tensors
|
||||||
|
fused_backward_pass = false # Fused backward pass
|
||||||
|
fused_optimizer_groups = 0 # Fused optimizer groups
|
||||||
sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs
|
sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs
|
||||||
sdxl_no_half_vae = true # No half VAE
|
sdxl_no_half_vae = true # No half VAE
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,13 @@ services:
|
||||||
- /tmp
|
- /tmp
|
||||||
volumes:
|
volumes:
|
||||||
- /tmp/.X11-unix:/tmp/.X11-unix
|
- /tmp/.X11-unix:/tmp/.X11-unix
|
||||||
|
- ./models:/app/models
|
||||||
- ./dataset:/dataset
|
- ./dataset:/dataset
|
||||||
- ./dataset/images:/app/data
|
- ./dataset/images:/app/data
|
||||||
- ./dataset/logs:/app/logs
|
- ./dataset/logs:/app/logs
|
||||||
- ./dataset/outputs:/app/outputs
|
- ./dataset/outputs:/app/outputs
|
||||||
- ./dataset/regularization:/app/regularization
|
- ./dataset/regularization:/app/regularization
|
||||||
|
- ./models:/app/models
|
||||||
- ./.cache/config:/app/config
|
- ./.cache/config:/app/config
|
||||||
- ./.cache/user:/home/1000/.cache
|
- ./.cache/user:/home/1000/.cache
|
||||||
- ./.cache/triton:/home/1000/.triton
|
- ./.cache/triton:/home/1000/.triton
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,27 @@
|
||||||
## Updating a Local Branch with the Latest sd-scripts Changes
|
## Updating a Local Submodule with the Latest sd-scripts Changes
|
||||||
|
|
||||||
To update your local branch with the most recent changes from kohya/sd-scripts, follow these steps:
|
To update your local branch with the most recent changes from kohya/sd-scripts, follow these steps:
|
||||||
|
|
||||||
1. Add sd-scripts as an alternative remote by executing the following command:
|
1. When you wish to perform an update of the dev branch, execute the following commands:
|
||||||
|
|
||||||
```
|
```bash
|
||||||
git remote add sd-scripts https://github.com/kohya-ss/sd-scripts.git
|
cd sd-scripts
|
||||||
```
|
git fetch
|
||||||
|
|
||||||
2. When you wish to perform an update, execute the following commands:
|
|
||||||
|
|
||||||
```
|
|
||||||
git checkout dev
|
git checkout dev
|
||||||
git pull sd-scripts main
|
git pull origin dev
|
||||||
|
cd ..
|
||||||
|
git add sd-scripts
|
||||||
|
git commit -m "Update sd-scripts submodule to the latest on dev"
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, if you want to obtain the latest code, even if it may be unstable:
|
Alternatively, if you want to obtain the latest code from main:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd sd-scripts
|
||||||
|
git fetch
|
||||||
|
git checkout main
|
||||||
|
git pull origin main
|
||||||
|
cd ..
|
||||||
|
git add sd-scripts
|
||||||
|
git commit -m "Update sd-scripts submodule to the latest on main"
|
||||||
```
|
```
|
||||||
git checkout dev
|
|
||||||
git pull sd-scripts dev
|
|
||||||
```
|
|
||||||
|
|
||||||
3. If you encounter a conflict with the Readme file, you can resolve it by taking the following steps:
|
|
||||||
|
|
||||||
```
|
|
||||||
git add README.md
|
|
||||||
git merge --continue
|
|
||||||
```
|
|
||||||
|
|
||||||
This may open a text editor for a commit message, but you can simply save and close it to proceed. Following these steps should resolve the conflict. If you encounter additional merge conflicts, consider them as valuable learning opportunities for personal growth.
|
|
||||||
8
gui.bat
8
gui.bat
|
|
@ -7,11 +7,13 @@ call .\venv\Scripts\deactivate.bat
|
||||||
|
|
||||||
:: Activate the virtual environment
|
:: Activate the virtual environment
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
|
|
||||||
|
:: Update pip to latest version
|
||||||
|
python -m pip install --upgrade pip -q
|
||||||
|
|
||||||
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
|
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
|
||||||
|
|
||||||
:: Validate requirements
|
echo Starting the GUI... this might take some time...
|
||||||
python.exe .\setup\validate_requirements.py
|
|
||||||
if %errorlevel% neq 0 exit /b %errorlevel%
|
|
||||||
|
|
||||||
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
|
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
|
||||||
if %errorlevel% equ 0 (
|
if %errorlevel% equ 0 (
|
||||||
|
|
|
||||||
30
gui.ps1
30
gui.ps1
|
|
@ -7,28 +7,18 @@ if ($env:VIRTUAL_ENV) {
|
||||||
# Activate the virtual environment
|
# Activate the virtual environment
|
||||||
# Write-Host "Activating the virtual environment..."
|
# Write-Host "Activating the virtual environment..."
|
||||||
& .\venv\Scripts\activate
|
& .\venv\Scripts\activate
|
||||||
|
|
||||||
|
python.exe -m pip install --upgrade pip -q
|
||||||
|
|
||||||
$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib"
|
$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib"
|
||||||
|
|
||||||
# Debug info about system
|
Write-Host "Starting the GUI... this might take some time..."
|
||||||
# python.exe .\setup\debug_info.py
|
|
||||||
|
|
||||||
# Validate the requirements and store the exit code
|
$argsFromFile = @()
|
||||||
python.exe .\setup\validate_requirements.py
|
if (Test-Path .\gui_parameters.txt) {
|
||||||
|
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
|
||||||
# Check the exit code and stop execution if it is not 0
|
|
||||||
if ($LASTEXITCODE -ne 0) {
|
|
||||||
Write-Host "Failed to validate requirements. Exiting script..."
|
|
||||||
exit $LASTEXITCODE
|
|
||||||
}
|
}
|
||||||
|
$args_combo = $argsFromFile + $args
|
||||||
|
# Write-Host "The arguments passed to this script were: $args_combo"
|
||||||
|
python.exe kohya_gui.py $args_combo
|
||||||
|
|
||||||
# If the exit code is 0, read arguments from gui_parameters.txt (if it exists)
|
|
||||||
# and run the kohya_gui.py script with the command-line arguments
|
|
||||||
if ($LASTEXITCODE -eq 0) {
|
|
||||||
$argsFromFile = @()
|
|
||||||
if (Test-Path .\gui_parameters.txt) {
|
|
||||||
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
|
|
||||||
}
|
|
||||||
$args_combo = $argsFromFile + $args
|
|
||||||
# Write-Host "The arguments passed to this script were: $args_combo"
|
|
||||||
python.exe kohya_gui.py $args_combo
|
|
||||||
}
|
|
||||||
|
|
|
||||||
8
gui.sh
8
gui.sh
|
|
@ -111,10 +111,4 @@ then
|
||||||
STARTUP_CMD=python
|
STARTUP_CMD=python
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Validate the requirements and run the script if successful
|
"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "--requirements=""$REQUIREMENTS_FILE" "$@"
|
||||||
if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then
|
|
||||||
"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "$@"
|
|
||||||
else
|
|
||||||
echo "Validation failed. Exiting..."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
|
||||||
237
kohya_gui.py
237
kohya_gui.py
|
|
@ -1,6 +1,10 @@
|
||||||
import gradio as gr
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
import contextlib
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from kohya_gui.class_gui_config import KohyaSSGUIConfig
|
from kohya_gui.class_gui_config import KohyaSSGUIConfig
|
||||||
from kohya_gui.dreambooth_gui import dreambooth_tab
|
from kohya_gui.dreambooth_gui import dreambooth_tab
|
||||||
from kohya_gui.finetune_gui import finetune_tab
|
from kohya_gui.finetune_gui import finetune_tab
|
||||||
|
|
@ -8,71 +12,43 @@ from kohya_gui.textual_inversion_gui import ti_tab
|
||||||
from kohya_gui.utilities import utilities_tab
|
from kohya_gui.utilities import utilities_tab
|
||||||
from kohya_gui.lora_gui import lora_tab
|
from kohya_gui.lora_gui import lora_tab
|
||||||
from kohya_gui.class_lora_tab import LoRATools
|
from kohya_gui.class_lora_tab import LoRATools
|
||||||
|
|
||||||
from kohya_gui.custom_logging import setup_logging
|
from kohya_gui.custom_logging import setup_logging
|
||||||
from kohya_gui.localization_ext import add_javascript
|
from kohya_gui.localization_ext import add_javascript
|
||||||
|
|
||||||
|
PYTHON = sys.executable
|
||||||
|
project_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
def UI(**kwargs):
|
# Function to read file content, suppressing any FileNotFoundError
|
||||||
add_javascript(kwargs.get("language"))
|
def read_file_content(file_path):
|
||||||
css = ""
|
with contextlib.suppress(FileNotFoundError):
|
||||||
|
with open(file_path, "r", encoding="utf8") as file:
|
||||||
|
return file.read()
|
||||||
|
return ""
|
||||||
|
|
||||||
headless = kwargs.get("headless", False)
|
# Function to initialize the Gradio UI interface
|
||||||
log.info(f"headless: {headless}")
|
def initialize_ui_interface(config, headless, use_shell, release_info, readme_content):
|
||||||
|
# Load custom CSS if available
|
||||||
|
css = read_file_content("./assets/style.css")
|
||||||
|
|
||||||
if os.path.exists("./assets/style.css"):
|
# Create the main Gradio Blocks interface
|
||||||
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
|
ui_interface = gr.Blocks(css=css, title=f"Kohya_ss GUI {release_info}", theme=gr.themes.Default())
|
||||||
log.debug("Load CSS...")
|
with ui_interface:
|
||||||
css += file.read() + "\n"
|
# Create tabs for different functionalities
|
||||||
|
|
||||||
if os.path.exists("./.release"):
|
|
||||||
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
|
|
||||||
release = file.read()
|
|
||||||
|
|
||||||
if os.path.exists("./README.md"):
|
|
||||||
with open(os.path.join("./README.md"), "r", encoding="utf8") as file:
|
|
||||||
README = file.read()
|
|
||||||
|
|
||||||
interface = gr.Blocks(
|
|
||||||
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
|
||||||
)
|
|
||||||
|
|
||||||
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
|
|
||||||
|
|
||||||
if config.is_config_loaded():
|
|
||||||
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
|
|
||||||
|
|
||||||
use_shell_flag = True
|
|
||||||
# if os.name == "posix":
|
|
||||||
# use_shell_flag = True
|
|
||||||
|
|
||||||
use_shell_flag = config.get("settings.use_shell", use_shell_flag)
|
|
||||||
|
|
||||||
if kwargs.get("do_not_use_shell", False):
|
|
||||||
use_shell_flag = False
|
|
||||||
|
|
||||||
if use_shell_flag:
|
|
||||||
log.info("Using shell=True when running external commands...")
|
|
||||||
|
|
||||||
with interface:
|
|
||||||
with gr.Tab("Dreambooth"):
|
with gr.Tab("Dreambooth"):
|
||||||
(
|
(
|
||||||
train_data_dir_input,
|
train_data_dir_input,
|
||||||
reg_data_dir_input,
|
reg_data_dir_input,
|
||||||
output_dir_input,
|
output_dir_input,
|
||||||
logging_dir_input,
|
logging_dir_input,
|
||||||
) = dreambooth_tab(
|
) = dreambooth_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||||
headless=headless, config=config, use_shell_flag=use_shell_flag
|
|
||||||
)
|
|
||||||
with gr.Tab("LoRA"):
|
with gr.Tab("LoRA"):
|
||||||
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
|
lora_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||||
with gr.Tab("Textual Inversion"):
|
with gr.Tab("Textual Inversion"):
|
||||||
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
|
ti_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||||
with gr.Tab("Finetuning"):
|
with gr.Tab("Finetuning"):
|
||||||
finetune_tab(
|
finetune_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||||
headless=headless, config=config, use_shell_flag=use_shell_flag
|
|
||||||
)
|
|
||||||
with gr.Tab("Utilities"):
|
with gr.Tab("Utilities"):
|
||||||
|
# Utilities tab requires inputs from the Dreambooth tab
|
||||||
utilities_tab(
|
utilities_tab(
|
||||||
train_data_dir_input=train_data_dir_input,
|
train_data_dir_input=train_data_dir_input,
|
||||||
reg_data_dir_input=reg_data_dir_input,
|
reg_data_dir_input=reg_data_dir_input,
|
||||||
|
|
@ -84,102 +60,97 @@ def UI(**kwargs):
|
||||||
with gr.Tab("LoRA"):
|
with gr.Tab("LoRA"):
|
||||||
_ = LoRATools(headless=headless)
|
_ = LoRATools(headless=headless)
|
||||||
with gr.Tab("About"):
|
with gr.Tab("About"):
|
||||||
gr.Markdown(f"kohya_ss GUI release {release}")
|
# About tab to display release information and README content
|
||||||
|
gr.Markdown(f"kohya_ss GUI release {release_info}")
|
||||||
with gr.Tab("README"):
|
with gr.Tab("README"):
|
||||||
gr.Markdown(README)
|
gr.Markdown(readme_content)
|
||||||
|
|
||||||
htmlStr = f"""
|
# Display release information in a div element
|
||||||
<html>
|
gr.Markdown(f"<div class='ver-class'>{release_info}</div>")
|
||||||
<body>
|
|
||||||
<div class="ver-class">{release}</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""
|
|
||||||
gr.HTML(htmlStr)
|
|
||||||
# Show the interface
|
|
||||||
launch_kwargs = {}
|
|
||||||
username = kwargs.get("username")
|
|
||||||
password = kwargs.get("password")
|
|
||||||
server_port = kwargs.get("server_port", 0)
|
|
||||||
inbrowser = kwargs.get("inbrowser", False)
|
|
||||||
share = kwargs.get("share", False)
|
|
||||||
do_not_share = kwargs.get("do_not_share", False)
|
|
||||||
server_name = kwargs.get("listen")
|
|
||||||
root_path = kwargs.get("root_path", None)
|
|
||||||
|
|
||||||
launch_kwargs["server_name"] = server_name
|
return ui_interface
|
||||||
if username and password:
|
|
||||||
launch_kwargs["auth"] = (username, password)
|
|
||||||
if server_port > 0:
|
|
||||||
launch_kwargs["server_port"] = server_port
|
|
||||||
if inbrowser:
|
|
||||||
launch_kwargs["inbrowser"] = inbrowser
|
|
||||||
if do_not_share:
|
|
||||||
launch_kwargs["share"] = False
|
|
||||||
else:
|
|
||||||
if share:
|
|
||||||
launch_kwargs["share"] = share
|
|
||||||
if root_path:
|
|
||||||
launch_kwargs["root_path"] = root_path
|
|
||||||
launch_kwargs["debug"] = True
|
|
||||||
interface.launch(**launch_kwargs)
|
|
||||||
|
|
||||||
|
# Function to configure and launch the UI
|
||||||
|
def UI(**kwargs):
|
||||||
|
# Add custom JavaScript if specified
|
||||||
|
add_javascript(kwargs.get("language"))
|
||||||
|
log.info(f"headless: {kwargs.get('headless', False)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# Load release and README information
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
release_info = read_file_content("./.release")
|
||||||
|
readme_content = read_file_content("./README.md")
|
||||||
|
|
||||||
|
# Load configuration from the specified file
|
||||||
|
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
|
||||||
|
if config.is_config_loaded():
|
||||||
|
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
|
||||||
|
|
||||||
|
# Determine if shell should be used for running external commands
|
||||||
|
use_shell = not kwargs.get("do_not_use_shell", False) and config.get("settings.use_shell", True)
|
||||||
|
if use_shell:
|
||||||
|
log.info("Using shell=True when running external commands...")
|
||||||
|
|
||||||
|
# Initialize the Gradio UI interface
|
||||||
|
ui_interface = initialize_ui_interface(config, kwargs.get("headless", False), use_shell, release_info, readme_content)
|
||||||
|
|
||||||
|
# Construct launch parameters using dictionary comprehension
|
||||||
|
launch_params = {
|
||||||
|
"server_name": kwargs.get("listen"),
|
||||||
|
"auth": (kwargs["username"], kwargs["password"]) if kwargs.get("username") and kwargs.get("password") else None,
|
||||||
|
"server_port": kwargs.get("server_port", 0) if kwargs.get("server_port", 0) > 0 else None,
|
||||||
|
"inbrowser": kwargs.get("inbrowser", False),
|
||||||
|
"share": False if kwargs.get("do_not_share", False) else kwargs.get("share", False),
|
||||||
|
"root_path": kwargs.get("root_path", None),
|
||||||
|
"debug": kwargs.get("debug", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# This line filters out any key-value pairs from `launch_params` where the value is `None`, ensuring only valid parameters are passed to the `launch` function.
|
||||||
|
launch_params = {k: v for k, v in launch_params.items() if v is not None}
|
||||||
|
|
||||||
|
# Launch the Gradio interface with the specified parameters
|
||||||
|
ui_interface.launch(**launch_params)
|
||||||
|
|
||||||
|
# Function to initialize argument parser for command-line arguments
|
||||||
|
def initialize_arg_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--config", type=str, default="./config.toml", help="Path to the toml config file for interface defaults")
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="./config.toml",
|
|
||||||
help="Path to the toml config file for interface defaults",
|
|
||||||
)
|
|
||||||
parser.add_argument("--debug", action="store_true", help="Debug on")
|
parser.add_argument("--debug", action="store_true", help="Debug on")
|
||||||
parser.add_argument(
|
parser.add_argument("--listen", type=str, default="127.0.0.1", help="IP to listen on for connections to Gradio")
|
||||||
"--listen",
|
parser.add_argument("--username", type=str, default="", help="Username for authentication")
|
||||||
type=str,
|
parser.add_argument("--password", type=str, default="", help="Password for authentication")
|
||||||
default="127.0.0.1",
|
parser.add_argument("--server_port", type=int, default=0, help="Port to run the server listener on")
|
||||||
help="IP to listen on for connections to Gradio",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--username", type=str, default="", help="Username for authentication"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--password", type=str, default="", help="Password for authentication"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--server_port",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Port to run the server listener on",
|
|
||||||
)
|
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||||
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
|
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
|
||||||
parser.add_argument(
|
parser.add_argument("--headless", action="store_true", help="Is the server headless")
|
||||||
"--headless", action="store_true", help="Is the server headless"
|
parser.add_argument("--language", type=str, default=None, help="Set custom language")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--language", type=str, default=None, help="Set custom language"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
||||||
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
|
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
|
||||||
|
parser.add_argument("--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands")
|
||||||
|
parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI")
|
||||||
|
parser.add_argument("--requirements", type=str, default=None, help="requirements file to use for validation")
|
||||||
|
parser.add_argument("--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss")
|
||||||
|
parser.add_argument("--noverify", action="store_true", help="Disable requirements verification")
|
||||||
|
return parser
|
||||||
|
|
||||||
parser.add_argument(
|
if __name__ == "__main__":
|
||||||
"--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands"
|
# Initialize argument parser and parse arguments
|
||||||
)
|
parser = initialize_arg_parser()
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--do_not_share", action="store_true", help="Do not share the gradio UI"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging based on the debug flag
|
||||||
log = setup_logging(debug=args.debug)
|
log = setup_logging(debug=args.debug)
|
||||||
|
|
||||||
UI(**vars(args))
|
# Verify requirements unless `noverify` flag is set
|
||||||
|
if args.noverify:
|
||||||
|
log.warning("Skipping requirements verification.")
|
||||||
|
else:
|
||||||
|
# Run the validation command to verify requirements
|
||||||
|
validation_command = [PYTHON, os.path.join(project_dir, "setup", "validate_requirements.py")]
|
||||||
|
|
||||||
|
if args.requirements is not None:
|
||||||
|
validation_command.append(f"--requirements={args.requirements}")
|
||||||
|
|
||||||
|
subprocess.run(validation_command, check=True)
|
||||||
|
|
||||||
|
# Launch the UI with the provided arguments
|
||||||
|
UI(**vars(args))
|
||||||
|
|
@ -102,7 +102,7 @@ def caption_images(
|
||||||
postfix=postfix,
|
postfix=postfix,
|
||||||
)
|
)
|
||||||
# Replace specified text in caption files if find and replace text is provided
|
# Replace specified text in caption files if find and replace text is provided
|
||||||
if find_text and replace_text:
|
if find_text:
|
||||||
find_replace(
|
find_replace(
|
||||||
folder_path=images_dir,
|
folder_path=images_dir,
|
||||||
caption_file_ext=caption_ext,
|
caption_file_ext=caption_ext,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ def get_images_in_directory(directory_path):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# List of common image file extensions to look for
|
# List of common image file extensions to look for
|
||||||
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
|
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
|
||||||
|
|
||||||
# Generate a list of image file paths in the directory
|
# Generate a list of image file paths in the directory
|
||||||
image_files = [
|
image_files = [
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,10 @@ import os
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
from .class_gui_config import KohyaSSGUIConfig
|
from .class_gui_config import KohyaSSGUIConfig
|
||||||
|
from .custom_logging import setup_logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
log = setup_logging()
|
||||||
|
|
||||||
|
|
||||||
class AccelerateLaunch:
|
class AccelerateLaunch:
|
||||||
|
|
@ -79,12 +83,16 @@ class AccelerateLaunch:
|
||||||
)
|
)
|
||||||
self.dynamo_use_fullgraph = gr.Checkbox(
|
self.dynamo_use_fullgraph = gr.Checkbox(
|
||||||
label="Dynamo use fullgraph",
|
label="Dynamo use fullgraph",
|
||||||
value=self.config.get("accelerate_launch.dynamo_use_fullgraph", False),
|
value=self.config.get(
|
||||||
|
"accelerate_launch.dynamo_use_fullgraph", False
|
||||||
|
),
|
||||||
info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
|
info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
|
||||||
)
|
)
|
||||||
self.dynamo_use_dynamic = gr.Checkbox(
|
self.dynamo_use_dynamic = gr.Checkbox(
|
||||||
label="Dynamo use dynamic",
|
label="Dynamo use dynamic",
|
||||||
value=self.config.get("accelerate_launch.dynamo_use_dynamic", False),
|
value=self.config.get(
|
||||||
|
"accelerate_launch.dynamo_use_dynamic", False
|
||||||
|
),
|
||||||
info="Whether to enable dynamic shape tracing.",
|
info="Whether to enable dynamic shape tracing.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -103,6 +111,24 @@ class AccelerateLaunch:
|
||||||
placeholder="example: 0,1",
|
placeholder="example: 0,1",
|
||||||
info=" What GPUs (by id) should be used for training on this machine as a comma-separated list",
|
info=" What GPUs (by id) should be used for training on this machine as a comma-separated list",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_gpu_ids(value):
|
||||||
|
if value == "":
|
||||||
|
return
|
||||||
|
if not (
|
||||||
|
value.isdigit() and int(value) >= 0 and int(value) <= 128
|
||||||
|
):
|
||||||
|
log.error("GPU IDs must be an integer between 0 and 128")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
for id in value.split(","):
|
||||||
|
if not id.isdigit() or int(id) < 0 or int(id) > 128:
|
||||||
|
log.error(
|
||||||
|
"GPU IDs must be an integer between 0 and 128"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gpu_ids.blur(fn=validate_gpu_ids, inputs=self.gpu_ids)
|
||||||
|
|
||||||
self.main_process_port = gr.Number(
|
self.main_process_port = gr.Number(
|
||||||
label="Main process port",
|
label="Main process port",
|
||||||
value=self.config.get("accelerate_launch.main_process_port", 0),
|
value=self.config.get("accelerate_launch.main_process_port", 0),
|
||||||
|
|
@ -136,9 +162,14 @@ class AccelerateLaunch:
|
||||||
|
|
||||||
if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"):
|
if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"):
|
||||||
run_cmd.append("--dynamo_use_dynamic")
|
run_cmd.append("--dynamo_use_dynamic")
|
||||||
|
|
||||||
if "extra_accelerate_launch_args" in kwargs and kwargs["extra_accelerate_launch_args"] != "":
|
if (
|
||||||
extra_accelerate_launch_args = kwargs["extra_accelerate_launch_args"].replace('"', "")
|
"extra_accelerate_launch_args" in kwargs
|
||||||
|
and kwargs["extra_accelerate_launch_args"] != ""
|
||||||
|
):
|
||||||
|
extra_accelerate_launch_args = kwargs[
|
||||||
|
"extra_accelerate_launch_args"
|
||||||
|
].replace('"', "")
|
||||||
for arg in extra_accelerate_launch_args.split():
|
for arg in extra_accelerate_launch_args.split():
|
||||||
run_cmd.append(shlex.quote(arg))
|
run_cmd.append(shlex.quote(arg))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ class AdvancedTraining:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.loss_type = gr.Dropdown(
|
self.loss_type = gr.Dropdown(
|
||||||
label="Loss type",
|
label="Loss type",
|
||||||
choices=["huber", "smooth_l1", "l2"],
|
choices=["huber", "smooth_l1", "l1", "l2"],
|
||||||
value=self.config.get("advanced.loss_type", "l2"),
|
value=self.config.get("advanced.loss_type", "l2"),
|
||||||
info="The type of loss to use and whether it's scheduled based on the timestep",
|
info="The type of loss to use and whether it's scheduled based on the timestep",
|
||||||
)
|
)
|
||||||
|
|
@ -168,6 +168,14 @@ class AdvancedTraining:
|
||||||
step=0.01,
|
step=0.01,
|
||||||
info="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type",
|
info="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type",
|
||||||
)
|
)
|
||||||
|
self.huber_scale = gr.Number(
|
||||||
|
label="Huber scale",
|
||||||
|
value=self.config.get("advanced.huber_scale", 1.0),
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=10.0,
|
||||||
|
step=0.01,
|
||||||
|
info="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.save_every_n_steps = gr.Number(
|
self.save_every_n_steps = gr.Number(
|
||||||
|
|
@ -188,6 +196,18 @@ class AdvancedTraining:
|
||||||
precision=0,
|
precision=0,
|
||||||
info="(Optional) Save only the specified number of states (old models will be deleted)",
|
info="(Optional) Save only the specified number of states (old models will be deleted)",
|
||||||
)
|
)
|
||||||
|
self.save_last_n_epochs = gr.Number(
|
||||||
|
label="Save last N epochs",
|
||||||
|
value=self.config.get("advanced.save_last_n_epochs", 0),
|
||||||
|
precision=0,
|
||||||
|
info="(Optional) Save only the specified number of epochs (old epochs will be deleted)",
|
||||||
|
)
|
||||||
|
self.save_last_n_epochs_state = gr.Number(
|
||||||
|
label="Save last N epochs state",
|
||||||
|
value=self.config.get("advanced.save_last_n_epochs_state", 0),
|
||||||
|
precision=0,
|
||||||
|
info="(Optional) Save only the specified number of epochs states (old models will be deleted)",
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
def full_options_update(full_fp16, full_bf16):
|
def full_options_update(full_fp16, full_bf16):
|
||||||
|
|
@ -228,12 +248,16 @@ class AdvancedTraining:
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
if training_type == "lora":
|
self.fp8_base = gr.Checkbox(
|
||||||
self.fp8_base = gr.Checkbox(
|
label="fp8 base",
|
||||||
label="fp8 base training (experimental)",
|
info="Use fp8 for base model",
|
||||||
info="U-Net and Text Encoder can be trained with fp8 (experimental)",
|
value=self.config.get("advanced.fp8_base", False),
|
||||||
value=self.config.get("advanced.fp8_base", False),
|
)
|
||||||
)
|
self.fp8_base_unet = gr.Checkbox(
|
||||||
|
label="fp8 base unet",
|
||||||
|
info="Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16.",
|
||||||
|
value=self.config.get("advanced.fp8_base_unet", False),
|
||||||
|
)
|
||||||
self.full_fp16 = gr.Checkbox(
|
self.full_fp16 = gr.Checkbox(
|
||||||
label="Full fp16 training (experimental)",
|
label="Full fp16 training (experimental)",
|
||||||
value=self.config.get("advanced.full_fp16", False),
|
value=self.config.get("advanced.full_fp16", False),
|
||||||
|
|
@ -254,6 +278,25 @@ class AdvancedTraining:
|
||||||
inputs=[self.full_fp16, self.full_bf16],
|
inputs=[self.full_fp16, self.full_bf16],
|
||||||
outputs=[self.full_fp16, self.full_bf16],
|
outputs=[self.full_fp16, self.full_bf16],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.highvram = gr.Checkbox(
|
||||||
|
label="highvram",
|
||||||
|
value=self.config.get("advanced.highvram", False),
|
||||||
|
info="Disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM)",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.lowvram = gr.Checkbox(
|
||||||
|
label="lowvram",
|
||||||
|
value=self.config.get("advanced.lowvram", False),
|
||||||
|
info="Enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.skip_cache_check = gr.Checkbox(
|
||||||
|
label="Skip cache check",
|
||||||
|
value=self.config.get("advanced.skip_cache_check", False),
|
||||||
|
info="Skip cache check for faster training start",
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.gradient_checkpointing = gr.Checkbox(
|
self.gradient_checkpointing = gr.Checkbox(
|
||||||
|
|
@ -450,6 +493,15 @@ class AdvancedTraining:
|
||||||
value=self.config.get("advanced.vae_batch_size", 0),
|
value=self.config.get("advanced.vae_batch_size", 0),
|
||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
|
self.blocks_to_swap = gr.Slider(
|
||||||
|
label="Blocks to swap",
|
||||||
|
value=self.config.get("advanced.blocks_to_swap", 0),
|
||||||
|
info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.",
|
||||||
|
minimum=0,
|
||||||
|
maximum=57,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
with gr.Group(), gr.Row():
|
with gr.Group(), gr.Row():
|
||||||
self.save_state = gr.Checkbox(
|
self.save_state = gr.Checkbox(
|
||||||
label="Save training state",
|
label="Save training state",
|
||||||
|
|
@ -534,6 +586,11 @@ class AdvancedTraining:
|
||||||
self.current_log_tracker_config_dir = path if not path == "" else "."
|
self.current_log_tracker_config_dir = path if not path == "" else "."
|
||||||
return list(list_files(path, exts=[".json"], all=True))
|
return list(list_files(path, exts=[".json"], all=True))
|
||||||
|
|
||||||
|
self.log_config = gr.Checkbox(
|
||||||
|
label="Log config",
|
||||||
|
value=self.config.get("advanced.log_config", False),
|
||||||
|
info="Log training parameter to WANDB",
|
||||||
|
)
|
||||||
self.log_tracker_name = gr.Textbox(
|
self.log_tracker_name = gr.Textbox(
|
||||||
label="Log tracker name",
|
label="Log tracker name",
|
||||||
value=self.config.get("advanced.log_tracker_name", ""),
|
value=self.config.get("advanced.log_tracker_name", ""),
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ class BasicTraining:
|
||||||
learning_rate_value: float = "1e-6",
|
learning_rate_value: float = "1e-6",
|
||||||
lr_scheduler_value: str = "constant",
|
lr_scheduler_value: str = "constant",
|
||||||
lr_warmup_value: float = "0",
|
lr_warmup_value: float = "0",
|
||||||
|
lr_warmup_steps_value: int = 0,
|
||||||
finetuning: bool = False,
|
finetuning: bool = False,
|
||||||
dreambooth: bool = False,
|
dreambooth: bool = False,
|
||||||
config: dict = {},
|
config: dict = {},
|
||||||
|
|
@ -44,10 +45,14 @@ class BasicTraining:
|
||||||
self.learning_rate_value = learning_rate_value
|
self.learning_rate_value = learning_rate_value
|
||||||
self.lr_scheduler_value = lr_scheduler_value
|
self.lr_scheduler_value = lr_scheduler_value
|
||||||
self.lr_warmup_value = lr_warmup_value
|
self.lr_warmup_value = lr_warmup_value
|
||||||
|
self.lr_warmup_steps_value= lr_warmup_steps_value
|
||||||
self.finetuning = finetuning
|
self.finetuning = finetuning
|
||||||
self.dreambooth = dreambooth
|
self.dreambooth = dreambooth
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Initialize old_lr_warmup and old_lr_warmup_steps with default values
|
||||||
self.old_lr_warmup = 0
|
self.old_lr_warmup = 0
|
||||||
|
self.old_lr_warmup_steps = 0
|
||||||
|
|
||||||
# Initialize the UI components
|
# Initialize the UI components
|
||||||
self.initialize_ui_components()
|
self.initialize_ui_components()
|
||||||
|
|
@ -162,20 +167,37 @@ class BasicTraining:
|
||||||
"cosine",
|
"cosine",
|
||||||
"cosine_with_restarts",
|
"cosine_with_restarts",
|
||||||
"linear",
|
"linear",
|
||||||
|
"piecewise_constant",
|
||||||
"polynomial",
|
"polynomial",
|
||||||
|
"cosine_with_min_lr",
|
||||||
|
"inverse_sqrt",
|
||||||
|
"warmup_stable_decay",
|
||||||
],
|
],
|
||||||
value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
|
value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize the learning rate scheduler type dropdown
|
||||||
|
self.lr_scheduler_type = gr.Dropdown(
|
||||||
|
label="LR Scheduler type",
|
||||||
|
info="(Optional) custom scheduler module name",
|
||||||
|
choices=[
|
||||||
|
"",
|
||||||
|
"CosineAnnealingLR",
|
||||||
|
],
|
||||||
|
value=self.config.get("basic.lr_scheduler_type", ""),
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the optimizer dropdown
|
# Initialize the optimizer dropdown
|
||||||
self.optimizer = gr.Dropdown(
|
self.optimizer = gr.Dropdown(
|
||||||
label="Optimizer",
|
label="Optimizer",
|
||||||
choices=[
|
choices=[
|
||||||
"AdamW",
|
"AdamW",
|
||||||
|
"AdamWScheduleFree",
|
||||||
"AdamW8bit",
|
"AdamW8bit",
|
||||||
"Adafactor",
|
"Adafactor",
|
||||||
|
"bitsandbytes.optim.AdEMAMix8bit",
|
||||||
|
"bitsandbytes.optim.PagedAdEMAMix8bit",
|
||||||
"DAdaptation",
|
"DAdaptation",
|
||||||
"DAdaptAdaGrad",
|
"DAdaptAdaGrad",
|
||||||
"DAdaptAdam",
|
"DAdaptAdam",
|
||||||
|
|
@ -190,11 +212,15 @@ class BasicTraining:
|
||||||
"PagedAdamW32bit",
|
"PagedAdamW32bit",
|
||||||
"PagedLion8bit",
|
"PagedLion8bit",
|
||||||
"Prodigy",
|
"Prodigy",
|
||||||
|
"prodigyplus.ProdigyPlusScheduleFree",
|
||||||
|
"RAdamScheduleFree",
|
||||||
"SGDNesterov",
|
"SGDNesterov",
|
||||||
"SGDNesterov8bit",
|
"SGDNesterov8bit",
|
||||||
|
"SGDScheduleFree",
|
||||||
],
|
],
|
||||||
value=self.config.get("basic.optimizer", "AdamW8bit"),
|
value=self.config.get("basic.optimizer", "AdamW8bit"),
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
allow_custom_value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_grad_and_lr_controls(self) -> None:
|
def init_grad_and_lr_controls(self) -> None:
|
||||||
|
|
@ -240,7 +266,7 @@ class BasicTraining:
|
||||||
self.learning_rate = gr.Number(
|
self.learning_rate = gr.Number(
|
||||||
label=lr_label,
|
label=lr_label,
|
||||||
value=self.config.get("basic.learning_rate", self.learning_rate_value),
|
value=self.config.get("basic.learning_rate", self.learning_rate_value),
|
||||||
minimum=0,
|
minimum=-1,
|
||||||
maximum=1,
|
maximum=1,
|
||||||
info="Set to 0 to not train the Unet",
|
info="Set to 0 to not train the Unet",
|
||||||
)
|
)
|
||||||
|
|
@ -251,7 +277,7 @@ class BasicTraining:
|
||||||
"basic.learning_rate_te", self.learning_rate_value
|
"basic.learning_rate_te", self.learning_rate_value
|
||||||
),
|
),
|
||||||
visible=self.finetuning or self.dreambooth,
|
visible=self.finetuning or self.dreambooth,
|
||||||
minimum=0,
|
minimum=-1,
|
||||||
maximum=1,
|
maximum=1,
|
||||||
info="Set to 0 to not train the Text Encoder",
|
info="Set to 0 to not train the Text Encoder",
|
||||||
)
|
)
|
||||||
|
|
@ -262,7 +288,7 @@ class BasicTraining:
|
||||||
"basic.learning_rate_te1", self.learning_rate_value
|
"basic.learning_rate_te1", self.learning_rate_value
|
||||||
),
|
),
|
||||||
visible=False,
|
visible=False,
|
||||||
minimum=0,
|
minimum=-1,
|
||||||
maximum=1,
|
maximum=1,
|
||||||
info="Set to 0 to not train the Text Encoder 1",
|
info="Set to 0 to not train the Text Encoder 1",
|
||||||
)
|
)
|
||||||
|
|
@ -273,7 +299,7 @@ class BasicTraining:
|
||||||
"basic.learning_rate_te2", self.learning_rate_value
|
"basic.learning_rate_te2", self.learning_rate_value
|
||||||
),
|
),
|
||||||
visible=False,
|
visible=False,
|
||||||
minimum=0,
|
minimum=-1,
|
||||||
maximum=1,
|
maximum=1,
|
||||||
info="Set to 0 to not train the Text Encoder 2",
|
info="Set to 0 to not train the Text Encoder 2",
|
||||||
)
|
)
|
||||||
|
|
@ -285,25 +311,37 @@ class BasicTraining:
|
||||||
maximum=100,
|
maximum=100,
|
||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
|
# Initialize the learning rate warmup steps override
|
||||||
|
self.lr_warmup_steps = gr.Number(
|
||||||
|
label="LR warmup steps (override)",
|
||||||
|
value=self.config.get("basic.lr_warmup_steps", self.lr_warmup_steps_value),
|
||||||
|
minimum=0,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
|
||||||
def lr_scheduler_changed(scheduler, value):
|
def lr_scheduler_changed(scheduler, value, value_lr_warmup_steps):
|
||||||
if scheduler == "constant":
|
if scheduler == "constant":
|
||||||
self.old_lr_warmup = value
|
self.old_lr_warmup = value
|
||||||
|
self.old_lr_warmup_steps = value_lr_warmup_steps
|
||||||
value = 0
|
value = 0
|
||||||
|
value_lr_warmup_steps = 0
|
||||||
interactive=False
|
interactive=False
|
||||||
info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..."
|
info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..."
|
||||||
else:
|
else:
|
||||||
if self.old_lr_warmup != 0:
|
if self.old_lr_warmup != 0:
|
||||||
value = self.old_lr_warmup
|
value = self.old_lr_warmup
|
||||||
self.old_lr_warmup = 0
|
self.old_lr_warmup = 0
|
||||||
|
if self.old_lr_warmup_steps != 0:
|
||||||
|
value_lr_warmup_steps = self.old_lr_warmup_steps
|
||||||
|
self.old_lr_warmup_steps = 0
|
||||||
interactive=True
|
interactive=True
|
||||||
info=""
|
info=""
|
||||||
return gr.Slider(value=value, interactive=interactive, info=info)
|
return gr.Slider(value=value, interactive=interactive, info=info), gr.Number(value=value_lr_warmup_steps, interactive=interactive, info=info)
|
||||||
|
|
||||||
self.lr_scheduler.change(
|
self.lr_scheduler.change(
|
||||||
lr_scheduler_changed,
|
lr_scheduler_changed,
|
||||||
inputs=[self.lr_scheduler, self.lr_warmup],
|
inputs=[self.lr_scheduler, self.lr_warmup, self.lr_warmup_steps],
|
||||||
outputs=self.lr_warmup,
|
outputs=[self.lr_warmup, self.lr_warmup_steps],
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_scheduler_controls(self) -> None:
|
def init_scheduler_controls(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class CommandExecutor:
|
||||||
|
|
||||||
# Execute the command securely
|
# Execute the command securely
|
||||||
self.process = subprocess.Popen(run_cmd, **kwargs)
|
self.process = subprocess.Popen(run_cmd, **kwargs)
|
||||||
log.info("Command executed.")
|
log.debug("Command executed.")
|
||||||
|
|
||||||
def kill_command(self):
|
def kill_command(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,336 @@
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Tuple
|
||||||
|
from .common_gui import (
|
||||||
|
get_any_file_path,
|
||||||
|
document_symbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class flux1Training:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
headless: bool = False,
|
||||||
|
finetuning: bool = False,
|
||||||
|
training_type: str = "",
|
||||||
|
config: dict = {},
|
||||||
|
flux1_checkbox: gr.Checkbox = False,
|
||||||
|
) -> None:
|
||||||
|
self.headless = headless
|
||||||
|
self.finetuning = finetuning
|
||||||
|
self.training_type = training_type
|
||||||
|
self.config = config
|
||||||
|
self.flux1_checkbox = flux1_checkbox
|
||||||
|
|
||||||
|
# Define the behavior for changing noise offset type.
|
||||||
|
def noise_offset_type_change(
|
||||||
|
noise_offset_type: str,
|
||||||
|
) -> Tuple[gr.Group, gr.Group]:
|
||||||
|
if noise_offset_type == "Original":
|
||||||
|
return (gr.Group(visible=True), gr.Group(visible=False))
|
||||||
|
else:
|
||||||
|
return (gr.Group(visible=False), gr.Group(visible=True))
|
||||||
|
|
||||||
|
with gr.Accordion(
|
||||||
|
"Flux.1", open=True, visible=False, elem_classes=["flux1_background"]
|
||||||
|
) as flux1_accordion:
|
||||||
|
with gr.Group():
|
||||||
|
with gr.Row():
|
||||||
|
self.ae = gr.Textbox(
|
||||||
|
label="VAE Path",
|
||||||
|
placeholder="Path to VAE model",
|
||||||
|
value=self.config.get("flux1.ae", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.ae_button = gr.Button(
|
||||||
|
document_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
visible=(not headless),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.ae_button.click(
|
||||||
|
get_any_file_path,
|
||||||
|
outputs=self.ae,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_l = gr.Textbox(
|
||||||
|
label="CLIP-L Path",
|
||||||
|
placeholder="Path to CLIP-L model",
|
||||||
|
value=self.config.get("flux1.clip_l", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.clip_l_button = gr.Button(
|
||||||
|
document_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
visible=(not headless),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.clip_l_button.click(
|
||||||
|
get_any_file_path,
|
||||||
|
outputs=self.clip_l,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t5xxl = gr.Textbox(
|
||||||
|
label="T5-XXL Path",
|
||||||
|
placeholder="Path to T5-XXL model",
|
||||||
|
value=self.config.get("flux1.t5xxl", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.t5xxl_button = gr.Button(
|
||||||
|
document_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
visible=(not headless),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.t5xxl_button.click(
|
||||||
|
get_any_file_path,
|
||||||
|
outputs=self.t5xxl,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
|
||||||
|
self.discrete_flow_shift = gr.Number(
|
||||||
|
label="Discrete Flow Shift",
|
||||||
|
value=self.config.get("flux1.discrete_flow_shift", 3.0),
|
||||||
|
info="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0",
|
||||||
|
minimum=-1024,
|
||||||
|
maximum=1024,
|
||||||
|
step=0.01,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.model_prediction_type = gr.Dropdown(
|
||||||
|
label="Model Prediction Type",
|
||||||
|
choices=["raw", "additive", "sigma_scaled"],
|
||||||
|
value=self.config.get(
|
||||||
|
"flux1.timestep_sampling", "sigma_scaled"
|
||||||
|
),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.timestep_sampling = gr.Dropdown(
|
||||||
|
label="Timestep Sampling",
|
||||||
|
choices=["flux_shift", "sigma", "shift", "sigmoid", "uniform"],
|
||||||
|
value=self.config.get("flux1.timestep_sampling", "sigma"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.apply_t5_attn_mask = gr.Checkbox(
|
||||||
|
label="Apply T5 Attention Mask",
|
||||||
|
value=self.config.get("flux1.apply_t5_attn_mask", False),
|
||||||
|
info="Apply attention mask to T5-XXL encode and FLUX double blocks ",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row(visible=True if not finetuning else False):
|
||||||
|
self.split_mode = gr.Checkbox(
|
||||||
|
label="Split Mode",
|
||||||
|
value=self.config.get("flux1.split_mode", False),
|
||||||
|
info="Split mode for Flux1",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.train_blocks = gr.Dropdown(
|
||||||
|
label="Train Blocks",
|
||||||
|
choices=["all", "double", "single"],
|
||||||
|
value=self.config.get("flux1.train_blocks", "all"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.split_qkv = gr.Checkbox(
|
||||||
|
label="Split QKV",
|
||||||
|
value=self.config.get("flux1.split_qkv", False),
|
||||||
|
info="Split the projection layers of q/k/v/txt in the attention",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.train_t5xxl = gr.Checkbox(
|
||||||
|
label="Train T5-XXL",
|
||||||
|
value=self.config.get("flux1.train_t5xxl", False),
|
||||||
|
info="Train T5-XXL model",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.cpu_offload_checkpointing = gr.Checkbox(
|
||||||
|
label="CPU Offload Checkpointing",
|
||||||
|
value=self.config.get("flux1.cpu_offload_checkpointing", False),
|
||||||
|
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.guidance_scale = gr.Number(
|
||||||
|
label="Guidance Scale",
|
||||||
|
value=self.config.get("flux1.guidance_scale", 3.5),
|
||||||
|
info="Guidance scale for Flux1",
|
||||||
|
minimum=0,
|
||||||
|
maximum=1024,
|
||||||
|
step=0.1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.t5xxl_max_token_length = gr.Number(
|
||||||
|
label="T5-XXL Max Token Length",
|
||||||
|
value=self.config.get("flux1.t5xxl_max_token_length", 512),
|
||||||
|
info="Max token length for T5-XXL",
|
||||||
|
minimum=0,
|
||||||
|
maximum=4096,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.enable_all_linear = gr.Checkbox(
|
||||||
|
label="Enable All Linear",
|
||||||
|
value=self.config.get("flux1.enable_all_linear", False),
|
||||||
|
info="(Only applicable to 'FLux1 OFT' LoRA) Target all linear connections in the MLP layer. The default is False, which targets only attention.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.flux1_cache_text_encoder_outputs = gr.Checkbox(
|
||||||
|
label="Cache Text Encoder Outputs",
|
||||||
|
value=self.config.get(
|
||||||
|
"flux1.cache_text_encoder_outputs", False
|
||||||
|
),
|
||||||
|
info="Cache text encoder outputs to speed up inference",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
||||||
|
label="Cache Text Encoder Outputs to Disk",
|
||||||
|
value=self.config.get(
|
||||||
|
"flux1.cache_text_encoder_outputs_to_disk", False
|
||||||
|
),
|
||||||
|
info="Cache text encoder outputs to disk to speed up inference",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.mem_eff_save = gr.Checkbox(
|
||||||
|
label="Memory Efficient Save",
|
||||||
|
value=self.config.get("flux1.mem_eff_save", False),
|
||||||
|
info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# self.blocks_to_swap = gr.Slider(
|
||||||
|
# label="Blocks to swap",
|
||||||
|
# value=self.config.get("flux1.blocks_to_swap", 0),
|
||||||
|
# info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.",
|
||||||
|
# minimum=0,
|
||||||
|
# maximum=57,
|
||||||
|
# step=1,
|
||||||
|
# interactive=True,
|
||||||
|
# )
|
||||||
|
self.single_blocks_to_swap = gr.Slider(
|
||||||
|
label="Single Blocks to swap (depercated)",
|
||||||
|
value=self.config.get("flux1.single_blocks_to_swap", 0),
|
||||||
|
info="[Experimental] Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes.",
|
||||||
|
minimum=0,
|
||||||
|
maximum=19,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.double_blocks_to_swap = gr.Slider(
|
||||||
|
label="Double Blocks to swap (depercated)",
|
||||||
|
value=self.config.get("flux1.double_blocks_to_swap", 0),
|
||||||
|
info="[Experimental] Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes.",
|
||||||
|
minimum=0,
|
||||||
|
maximum=38,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row(visible=True if finetuning else False):
|
||||||
|
self.blockwise_fused_optimizers = gr.Checkbox(
|
||||||
|
label="Blockwise Fused Optimizer",
|
||||||
|
value=self.config.get(
|
||||||
|
"flux1.blockwise_fused_optimizers", False
|
||||||
|
),
|
||||||
|
info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.cpu_offload_checkpointing = gr.Checkbox(
|
||||||
|
label="CPU Offload Checkpointing",
|
||||||
|
value=self.config.get("flux1.cpu_offload_checkpointing", False),
|
||||||
|
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.flux_fused_backward_pass = gr.Checkbox(
|
||||||
|
label="Fused Backward Pass",
|
||||||
|
value=self.config.get("flux1.fused_backward_pass", False),
|
||||||
|
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion(
|
||||||
|
"Blocks to train",
|
||||||
|
open=True,
|
||||||
|
visible=False if finetuning else True,
|
||||||
|
elem_classes=["flux1_blocks_to_train_background"],
|
||||||
|
):
|
||||||
|
with gr.Row():
|
||||||
|
self.train_double_block_indices = gr.Textbox(
|
||||||
|
label="train_double_block_indices",
|
||||||
|
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of double blocks is 19.",
|
||||||
|
value=self.config.get("flux1.train_double_block_indices", "all"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.train_single_block_indices = gr.Textbox(
|
||||||
|
label="train_single_block_indices",
|
||||||
|
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of single blocks is 38.",
|
||||||
|
value=self.config.get("flux1.train_single_block_indices", "all"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion(
|
||||||
|
"Rank for layers",
|
||||||
|
open=False,
|
||||||
|
visible=False if finetuning else True,
|
||||||
|
elem_classes=["flux1_rank_layers_background"],
|
||||||
|
):
|
||||||
|
with gr.Row():
|
||||||
|
self.img_attn_dim = gr.Textbox(
|
||||||
|
label="img_attn_dim",
|
||||||
|
value=self.config.get("flux1.img_attn_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.img_mlp_dim = gr.Textbox(
|
||||||
|
label="img_mlp_dim",
|
||||||
|
value=self.config.get("flux1.img_mlp_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.img_mod_dim = gr.Textbox(
|
||||||
|
label="img_mod_dim",
|
||||||
|
value=self.config.get("flux1.img_mod_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.single_dim = gr.Textbox(
|
||||||
|
label="single_dim",
|
||||||
|
value=self.config.get("flux1.single_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.txt_attn_dim = gr.Textbox(
|
||||||
|
label="txt_attn_dim",
|
||||||
|
value=self.config.get("flux1.txt_attn_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.txt_mlp_dim = gr.Textbox(
|
||||||
|
label="txt_mlp_dim",
|
||||||
|
value=self.config.get("flux1.txt_mlp_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.txt_mod_dim = gr.Textbox(
|
||||||
|
label="txt_mod_dim",
|
||||||
|
value=self.config.get("flux1.txt_mod_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.single_mod_dim = gr.Textbox(
|
||||||
|
label="single_mod_dim",
|
||||||
|
value=self.config.get("flux1.single_mod_dim", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.in_dims = gr.Textbox(
|
||||||
|
label="in_dims",
|
||||||
|
value=self.config.get("flux1.in_dims", ""),
|
||||||
|
placeholder="e.g., [4,0,0,0,4]",
|
||||||
|
info="Each number corresponds to img_in, time_in, vector_in, guidance_in, txt_in. The above example applies LoRA to all conditioning layers, with rank 4 for img_in, 2 for time_in, vector_in, guidance_in, and 4 for txt_in.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.flux1_checkbox.change(
|
||||||
|
lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox),
|
||||||
|
inputs=[self.flux1_checkbox],
|
||||||
|
outputs=[flux1_accordion],
|
||||||
|
)
|
||||||
|
|
@ -4,10 +4,12 @@ from .svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
||||||
from .verify_lora_gui import gradio_verify_lora_tab
|
from .verify_lora_gui import gradio_verify_lora_tab
|
||||||
from .resize_lora_gui import gradio_resize_lora_tab
|
from .resize_lora_gui import gradio_resize_lora_tab
|
||||||
from .extract_lora_gui import gradio_extract_lora_tab
|
from .extract_lora_gui import gradio_extract_lora_tab
|
||||||
|
from .flux_extract_lora_gui import gradio_flux_extract_lora_tab
|
||||||
from .convert_lcm_gui import gradio_convert_lcm_tab
|
from .convert_lcm_gui import gradio_convert_lcm_tab
|
||||||
from .extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
from .extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
||||||
from .extract_lora_from_dylora_gui import gradio_extract_dylora_tab
|
from .extract_lora_from_dylora_gui import gradio_extract_dylora_tab
|
||||||
from .merge_lycoris_gui import gradio_merge_lycoris_tab
|
from .merge_lycoris_gui import gradio_merge_lycoris_tab
|
||||||
|
from .flux_merge_lora_gui import GradioFluxMergeLoRaTab
|
||||||
|
|
||||||
|
|
||||||
class LoRATools:
|
class LoRATools:
|
||||||
|
|
@ -19,9 +21,11 @@ class LoRATools:
|
||||||
gradio_extract_dylora_tab(headless=headless)
|
gradio_extract_dylora_tab(headless=headless)
|
||||||
gradio_convert_lcm_tab(headless=headless)
|
gradio_convert_lcm_tab(headless=headless)
|
||||||
gradio_extract_lora_tab(headless=headless)
|
gradio_extract_lora_tab(headless=headless)
|
||||||
|
gradio_flux_extract_lora_tab(headless=headless)
|
||||||
gradio_extract_lycoris_locon_tab(headless=headless)
|
gradio_extract_lycoris_locon_tab(headless=headless)
|
||||||
gradio_merge_lora_tab = GradioMergeLoRaTab()
|
gradio_merge_lora_tab = GradioMergeLoRaTab()
|
||||||
gradio_merge_lycoris_tab(headless=headless)
|
gradio_merge_lycoris_tab(headless=headless)
|
||||||
gradio_svd_merge_lora_tab(headless=headless)
|
gradio_svd_merge_lora_tab(headless=headless)
|
||||||
gradio_resize_lora_tab(headless=headless)
|
gradio_resize_lora_tab(headless=headless)
|
||||||
gradio_verify_lora_tab(headless=headless)
|
gradio_verify_lora_tab(headless=headless)
|
||||||
|
GradioFluxMergeLoRaTab(headless=headless)
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,10 @@ def create_prompt_file(sample_prompts, output_dir):
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the prompt file.
|
str: The path to the prompt file.
|
||||||
"""
|
"""
|
||||||
sample_prompts_path = os.path.join(output_dir, "prompt.txt")
|
sample_prompts_path = os.path.join(output_dir, "sample/prompt.txt")
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.dirname(sample_prompts_path)):
|
||||||
|
os.makedirs(os.path.dirname(sample_prompts_path))
|
||||||
|
|
||||||
with open(sample_prompts_path, "w", encoding="utf-8") as f:
|
with open(sample_prompts_path, "w", encoding="utf-8") as f:
|
||||||
f.write(sample_prompts)
|
f.write(sample_prompts)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,249 @@
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Tuple
|
||||||
|
from .common_gui import (
|
||||||
|
get_folder_path,
|
||||||
|
get_any_file_path,
|
||||||
|
list_files,
|
||||||
|
list_dirs,
|
||||||
|
create_refresh_button,
|
||||||
|
document_symbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class sd3Training:
|
||||||
|
"""
|
||||||
|
This class configures and initializes the advanced training settings for a machine learning model,
|
||||||
|
including options for headless operation, fine-tuning, training type selection, and default directory paths.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
headless (bool): If True, run without the Gradio interface.
|
||||||
|
finetuning (bool): If True, enables fine-tuning of the model.
|
||||||
|
training_type (str): Specifies the type of training to perform.
|
||||||
|
no_token_padding (gr.Checkbox): Checkbox to disable token padding.
|
||||||
|
gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps.
|
||||||
|
weighted_captions (gr.Checkbox): Checkbox to enable weighted captions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
headless: bool = False,
|
||||||
|
finetuning: bool = False,
|
||||||
|
training_type: str = "",
|
||||||
|
config: dict = {},
|
||||||
|
sd3_checkbox: gr.Checkbox = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the AdvancedTraining class with given settings.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
headless (bool): Run in headless mode without GUI.
|
||||||
|
finetuning (bool): Enable model fine-tuning.
|
||||||
|
training_type (str): The type of training to be performed.
|
||||||
|
config (dict): Configuration options for the training process.
|
||||||
|
"""
|
||||||
|
self.headless = headless
|
||||||
|
self.finetuning = finetuning
|
||||||
|
self.training_type = training_type
|
||||||
|
self.config = config
|
||||||
|
self.sd3_checkbox = sd3_checkbox
|
||||||
|
|
||||||
|
# Define the behavior for changing noise offset type.
|
||||||
|
def noise_offset_type_change(
|
||||||
|
noise_offset_type: str,
|
||||||
|
) -> Tuple[gr.Group, gr.Group]:
|
||||||
|
"""
|
||||||
|
Returns a tuple of Gradio Groups with visibility set based on the noise offset type.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
noise_offset_type (str): The selected noise offset type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set.
|
||||||
|
"""
|
||||||
|
if noise_offset_type == "Original":
|
||||||
|
return (gr.Group(visible=True), gr.Group(visible=False))
|
||||||
|
else:
|
||||||
|
return (gr.Group(visible=False), gr.Group(visible=True))
|
||||||
|
|
||||||
|
with gr.Accordion(
|
||||||
|
"SD3", open=False, elem_id="sd3_tab", visible=False
|
||||||
|
) as sd3_accordion:
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### SD3 Specific Parameters")
|
||||||
|
with gr.Row():
|
||||||
|
self.weighting_scheme = gr.Dropdown(
|
||||||
|
label="Weighting Scheme",
|
||||||
|
choices=["logit_normal", "sigma_sqrt", "mode", "cosmap", "uniform"],
|
||||||
|
value=self.config.get("sd3.weighting_scheme", "logit_normal"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.logit_mean = gr.Number(
|
||||||
|
label="Logit Mean",
|
||||||
|
value=self.config.get("sd3.logit_mean", 0.0),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.logit_std = gr.Number(
|
||||||
|
label="Logit Std",
|
||||||
|
value=self.config.get("sd3.logit_std", 1.0),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.mode_scale = gr.Number(
|
||||||
|
label="Mode Scale",
|
||||||
|
value=self.config.get("sd3.mode_scale", 1.29),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.clip_l = gr.Textbox(
|
||||||
|
label="CLIP-L Path",
|
||||||
|
placeholder="Path to CLIP-L model",
|
||||||
|
value=self.config.get("sd3.clip_l", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.clip_l_button = gr.Button(
|
||||||
|
document_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
visible=(not headless),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.clip_l_button.click(
|
||||||
|
get_any_file_path,
|
||||||
|
outputs=self.clip_l,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_g = gr.Textbox(
|
||||||
|
label="CLIP-G Path",
|
||||||
|
placeholder="Path to CLIP-G model",
|
||||||
|
value=self.config.get("sd3.clip_g", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.clip_g_button = gr.Button(
|
||||||
|
document_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
visible=(not headless),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.clip_g_button.click(
|
||||||
|
get_any_file_path,
|
||||||
|
outputs=self.clip_g,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t5xxl = gr.Textbox(
|
||||||
|
label="T5-XXL Path",
|
||||||
|
placeholder="Path to T5-XXL model",
|
||||||
|
value=self.config.get("sd3.t5xxl", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.t5xxl_button = gr.Button(
|
||||||
|
document_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
visible=(not headless),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.t5xxl_button.click(
|
||||||
|
get_any_file_path,
|
||||||
|
outputs=self.t5xxl,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.save_clip = gr.Checkbox(
|
||||||
|
label="Save CLIP models",
|
||||||
|
value=self.config.get("sd3.save_clip", False),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.save_t5xxl = gr.Checkbox(
|
||||||
|
label="Save T5-XXL model",
|
||||||
|
value=self.config.get("sd3.save_t5xxl", False),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.t5xxl_device = gr.Textbox(
|
||||||
|
label="T5-XXL Device",
|
||||||
|
placeholder="Device for T5-XXL (e.g., cuda:0)",
|
||||||
|
value=self.config.get("sd3.t5xxl_device", ""),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.t5xxl_dtype = gr.Dropdown(
|
||||||
|
label="T5-XXL Dtype",
|
||||||
|
choices=["float32", "fp16", "bf16"],
|
||||||
|
value=self.config.get("sd3.t5xxl_dtype", "bf16"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.sd3_text_encoder_batch_size = gr.Number(
|
||||||
|
label="Text Encoder Batch Size",
|
||||||
|
value=self.config.get("sd3.text_encoder_batch_size", 1),
|
||||||
|
minimum=1,
|
||||||
|
maximum=1024,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.sd3_cache_text_encoder_outputs = gr.Checkbox(
|
||||||
|
label="Cache Text Encoder Outputs",
|
||||||
|
value=self.config.get("sd3.cache_text_encoder_outputs", False),
|
||||||
|
info="Cache text encoder outputs to speed up inference",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.sd3_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
||||||
|
label="Cache Text Encoder Outputs to Disk",
|
||||||
|
value=self.config.get(
|
||||||
|
"sd3.cache_text_encoder_outputs_to_disk", False
|
||||||
|
),
|
||||||
|
info="Cache text encoder outputs to disk to speed up inference",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.clip_l_dropout_rate = gr.Number(
|
||||||
|
label="CLIP-L Dropout Rate",
|
||||||
|
value=self.config.get("sd3.clip_l_dropout_rate", 0.0),
|
||||||
|
interactive=True,
|
||||||
|
minimum=0.0,
|
||||||
|
info="Dropout rate for CLIP-L encoder"
|
||||||
|
)
|
||||||
|
self.clip_g_dropout_rate = gr.Number(
|
||||||
|
label="CLIP-G Dropout Rate",
|
||||||
|
value=self.config.get("sd3.clip_g_dropout_rate", 0.0),
|
||||||
|
interactive=True,
|
||||||
|
minimum=0.0,
|
||||||
|
info="Dropout rate for CLIP-G encoder"
|
||||||
|
)
|
||||||
|
self.t5_dropout_rate = gr.Number(
|
||||||
|
label="T5 Dropout Rate",
|
||||||
|
value=self.config.get("sd3.t5_dropout_rate", 0.0),
|
||||||
|
interactive=True,
|
||||||
|
minimum=0.0,
|
||||||
|
info="Dropout rate for T5-XXL encoder"
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.sd3_fused_backward_pass = gr.Checkbox(
|
||||||
|
label="Fused Backward Pass",
|
||||||
|
value=self.config.get("sd3.fused_backward_pass", False),
|
||||||
|
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.disable_mmap_load_safetensors = gr.Checkbox(
|
||||||
|
label="Disable mmap load safe tensors",
|
||||||
|
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
|
||||||
|
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
|
||||||
|
)
|
||||||
|
self.enable_scaled_pos_embed = gr.Checkbox(
|
||||||
|
label="Enable Scaled Positional Embeddings",
|
||||||
|
info="Enable scaled positional embeddings in the model.",
|
||||||
|
value=self.config.get("sd3.enable_scaled_pos_embed", False),
|
||||||
|
)
|
||||||
|
self.pos_emb_random_crop_rate = gr.Number(
|
||||||
|
label="Positional Embedding Random Crop Rate",
|
||||||
|
value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0),
|
||||||
|
interactive=True,
|
||||||
|
minimum=0.0,
|
||||||
|
info="Random crop rate for positional embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sd3_checkbox.change(
|
||||||
|
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
|
||||||
|
inputs=[self.sd3_checkbox],
|
||||||
|
outputs=[sd3_accordion],
|
||||||
|
)
|
||||||
|
|
@ -7,10 +7,12 @@ class SDXLParameters:
|
||||||
sdxl_checkbox: gr.Checkbox,
|
sdxl_checkbox: gr.Checkbox,
|
||||||
show_sdxl_cache_text_encoder_outputs: bool = True,
|
show_sdxl_cache_text_encoder_outputs: bool = True,
|
||||||
config: KohyaSSGUIConfig = {},
|
config: KohyaSSGUIConfig = {},
|
||||||
|
trainer: str = "",
|
||||||
):
|
):
|
||||||
self.sdxl_checkbox = sdxl_checkbox
|
self.sdxl_checkbox = sdxl_checkbox
|
||||||
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
|
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
self.initialize_accordion()
|
self.initialize_accordion()
|
||||||
|
|
||||||
|
|
@ -30,6 +32,41 @@ class SDXLParameters:
|
||||||
info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.",
|
info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.",
|
||||||
value=self.config.get("sdxl.sdxl_no_half_vae", False),
|
value=self.config.get("sdxl.sdxl_no_half_vae", False),
|
||||||
)
|
)
|
||||||
|
self.fused_backward_pass = gr.Checkbox(
|
||||||
|
label="Fused backward pass",
|
||||||
|
info="Enable fused backward pass. This option is useful to reduce the GPU memory usage. Can't be used if Fused optimizer groups is > 0. Only AdaFactor is supported",
|
||||||
|
value=self.config.get("sdxl.fused_backward_pass", False),
|
||||||
|
visible=self.trainer == "finetune" or self.trainer == "dreambooth",
|
||||||
|
)
|
||||||
|
self.fused_optimizer_groups = gr.Number(
|
||||||
|
label="Fused optimizer groups",
|
||||||
|
info="Number of optimizer groups to fuse. This option is useful to reduce the GPU memory usage. Can't be used if Fused backward pass is enabled. Since the effect is limited to a certain number, it is recommended to specify 4-10.",
|
||||||
|
value=self.config.get("sdxl.fused_optimizer_groups", 0),
|
||||||
|
minimum=0,
|
||||||
|
step=1,
|
||||||
|
visible=self.trainer == "finetune" or self.trainer == "dreambooth",
|
||||||
|
)
|
||||||
|
self.disable_mmap_load_safetensors = gr.Checkbox(
|
||||||
|
label="Disable mmap load safe tensors",
|
||||||
|
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
|
||||||
|
value=self.config.get("sdxl.disable_mmap_load_safetensors", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fused_backward_pass.change(
|
||||||
|
lambda fused_backward_pass: gr.Number(
|
||||||
|
interactive=not fused_backward_pass
|
||||||
|
),
|
||||||
|
inputs=[self.fused_backward_pass],
|
||||||
|
outputs=[self.fused_optimizer_groups],
|
||||||
|
)
|
||||||
|
self.fused_optimizer_groups.change(
|
||||||
|
lambda fused_optimizer_groups: gr.Checkbox(
|
||||||
|
interactive=fused_optimizer_groups == 0
|
||||||
|
),
|
||||||
|
inputs=[self.fused_optimizer_groups],
|
||||||
|
outputs=[self.fused_backward_pass],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
self.sdxl_checkbox.change(
|
self.sdxl_checkbox.change(
|
||||||
lambda sdxl_checkbox: gr.Accordion(visible=sdxl_checkbox),
|
lambda sdxl_checkbox: gr.Accordion(visible=sdxl_checkbox),
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ default_models = [
|
||||||
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
|
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
|
||||||
"stabilityai/stable-diffusion-2-1",
|
"stabilityai/stable-diffusion-2-1",
|
||||||
"stabilityai/stable-diffusion-2",
|
"stabilityai/stable-diffusion-2",
|
||||||
"runwayml/stable-diffusion-v1-5",
|
|
||||||
"CompVis/stable-diffusion-v1-4",
|
"CompVis/stable-diffusion-v1-4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -245,19 +244,88 @@ class SourceModel:
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.v2 = gr.Checkbox(
|
self.v2 = gr.Checkbox(
|
||||||
label="v2", value=False, visible=False, min_width=60
|
label="v2", value=False, visible=False, min_width=60,
|
||||||
|
interactive=True,
|
||||||
)
|
)
|
||||||
self.v_parameterization = gr.Checkbox(
|
self.v_parameterization = gr.Checkbox(
|
||||||
label="v_parameterization",
|
label="v_param",
|
||||||
value=False,
|
value=False,
|
||||||
visible=False,
|
visible=False,
|
||||||
min_width=130,
|
min_width=130,
|
||||||
|
interactive=True,
|
||||||
)
|
)
|
||||||
self.sdxl_checkbox = gr.Checkbox(
|
self.sdxl_checkbox = gr.Checkbox(
|
||||||
label="SDXL",
|
label="SDXL",
|
||||||
value=False,
|
value=False,
|
||||||
visible=False,
|
visible=False,
|
||||||
min_width=60,
|
min_width=60,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.sd3_checkbox = gr.Checkbox(
|
||||||
|
label="SD3",
|
||||||
|
value=False,
|
||||||
|
visible=False,
|
||||||
|
min_width=60,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.flux1_checkbox = gr.Checkbox(
|
||||||
|
label="Flux.1",
|
||||||
|
value=False,
|
||||||
|
visible=False,
|
||||||
|
min_width=60,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def toggle_checkboxes(v2, v_parameterization, sdxl_checkbox, sd3_checkbox, flux1_checkbox):
|
||||||
|
# Check if all checkboxes are unchecked
|
||||||
|
if not v2 and not v_parameterization and not sdxl_checkbox and not sd3_checkbox and not flux1_checkbox:
|
||||||
|
# If all unchecked, return new interactive checkboxes
|
||||||
|
return (
|
||||||
|
gr.Checkbox(interactive=True), # v2 checkbox
|
||||||
|
gr.Checkbox(interactive=True), # v_parameterization checkbox
|
||||||
|
gr.Checkbox(interactive=True), # sdxl_checkbox
|
||||||
|
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||||
|
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If any checkbox is checked, return checkboxes with current interactive state
|
||||||
|
return (
|
||||||
|
gr.Checkbox(interactive=v2), # v2 checkbox
|
||||||
|
gr.Checkbox(interactive=v_parameterization), # v_parameterization checkbox
|
||||||
|
gr.Checkbox(interactive=sdxl_checkbox), # sdxl_checkbox
|
||||||
|
gr.Checkbox(interactive=sd3_checkbox), # sd3_checkbox
|
||||||
|
gr.Checkbox(interactive=flux1_checkbox), # flux1_checkbox
|
||||||
|
)
|
||||||
|
|
||||||
|
self.v2.change(
|
||||||
|
fn=toggle_checkboxes,
|
||||||
|
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
self.v_parameterization.change(
|
||||||
|
fn=toggle_checkboxes,
|
||||||
|
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
self.sdxl_checkbox.change(
|
||||||
|
fn=toggle_checkboxes,
|
||||||
|
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
self.sd3_checkbox.change(
|
||||||
|
fn=toggle_checkboxes,
|
||||||
|
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
self.flux1_checkbox.change(
|
||||||
|
fn=toggle_checkboxes,
|
||||||
|
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Group(visible=False)
|
gr.Group(visible=False)
|
||||||
|
|
@ -294,6 +362,8 @@ class SourceModel:
|
||||||
self.v2,
|
self.v2,
|
||||||
self.v_parameterization,
|
self.v_parameterization,
|
||||||
self.sdxl_checkbox,
|
self.sdxl_checkbox,
|
||||||
|
self.sd3_checkbox,
|
||||||
|
self.flux1_checkbox,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from .common_gui import setup_environment
|
||||||
|
|
||||||
class TensorboardManager:
|
class TensorboardManager:
|
||||||
DEFAULT_TENSORBOARD_PORT = 6006
|
DEFAULT_TENSORBOARD_PORT = 6006
|
||||||
|
DEFAULT_TENSORBOARD_HOST = "0.0.0.0"
|
||||||
|
|
||||||
def __init__(self, logging_dir, headless: bool = False, wait_time=5):
|
def __init__(self, logging_dir, headless: bool = False, wait_time=5):
|
||||||
self.logging_dir = logging_dir
|
self.logging_dir = logging_dir
|
||||||
|
|
@ -29,6 +30,9 @@ class TensorboardManager:
|
||||||
self.tensorboard_port = os.environ.get(
|
self.tensorboard_port = os.environ.get(
|
||||||
"TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT
|
"TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT
|
||||||
)
|
)
|
||||||
|
self.tensorboard_host = os.environ.get(
|
||||||
|
"TENSORBOARD_HOST", self.DEFAULT_TENSORBOARD_HOST
|
||||||
|
)
|
||||||
self.log = setup_logging()
|
self.log = setup_logging()
|
||||||
self.thread = None
|
self.thread = None
|
||||||
self.stop_event = Event()
|
self.stop_event = Event()
|
||||||
|
|
@ -64,7 +68,7 @@ class TensorboardManager:
|
||||||
"--logdir",
|
"--logdir",
|
||||||
logging_dir,
|
logging_dir,
|
||||||
"--host",
|
"--host",
|
||||||
"0.0.0.0",
|
self.tensorboard_host,
|
||||||
"--port",
|
"--port",
|
||||||
str(self.tensorboard_port),
|
str(self.tensorboard_port),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ except ImportError:
|
||||||
from easygui import msgbox, ynbox
|
from easygui import msgbox, ynbox
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .custom_logging import setup_logging
|
from .custom_logging import setup_logging
|
||||||
|
from .sd_modeltype import SDModelType
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
@ -327,7 +328,6 @@ def update_my_data(my_data):
|
||||||
|
|
||||||
# Convert values to int if they are strings
|
# Convert values to int if they are strings
|
||||||
for key in [
|
for key in [
|
||||||
"adaptive_noise_scale",
|
|
||||||
"clip_skip",
|
"clip_skip",
|
||||||
"epoch",
|
"epoch",
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
|
|
@ -379,7 +379,13 @@ def update_my_data(my_data):
|
||||||
my_data[key] = int(75)
|
my_data[key] = int(75)
|
||||||
|
|
||||||
# Convert values to float if they are strings, correctly handling float representations
|
# Convert values to float if they are strings, correctly handling float representations
|
||||||
for key in ["noise_offset", "learning_rate", "text_encoder_lr", "unet_lr"]:
|
for key in [
|
||||||
|
"adaptive_noise_scale",
|
||||||
|
"noise_offset",
|
||||||
|
"learning_rate",
|
||||||
|
"text_encoder_lr",
|
||||||
|
"unet_lr",
|
||||||
|
]:
|
||||||
value = my_data.get(key)
|
value = my_data.get(key)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -956,11 +962,15 @@ def set_pretrained_model_name_or_path_input(
|
||||||
v2 = gr.Checkbox(value=False, visible=False)
|
v2 = gr.Checkbox(value=False, visible=False)
|
||||||
v_parameterization = gr.Checkbox(value=False, visible=False)
|
v_parameterization = gr.Checkbox(value=False, visible=False)
|
||||||
sdxl = gr.Checkbox(value=True, visible=False)
|
sdxl = gr.Checkbox(value=True, visible=False)
|
||||||
|
sd3 = gr.Checkbox(value=False, visible=False)
|
||||||
|
flux1 = gr.Checkbox(value=False, visible=False)
|
||||||
return (
|
return (
|
||||||
gr.Dropdown(),
|
gr.Dropdown(),
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
sd3,
|
||||||
|
flux1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the given pretrained_model_name_or_path is in the list of V2 base models
|
# Check if the given pretrained_model_name_or_path is in the list of V2 base models
|
||||||
|
|
@ -969,11 +979,15 @@ def set_pretrained_model_name_or_path_input(
|
||||||
v2 = gr.Checkbox(value=True, visible=False)
|
v2 = gr.Checkbox(value=True, visible=False)
|
||||||
v_parameterization = gr.Checkbox(value=False, visible=False)
|
v_parameterization = gr.Checkbox(value=False, visible=False)
|
||||||
sdxl = gr.Checkbox(value=False, visible=False)
|
sdxl = gr.Checkbox(value=False, visible=False)
|
||||||
|
sd3 = gr.Checkbox(value=False, visible=False)
|
||||||
|
flux1 = gr.Checkbox(value=False, visible=False)
|
||||||
return (
|
return (
|
||||||
gr.Dropdown(),
|
gr.Dropdown(),
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
sd3,
|
||||||
|
flux1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the given pretrained_model_name_or_path is in the list of V parameterization models
|
# Check if the given pretrained_model_name_or_path is in the list of V parameterization models
|
||||||
|
|
@ -984,11 +998,15 @@ def set_pretrained_model_name_or_path_input(
|
||||||
v2 = gr.Checkbox(value=True, visible=False)
|
v2 = gr.Checkbox(value=True, visible=False)
|
||||||
v_parameterization = gr.Checkbox(value=True, visible=False)
|
v_parameterization = gr.Checkbox(value=True, visible=False)
|
||||||
sdxl = gr.Checkbox(value=False, visible=False)
|
sdxl = gr.Checkbox(value=False, visible=False)
|
||||||
|
sd3 = gr.Checkbox(value=False, visible=False)
|
||||||
|
flux1 = gr.Checkbox(value=False, visible=False)
|
||||||
return (
|
return (
|
||||||
gr.Dropdown(),
|
gr.Dropdown(),
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
sd3,
|
||||||
|
flux1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the given pretrained_model_name_or_path is in the list of V1 models
|
# Check if the given pretrained_model_name_or_path is in the list of V1 models
|
||||||
|
|
@ -997,17 +1015,32 @@ def set_pretrained_model_name_or_path_input(
|
||||||
v2 = gr.Checkbox(value=False, visible=False)
|
v2 = gr.Checkbox(value=False, visible=False)
|
||||||
v_parameterization = gr.Checkbox(value=False, visible=False)
|
v_parameterization = gr.Checkbox(value=False, visible=False)
|
||||||
sdxl = gr.Checkbox(value=False, visible=False)
|
sdxl = gr.Checkbox(value=False, visible=False)
|
||||||
|
sd3 = gr.Checkbox(value=False, visible=False)
|
||||||
|
flux1 = gr.Checkbox(value=False, visible=False)
|
||||||
return (
|
return (
|
||||||
gr.Dropdown(),
|
gr.Dropdown(),
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
sd3,
|
||||||
|
flux1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the model_list is set to 'custom'
|
# Check if the model_list is set to 'custom'
|
||||||
v2 = gr.Checkbox(visible=True)
|
v2 = gr.Checkbox(visible=True)
|
||||||
v_parameterization = gr.Checkbox(visible=True)
|
v_parameterization = gr.Checkbox(visible=True)
|
||||||
sdxl = gr.Checkbox(visible=True)
|
sdxl = gr.Checkbox(visible=True)
|
||||||
|
sd3 = gr.Checkbox(visible=True)
|
||||||
|
flux1 = gr.Checkbox(visible=True)
|
||||||
|
|
||||||
|
# Auto-detect model type if safetensors file path is given
|
||||||
|
if pretrained_model_name_or_path.lower().endswith(".safetensors"):
|
||||||
|
detect = SDModelType(pretrained_model_name_or_path)
|
||||||
|
v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True)
|
||||||
|
sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True)
|
||||||
|
sd3 = gr.Checkbox(value=detect.Is_SD3(), visible=True)
|
||||||
|
flux1 = gr.Checkbox(value=detect.Is_FLUX1(), visible=True)
|
||||||
|
#TODO: v_parameterization
|
||||||
|
|
||||||
# If a refresh method is provided, use it to update the choices for the Dropdown widget
|
# If a refresh method is provided, use it to update the choices for the Dropdown widget
|
||||||
if refresh_method is not None:
|
if refresh_method is not None:
|
||||||
|
|
@ -1021,6 +1054,8 @@ def set_pretrained_model_name_or_path_input(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
sd3,
|
||||||
|
flux1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1369,7 +1404,11 @@ def validate_file_path(file_path: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def validate_folder_path(folder_path: str, can_be_written_to: bool = False, create_if_not_exists: bool = False) -> bool:
|
def validate_folder_path(
|
||||||
|
folder_path: str,
|
||||||
|
can_be_written_to: bool = False,
|
||||||
|
create_if_not_exists: bool = False,
|
||||||
|
) -> bool:
|
||||||
if folder_path == "":
|
if folder_path == "":
|
||||||
return True
|
return True
|
||||||
msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..."
|
msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..."
|
||||||
|
|
@ -1387,6 +1426,7 @@ def validate_folder_path(folder_path: str, can_be_written_to: bool = False, crea
|
||||||
log.info(f"{msg} SUCCESS")
|
log.info(f"{msg} SUCCESS")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def validate_toml_file(file_path: str) -> bool:
|
def validate_toml_file(file_path: str) -> bool:
|
||||||
if file_path == "":
|
if file_path == "":
|
||||||
return True
|
return True
|
||||||
|
|
@ -1394,7 +1434,7 @@ def validate_toml_file(file_path: str) -> bool:
|
||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
log.error(f"{msg} FAILED: does not exist")
|
log.error(f"{msg} FAILED: does not exist")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
toml.load(file_path)
|
toml.load(file_path)
|
||||||
except:
|
except:
|
||||||
|
|
@ -1425,11 +1465,14 @@ def validate_model_path(pretrained_model_name_or_path: str) -> bool:
|
||||||
log.info(f"{msg} SUCCESS")
|
log.info(f"{msg} SUCCESS")
|
||||||
else:
|
else:
|
||||||
# If not one of the default models, check if it's a valid local path
|
# If not one of the default models, check if it's a valid local path
|
||||||
if not validate_file_path(pretrained_model_name_or_path) and not validate_folder_path(pretrained_model_name_or_path):
|
if not validate_file_path(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
) and not validate_folder_path(pretrained_model_name_or_path):
|
||||||
log.info(f"{msg} FAILURE: not a valid file or folder")
|
log.info(f"{msg} FAILURE: not a valid file or folder")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_file_writable(file_path: str) -> bool:
|
def is_file_writable(file_path: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks if a file is writable.
|
Checks if a file is writable.
|
||||||
|
|
@ -1450,8 +1493,9 @@ def is_file_writable(file_path: str) -> bool:
|
||||||
pass
|
pass
|
||||||
# If the file can be opened, it is considered writable
|
# If the file can be opened, it is considered writable
|
||||||
return True
|
return True
|
||||||
except IOError:
|
except IOError as e:
|
||||||
# If an IOError occurs, the file cannot be written to
|
# If an IOError occurs, the file cannot be written to
|
||||||
|
log.info(f"Error: {e}. File '{file_path}' is not writable.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1462,7 +1506,7 @@ def print_command_and_toml(run_cmd, tmpfilename):
|
||||||
# Reconstruct the safe command string for display
|
# Reconstruct the safe command string for display
|
||||||
command_to_run = " ".join(run_cmd)
|
command_to_run = " ".join(run_cmd)
|
||||||
|
|
||||||
log.info(command_to_run)
|
print(command_to_run)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
log.info(f"Showing toml config file: {tmpfilename}")
|
log.info(f"Showing toml config file: {tmpfilename}")
|
||||||
|
|
@ -1489,10 +1533,11 @@ def validate_args_setting(input_string):
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def setup_environment():
|
def setup_environment():
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env["PYTHONPATH"] = (
|
env["PYTHONPATH"] = (
|
||||||
fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||||
)
|
)
|
||||||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,11 @@ from .custom_logging import setup_logging
|
||||||
log = setup_logging()
|
log = setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import logging as log
|
||||||
|
from easygui import msgbox
|
||||||
|
|
||||||
def dataset_balancing(concept_repeats, folder, insecure):
|
def dataset_balancing(concept_repeats, folder, insecure):
|
||||||
|
|
||||||
if not concept_repeats > 0:
|
if not concept_repeats > 0:
|
||||||
|
|
@ -78,7 +83,11 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
||||||
old_name = os.path.join(folder, subdir)
|
old_name = os.path.join(folder, subdir)
|
||||||
new_name = os.path.join(folder, f"{repeats}_{subdir}")
|
new_name = os.path.join(folder, f"{repeats}_{subdir}")
|
||||||
|
|
||||||
os.rename(old_name, new_name)
|
# Check if the new folder name already exists
|
||||||
|
if os.path.exists(new_name):
|
||||||
|
log.warning(f"Destination folder {new_name} already exists. Skipping...")
|
||||||
|
else:
|
||||||
|
os.rename(old_name, new_name)
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..."
|
f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..."
|
||||||
|
|
@ -87,6 +96,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
||||||
msgbox("Dataset balancing completed...")
|
msgbox("Dataset balancing completed...")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def warning(insecure):
|
def warning(insecure):
|
||||||
if insecure:
|
if insecure:
|
||||||
if boolbox(
|
if boolbox(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,9 @@ from .common_gui import (
|
||||||
SaveConfigFile,
|
SaveConfigFile,
|
||||||
scriptdir,
|
scriptdir,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
validate_file_path, validate_folder_path, validate_model_path,
|
validate_file_path,
|
||||||
|
validate_folder_path,
|
||||||
|
validate_model_path,
|
||||||
validate_args_setting,
|
validate_args_setting,
|
||||||
setup_environment,
|
setup_environment,
|
||||||
)
|
)
|
||||||
|
|
@ -27,10 +29,13 @@ from .class_gui_config import KohyaSSGUIConfig
|
||||||
from .class_source_model import SourceModel
|
from .class_source_model import SourceModel
|
||||||
from .class_basic_training import BasicTraining
|
from .class_basic_training import BasicTraining
|
||||||
from .class_advanced_training import AdvancedTraining
|
from .class_advanced_training import AdvancedTraining
|
||||||
|
from .class_sd3 import sd3Training
|
||||||
from .class_folders import Folders
|
from .class_folders import Folders
|
||||||
from .class_command_executor import CommandExecutor
|
from .class_command_executor import CommandExecutor
|
||||||
from .class_huggingface import HuggingFace
|
from .class_huggingface import HuggingFace
|
||||||
from .class_metadata import MetaData
|
from .class_metadata import MetaData
|
||||||
|
from .class_sdxl_parameters import SDXLParameters
|
||||||
|
from .class_flux1 import flux1Training
|
||||||
|
|
||||||
from .dreambooth_folder_creation_gui import (
|
from .dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
|
|
@ -60,6 +65,7 @@ def save_configuration(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
flux1_checkbox,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
|
|
@ -72,6 +78,7 @@ def save_configuration(
|
||||||
learning_rate_te2,
|
learning_rate_te2,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
|
|
@ -84,6 +91,7 @@ def save_configuration(
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
fp8_base,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
full_bf16,
|
full_bf16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
|
|
@ -134,6 +142,7 @@ def save_configuration(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -150,18 +159,28 @@ def save_configuration(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
|
fused_backward_pass,
|
||||||
|
fused_optimizer_groups,
|
||||||
|
sdxl_cache_text_encoder_outputs,
|
||||||
|
sdxl_no_half_vae,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
max_timestep,
|
max_timestep,
|
||||||
debiased_estimation_loss,
|
debiased_estimation_loss,
|
||||||
|
|
@ -178,6 +197,44 @@ def save_configuration(
|
||||||
metadata_license,
|
metadata_license,
|
||||||
metadata_tags,
|
metadata_tags,
|
||||||
metadata_title,
|
metadata_title,
|
||||||
|
# SD3 parameters
|
||||||
|
sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_fused_backward_pass,
|
||||||
|
clip_g,
|
||||||
|
clip_l,
|
||||||
|
logit_mean,
|
||||||
|
logit_std,
|
||||||
|
mode_scale,
|
||||||
|
save_clip,
|
||||||
|
save_t5xxl,
|
||||||
|
t5xxl,
|
||||||
|
t5xxl_device,
|
||||||
|
t5xxl_dtype,
|
||||||
|
sd3_text_encoder_batch_size,
|
||||||
|
weighting_scheme,
|
||||||
|
sd3_checkbox,
|
||||||
|
# Flux.1
|
||||||
|
flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
ae,
|
||||||
|
flux1_clip_l,
|
||||||
|
flux1_t5xxl,
|
||||||
|
discrete_flow_shift,
|
||||||
|
model_prediction_type,
|
||||||
|
timestep_sampling,
|
||||||
|
split_mode,
|
||||||
|
train_blocks,
|
||||||
|
t5xxl_max_token_length,
|
||||||
|
guidance_scale,
|
||||||
|
blockwise_fused_optimizers,
|
||||||
|
flux_fused_backward_pass,
|
||||||
|
cpu_offload_checkpointing,
|
||||||
|
blocks_to_swap,
|
||||||
|
single_blocks_to_swap,
|
||||||
|
double_blocks_to_swap,
|
||||||
|
mem_eff_save,
|
||||||
|
apply_t5_attn_mask,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
@ -218,6 +275,7 @@ def open_configuration(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
flux1_checkbox,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
|
|
@ -230,6 +288,7 @@ def open_configuration(
|
||||||
learning_rate_te2,
|
learning_rate_te2,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
|
|
@ -242,6 +301,7 @@ def open_configuration(
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
fp8_base,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
full_bf16,
|
full_bf16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
|
|
@ -292,6 +352,7 @@ def open_configuration(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -308,18 +369,28 @@ def open_configuration(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
|
fused_backward_pass,
|
||||||
|
fused_optimizer_groups,
|
||||||
|
sdxl_cache_text_encoder_outputs,
|
||||||
|
sdxl_no_half_vae,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
max_timestep,
|
max_timestep,
|
||||||
debiased_estimation_loss,
|
debiased_estimation_loss,
|
||||||
|
|
@ -336,6 +407,44 @@ def open_configuration(
|
||||||
metadata_license,
|
metadata_license,
|
||||||
metadata_tags,
|
metadata_tags,
|
||||||
metadata_title,
|
metadata_title,
|
||||||
|
# SD3 parameters
|
||||||
|
sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_fused_backward_pass,
|
||||||
|
clip_g,
|
||||||
|
clip_l,
|
||||||
|
logit_mean,
|
||||||
|
logit_std,
|
||||||
|
mode_scale,
|
||||||
|
save_clip,
|
||||||
|
save_t5xxl,
|
||||||
|
t5xxl,
|
||||||
|
t5xxl_device,
|
||||||
|
t5xxl_dtype,
|
||||||
|
sd3_text_encoder_batch_size,
|
||||||
|
weighting_scheme,
|
||||||
|
sd3_checkbox,
|
||||||
|
# Flux.1
|
||||||
|
flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
ae,
|
||||||
|
flux1_clip_l,
|
||||||
|
flux1_t5xxl,
|
||||||
|
discrete_flow_shift,
|
||||||
|
model_prediction_type,
|
||||||
|
timestep_sampling,
|
||||||
|
split_mode,
|
||||||
|
train_blocks,
|
||||||
|
t5xxl_max_token_length,
|
||||||
|
guidance_scale,
|
||||||
|
blockwise_fused_optimizers,
|
||||||
|
flux_fused_backward_pass,
|
||||||
|
cpu_offload_checkpointing,
|
||||||
|
blocks_to_swap,
|
||||||
|
single_blocks_to_swap,
|
||||||
|
double_blocks_to_swap,
|
||||||
|
mem_eff_save,
|
||||||
|
apply_t5_attn_mask,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
@ -371,6 +480,7 @@ def train_model(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl,
|
sdxl,
|
||||||
|
flux1_checkbox,
|
||||||
logging_dir,
|
logging_dir,
|
||||||
train_data_dir,
|
train_data_dir,
|
||||||
reg_data_dir,
|
reg_data_dir,
|
||||||
|
|
@ -383,6 +493,7 @@ def train_model(
|
||||||
learning_rate_te2,
|
learning_rate_te2,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
|
|
@ -395,6 +506,7 @@ def train_model(
|
||||||
caption_extension,
|
caption_extension,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
fp8_base,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
full_bf16,
|
full_bf16,
|
||||||
no_token_padding,
|
no_token_padding,
|
||||||
|
|
@ -445,6 +557,7 @@ def train_model(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -461,18 +574,28 @@ def train_model(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
|
fused_backward_pass,
|
||||||
|
fused_optimizer_groups,
|
||||||
|
sdxl_cache_text_encoder_outputs,
|
||||||
|
sdxl_no_half_vae,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
max_timestep,
|
max_timestep,
|
||||||
debiased_estimation_loss,
|
debiased_estimation_loss,
|
||||||
|
|
@ -489,6 +612,44 @@ def train_model(
|
||||||
metadata_license,
|
metadata_license,
|
||||||
metadata_tags,
|
metadata_tags,
|
||||||
metadata_title,
|
metadata_title,
|
||||||
|
# SD3 parameters
|
||||||
|
sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_fused_backward_pass,
|
||||||
|
clip_g,
|
||||||
|
clip_l,
|
||||||
|
logit_mean,
|
||||||
|
logit_std,
|
||||||
|
mode_scale,
|
||||||
|
save_clip,
|
||||||
|
save_t5xxl,
|
||||||
|
t5xxl,
|
||||||
|
t5xxl_device,
|
||||||
|
t5xxl_dtype,
|
||||||
|
sd3_text_encoder_batch_size,
|
||||||
|
weighting_scheme,
|
||||||
|
sd3_checkbox,
|
||||||
|
# Flux.1
|
||||||
|
flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
ae,
|
||||||
|
flux1_clip_l,
|
||||||
|
flux1_t5xxl,
|
||||||
|
discrete_flow_shift,
|
||||||
|
model_prediction_type,
|
||||||
|
timestep_sampling,
|
||||||
|
split_mode,
|
||||||
|
train_blocks,
|
||||||
|
t5xxl_max_token_length,
|
||||||
|
guidance_scale,
|
||||||
|
blockwise_fused_optimizers,
|
||||||
|
flux_fused_backward_pass,
|
||||||
|
cpu_offload_checkpointing,
|
||||||
|
blocks_to_swap,
|
||||||
|
single_blocks_to_swap,
|
||||||
|
double_blocks_to_swap,
|
||||||
|
mem_eff_save,
|
||||||
|
apply_t5_attn_mask,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
@ -509,61 +670,50 @@ def train_model(
|
||||||
log.info(f"Validating lr scheduler arguments...")
|
log.info(f"Validating lr scheduler arguments...")
|
||||||
if not validate_args_setting(lr_scheduler_args):
|
if not validate_args_setting(lr_scheduler_args):
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info(f"Validating optimizer arguments...")
|
log.info(f"Validating optimizer arguments...")
|
||||||
if not validate_args_setting(optimizer_args):
|
if not validate_args_setting(optimizer_args):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
#
|
#
|
||||||
# Validate paths
|
# Validate paths
|
||||||
#
|
#
|
||||||
|
|
||||||
if not validate_file_path(dataset_config):
|
if not validate_file_path(dataset_config):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_file_path(log_tracker_config):
|
if not validate_file_path(log_tracker_config):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True):
|
if not validate_folder_path(
|
||||||
|
logging_dir, can_be_written_to=True, create_if_not_exists=True
|
||||||
|
):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True):
|
if not validate_folder_path(
|
||||||
|
output_dir, can_be_written_to=True, create_if_not_exists=True
|
||||||
|
):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_model_path(pretrained_model_name_or_path):
|
if not validate_model_path(pretrained_model_name_or_path):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(reg_data_dir):
|
if not validate_folder_path(reg_data_dir):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(resume):
|
if not validate_folder_path(resume):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(train_data_dir):
|
if not validate_folder_path(train_data_dir):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_model_path(vae):
|
if not validate_model_path(vae):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
#
|
#
|
||||||
# End of path validation
|
# End of path validation
|
||||||
#
|
#
|
||||||
|
|
||||||
# This function validates files or folder paths. Simply add new variables containing file of folder path
|
|
||||||
# to validate below
|
|
||||||
# if not validate_paths(
|
|
||||||
# dataset_config=dataset_config,
|
|
||||||
# headless=headless,
|
|
||||||
# log_tracker_config=log_tracker_config,
|
|
||||||
# logging_dir=logging_dir,
|
|
||||||
# output_dir=output_dir,
|
|
||||||
# pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
# reg_data_dir=reg_data_dir,
|
|
||||||
# resume=resume,
|
|
||||||
# train_data_dir=train_data_dir,
|
|
||||||
# vae=vae,
|
|
||||||
# ):
|
|
||||||
# return TRAIN_BUTTON_VISIBLE
|
|
||||||
|
|
||||||
if not print_only and check_if_model_exist(
|
if not print_only and check_if_model_exist(
|
||||||
output_name, output_dir, save_model_as, headless=headless
|
output_name, output_dir, save_model_as, headless=headless
|
||||||
):
|
):
|
||||||
|
|
@ -573,15 +723,6 @@ def train_model(
|
||||||
log.info(
|
log.info(
|
||||||
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
|
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
|
||||||
)
|
)
|
||||||
if max_train_steps > 0:
|
|
||||||
if lr_warmup != 0:
|
|
||||||
lr_warmup_steps = round(
|
|
||||||
float(int(lr_warmup) * int(max_train_steps) / 100)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lr_warmup_steps = 0
|
|
||||||
else:
|
|
||||||
lr_warmup_steps = 0
|
|
||||||
|
|
||||||
if max_train_steps == 0:
|
if max_train_steps == 0:
|
||||||
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
|
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
|
||||||
|
|
@ -640,11 +781,11 @@ def train_model(
|
||||||
reg_factor = 1
|
reg_factor = 1
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
"Regularisation images are used... Will double the number of steps required..."
|
"Regularization images are used... Will double the number of steps required..."
|
||||||
)
|
)
|
||||||
reg_factor = 2
|
reg_factor = 2
|
||||||
|
|
||||||
log.info(f"Regulatization factor: {reg_factor}")
|
log.info(f"Regularization factor: {reg_factor}")
|
||||||
|
|
||||||
if max_train_steps == 0:
|
if max_train_steps == 0:
|
||||||
# calculate max_train_steps
|
# calculate max_train_steps
|
||||||
|
|
@ -664,13 +805,18 @@ def train_model(
|
||||||
else:
|
else:
|
||||||
max_train_steps_info = f"Max train steps: {max_train_steps}"
|
max_train_steps_info = f"Max train steps: {max_train_steps}"
|
||||||
|
|
||||||
if lr_warmup != 0:
|
|
||||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
|
||||||
else:
|
|
||||||
lr_warmup_steps = 0
|
|
||||||
|
|
||||||
log.info(f"Total steps: {total_steps}")
|
log.info(f"Total steps: {total_steps}")
|
||||||
|
|
||||||
|
# Calculate lr_warmup_steps
|
||||||
|
if lr_warmup_steps > 0:
|
||||||
|
lr_warmup_steps = int(lr_warmup_steps)
|
||||||
|
if lr_warmup > 0:
|
||||||
|
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||||
|
elif lr_warmup != 0:
|
||||||
|
lr_warmup_steps = lr_warmup / 100
|
||||||
|
else:
|
||||||
|
lr_warmup_steps = 0
|
||||||
|
|
||||||
log.info(f"Train batch size: {train_batch_size}")
|
log.info(f"Train batch size: {train_batch_size}")
|
||||||
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
|
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
|
||||||
log.info(f"Epoch: {epoch}")
|
log.info(f"Epoch: {epoch}")
|
||||||
|
|
@ -682,7 +828,7 @@ def train_model(
|
||||||
log.error("accelerate not found")
|
log.error("accelerate not found")
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
run_cmd = [rf'{accelerate_path}', "launch"]
|
run_cmd = [rf"{accelerate_path}", "launch"]
|
||||||
|
|
||||||
run_cmd = AccelerateLaunch.run_cmd(
|
run_cmd = AccelerateLaunch.run_cmd(
|
||||||
run_cmd=run_cmd,
|
run_cmd=run_cmd,
|
||||||
|
|
@ -701,10 +847,23 @@ def train_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
if sdxl:
|
if sdxl:
|
||||||
run_cmd.append(rf'{scriptdir}/sd-scripts/sdxl_train.py')
|
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
|
||||||
|
elif sd3_checkbox:
|
||||||
|
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py")
|
||||||
|
elif flux1_checkbox:
|
||||||
|
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py")
|
||||||
else:
|
else:
|
||||||
run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py")
|
run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py")
|
||||||
|
|
||||||
|
cache_text_encoder_outputs = (
|
||||||
|
(sdxl and sdxl_cache_text_encoder_outputs)
|
||||||
|
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
|
||||||
|
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
|
||||||
|
)
|
||||||
|
cache_text_encoder_outputs_to_disk = (
|
||||||
|
sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
|
||||||
|
) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
|
||||||
|
no_half_vae = sdxl and sdxl_no_half_vae
|
||||||
if max_data_loader_n_workers == "" or None:
|
if max_data_loader_n_workers == "" or None:
|
||||||
max_data_loader_n_workers = 0
|
max_data_loader_n_workers = 0
|
||||||
else:
|
else:
|
||||||
|
|
@ -715,6 +874,19 @@ def train_model(
|
||||||
else:
|
else:
|
||||||
max_train_steps = int(max_train_steps)
|
max_train_steps = int(max_train_steps)
|
||||||
|
|
||||||
|
if sdxl:
|
||||||
|
train_text_encoder = (learning_rate_te1 != None and learning_rate_te1 > 0) or (
|
||||||
|
learning_rate_te2 != None and learning_rate_te2 > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_backward_pass_value = False
|
||||||
|
if sd3_checkbox:
|
||||||
|
fused_backward_pass_value = sd3_fused_backward_pass
|
||||||
|
elif flux1_checkbox:
|
||||||
|
fused_backward_pass_value = flux_fused_backward_pass
|
||||||
|
else:
|
||||||
|
fused_backward_pass_value = fused_backward_pass
|
||||||
|
|
||||||
# def save_huggingface_to_toml(self, toml_file_path: str):
|
# def save_huggingface_to_toml(self, toml_file_path: str):
|
||||||
config_toml_data = {
|
config_toml_data = {
|
||||||
# Update the values in the TOML data
|
# Update the values in the TOML data
|
||||||
|
|
@ -724,22 +896,32 @@ def train_model(
|
||||||
"bucket_reso_steps": bucket_reso_steps,
|
"bucket_reso_steps": bucket_reso_steps,
|
||||||
"cache_latents": cache_latents,
|
"cache_latents": cache_latents,
|
||||||
"cache_latents_to_disk": cache_latents_to_disk,
|
"cache_latents_to_disk": cache_latents_to_disk,
|
||||||
|
"cache_text_encoder_outputs": cache_text_encoder_outputs,
|
||||||
|
"cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk,
|
||||||
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
||||||
"caption_dropout_rate": caption_dropout_rate,
|
"caption_dropout_rate": caption_dropout_rate,
|
||||||
"caption_extension": caption_extension,
|
"caption_extension": caption_extension,
|
||||||
|
"clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None,
|
||||||
"clip_skip": clip_skip if clip_skip != 0 else None,
|
"clip_skip": clip_skip if clip_skip != 0 else None,
|
||||||
"color_aug": color_aug,
|
"color_aug": color_aug,
|
||||||
"dataset_config": dataset_config,
|
"dataset_config": dataset_config,
|
||||||
"debiased_estimation_loss": debiased_estimation_loss,
|
"debiased_estimation_loss": debiased_estimation_loss,
|
||||||
|
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
|
||||||
"dynamo_backend": dynamo_backend,
|
"dynamo_backend": dynamo_backend,
|
||||||
"enable_bucket": enable_bucket,
|
"enable_bucket": enable_bucket,
|
||||||
"epoch": int(epoch),
|
"epoch": int(epoch),
|
||||||
"flip_aug": flip_aug,
|
"flip_aug": flip_aug,
|
||||||
|
"fp8_base": fp8_base,
|
||||||
"full_bf16": full_bf16,
|
"full_bf16": full_bf16,
|
||||||
"full_fp16": full_fp16,
|
"full_fp16": full_fp16,
|
||||||
|
"fused_backward_pass": fused_backward_pass_value,
|
||||||
|
"fused_optimizer_groups": (
|
||||||
|
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
|
||||||
|
),
|
||||||
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
||||||
"gradient_checkpointing": gradient_checkpointing,
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
"huber_c": huber_c,
|
"huber_c": huber_c,
|
||||||
|
"huber_scale": huber_scale,
|
||||||
"huber_schedule": huber_schedule,
|
"huber_schedule": huber_schedule,
|
||||||
"huggingface_path_in_repo": huggingface_path_in_repo,
|
"huggingface_path_in_repo": huggingface_path_in_repo,
|
||||||
"huggingface_repo_id": huggingface_repo_id,
|
"huggingface_repo_id": huggingface_repo_id,
|
||||||
|
|
@ -750,16 +932,11 @@ def train_model(
|
||||||
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
|
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
|
||||||
"keep_tokens": int(keep_tokens),
|
"keep_tokens": int(keep_tokens),
|
||||||
"learning_rate": learning_rate, # both for sd1.5 and sdxl
|
"learning_rate": learning_rate, # both for sd1.5 and sdxl
|
||||||
"learning_rate_te": (
|
"learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5
|
||||||
learning_rate_te if not sdxl and not 0 else None
|
"learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl
|
||||||
), # only for sd1.5 and not 0
|
"learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl
|
||||||
"learning_rate_te1": (
|
|
||||||
learning_rate_te1 if sdxl and not 0 else None
|
|
||||||
), # only for sdxl and not 0
|
|
||||||
"learning_rate_te2": (
|
|
||||||
learning_rate_te2 if sdxl and not 0 else None
|
|
||||||
), # only for sdxl and not 0
|
|
||||||
"logging_dir": logging_dir,
|
"logging_dir": logging_dir,
|
||||||
|
"log_config": log_config,
|
||||||
"log_tracker_config": log_tracker_config,
|
"log_tracker_config": log_tracker_config,
|
||||||
"log_tracker_name": log_tracker_name,
|
"log_tracker_name": log_tracker_name,
|
||||||
"log_with": log_with,
|
"log_with": log_with,
|
||||||
|
|
@ -767,15 +944,20 @@ def train_model(
|
||||||
"lr_scheduler": lr_scheduler,
|
"lr_scheduler": lr_scheduler,
|
||||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
||||||
"lr_scheduler_num_cycles": (
|
"lr_scheduler_num_cycles": (
|
||||||
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
|
int(lr_scheduler_num_cycles)
|
||||||
|
if lr_scheduler_num_cycles != ""
|
||||||
|
else int(epoch)
|
||||||
),
|
),
|
||||||
"lr_scheduler_power": lr_scheduler_power,
|
"lr_scheduler_power": lr_scheduler_power,
|
||||||
|
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
|
||||||
"lr_warmup_steps": lr_warmup_steps,
|
"lr_warmup_steps": lr_warmup_steps,
|
||||||
"masked_loss": masked_loss,
|
"masked_loss": masked_loss,
|
||||||
"max_bucket_reso": max_bucket_reso,
|
"max_bucket_reso": max_bucket_reso,
|
||||||
"max_timestep": max_timestep if max_timestep != 0 else None,
|
"max_timestep": max_timestep if max_timestep != 0 else None,
|
||||||
"max_token_length": int(max_token_length),
|
"max_token_length": int(max_token_length),
|
||||||
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
|
"max_train_epochs": (
|
||||||
|
int(max_train_epochs) if int(max_train_epochs) != 0 else None
|
||||||
|
),
|
||||||
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
|
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
|
||||||
"mem_eff_attn": mem_eff_attn,
|
"mem_eff_attn": mem_eff_attn,
|
||||||
"metadata_author": metadata_author,
|
"metadata_author": metadata_author,
|
||||||
|
|
@ -789,6 +971,7 @@ def train_model(
|
||||||
"mixed_precision": mixed_precision,
|
"mixed_precision": mixed_precision,
|
||||||
"multires_noise_discount": multires_noise_discount,
|
"multires_noise_discount": multires_noise_discount,
|
||||||
"multires_noise_iterations": multires_noise_iterations if not 0 else None,
|
"multires_noise_iterations": multires_noise_iterations if not 0 else None,
|
||||||
|
"no_half_vae": no_half_vae,
|
||||||
"no_token_padding": no_token_padding,
|
"no_token_padding": no_token_padding,
|
||||||
"noise_offset": noise_offset if not 0 else None,
|
"noise_offset": noise_offset if not 0 else None,
|
||||||
"noise_offset_random_strength": noise_offset_random_strength,
|
"noise_offset_random_strength": noise_offset_random_strength,
|
||||||
|
|
@ -825,6 +1008,10 @@ def train_model(
|
||||||
"save_last_n_steps_state": (
|
"save_last_n_steps_state": (
|
||||||
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
||||||
),
|
),
|
||||||
|
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
|
||||||
|
"save_last_n_epochs_state": (
|
||||||
|
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
|
||||||
|
),
|
||||||
"save_model_as": save_model_as,
|
"save_model_as": save_model_as,
|
||||||
"save_precision": save_precision,
|
"save_precision": save_precision,
|
||||||
"save_state": save_state,
|
"save_state": save_state,
|
||||||
|
|
@ -834,20 +1021,65 @@ def train_model(
|
||||||
"sdpa": True if xformers == "sdpa" else None,
|
"sdpa": True if xformers == "sdpa" else None,
|
||||||
"seed": int(seed) if int(seed) != 0 else None,
|
"seed": int(seed) if int(seed) != 0 else None,
|
||||||
"shuffle_caption": shuffle_caption,
|
"shuffle_caption": shuffle_caption,
|
||||||
|
"skip_cache_check": skip_cache_check,
|
||||||
"stop_text_encoder_training": (
|
"stop_text_encoder_training": (
|
||||||
stop_text_encoder_training if stop_text_encoder_training != 0 else None
|
stop_text_encoder_training if stop_text_encoder_training != 0 else None
|
||||||
),
|
),
|
||||||
|
"t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None,
|
||||||
"train_batch_size": train_batch_size,
|
"train_batch_size": train_batch_size,
|
||||||
"train_data_dir": train_data_dir,
|
"train_data_dir": train_data_dir,
|
||||||
|
"train_text_encoder": train_text_encoder if sdxl else None,
|
||||||
"v2": v2,
|
"v2": v2,
|
||||||
"v_parameterization": v_parameterization,
|
"v_parameterization": v_parameterization,
|
||||||
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
||||||
"vae": vae,
|
"vae": vae,
|
||||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||||
"wandb_api_key": wandb_api_key,
|
"wandb_api_key": wandb_api_key,
|
||||||
"wandb_run_name": wandb_run_name,
|
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||||
"weighted_captions": weighted_captions,
|
"weighted_captions": weighted_captions,
|
||||||
"xformers": True if xformers == "xformers" else None,
|
"xformers": True if xformers == "xformers" else None,
|
||||||
|
# SD3 only Parameters
|
||||||
|
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||||
|
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||||
|
"clip_g": clip_g if sd3_checkbox else None,
|
||||||
|
# "clip_l": see previous assignment above for code
|
||||||
|
"logit_mean": logit_mean if sd3_checkbox else None,
|
||||||
|
"logit_std": logit_std if sd3_checkbox else None,
|
||||||
|
"mode_scale": mode_scale if sd3_checkbox else None,
|
||||||
|
"save_clip": save_clip if sd3_checkbox else None,
|
||||||
|
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
|
||||||
|
# "t5xxl": see previous assignment above for code
|
||||||
|
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
|
||||||
|
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
|
||||||
|
"text_encoder_batch_size": (
|
||||||
|
sd3_text_encoder_batch_size if sd3_checkbox else None
|
||||||
|
),
|
||||||
|
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
|
||||||
|
# Flux.1 specific parameters
|
||||||
|
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||||
|
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||||
|
"ae": ae if flux1_checkbox else None,
|
||||||
|
# "clip_l": see previous assignment above for code
|
||||||
|
# "t5xxl": see previous assignment above for code
|
||||||
|
"discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None,
|
||||||
|
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
|
||||||
|
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
|
||||||
|
"split_mode": split_mode if flux1_checkbox else None,
|
||||||
|
"train_blocks": train_blocks if flux1_checkbox else None,
|
||||||
|
"t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None,
|
||||||
|
"guidance_scale": guidance_scale if flux1_checkbox else None,
|
||||||
|
"blockwise_fused_optimizers": (
|
||||||
|
blockwise_fused_optimizers if flux1_checkbox else None
|
||||||
|
),
|
||||||
|
# "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code
|
||||||
|
"cpu_offload_checkpointing": (
|
||||||
|
cpu_offload_checkpointing if flux1_checkbox else None
|
||||||
|
),
|
||||||
|
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else None,
|
||||||
|
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
|
||||||
|
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
|
||||||
|
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
|
||||||
|
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Given dictionary `config_toml_data`
|
# Given dictionary `config_toml_data`
|
||||||
|
|
@ -855,7 +1087,7 @@ def train_model(
|
||||||
config_toml_data = {
|
config_toml_data = {
|
||||||
key: value
|
key: value
|
||||||
for key, value in config_toml_data.items()
|
for key, value in config_toml_data.items()
|
||||||
if value not in ["", False, None]
|
if not any([value == "", value is False, value is None])
|
||||||
}
|
}
|
||||||
|
|
||||||
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)
|
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)
|
||||||
|
|
@ -865,8 +1097,8 @@ def train_model(
|
||||||
|
|
||||||
current_datetime = datetime.now()
|
current_datetime = datetime.now()
|
||||||
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
||||||
tmpfilename = fr"{output_dir}/config_dreambooth-{formatted_datetime}.toml"
|
tmpfilename = rf"{output_dir}/config_dreambooth-{formatted_datetime}.toml"
|
||||||
|
|
||||||
# Save the updated TOML data back to the file
|
# Save the updated TOML data back to the file
|
||||||
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
|
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
|
||||||
toml.dump(config_toml_data, toml_file)
|
toml.dump(config_toml_data, toml_file)
|
||||||
|
|
@ -875,7 +1107,7 @@ def train_model(
|
||||||
log.error(f"Failed to write TOML file: {toml_file.name}")
|
log.error(f"Failed to write TOML file: {toml_file.name}")
|
||||||
|
|
||||||
run_cmd.append(f"--config_file")
|
run_cmd.append(f"--config_file")
|
||||||
run_cmd.append(rf'{tmpfilename}')
|
run_cmd.append(rf"{tmpfilename}")
|
||||||
|
|
||||||
# Initialize a dictionary with always-included keyword arguments
|
# Initialize a dictionary with always-included keyword arguments
|
||||||
kwargs_for_training = {
|
kwargs_for_training = {
|
||||||
|
|
@ -981,6 +1213,26 @@ def dreambooth_tab(
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add SDXL Parameters
|
||||||
|
sdxl_params = SDXLParameters(
|
||||||
|
source_model.sdxl_checkbox,
|
||||||
|
config=config,
|
||||||
|
trainer="finetune",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add FLUX1 Parameters
|
||||||
|
flux1_training = flux1Training(
|
||||||
|
headless=headless,
|
||||||
|
config=config,
|
||||||
|
flux1_checkbox=source_model.flux1_checkbox,
|
||||||
|
finetuning=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add SD3 Parameters
|
||||||
|
sd3_training = sd3Training(
|
||||||
|
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
||||||
advanced_training = AdvancedTraining(headless=headless, config=config)
|
advanced_training = AdvancedTraining(headless=headless, config=config)
|
||||||
advanced_training.color_aug.change(
|
advanced_training.color_aug.change(
|
||||||
|
|
@ -1011,6 +1263,7 @@ def dreambooth_tab(
|
||||||
source_model.v2,
|
source_model.v2,
|
||||||
source_model.v_parameterization,
|
source_model.v_parameterization,
|
||||||
source_model.sdxl_checkbox,
|
source_model.sdxl_checkbox,
|
||||||
|
source_model.flux1_checkbox,
|
||||||
folders.logging_dir,
|
folders.logging_dir,
|
||||||
source_model.train_data_dir,
|
source_model.train_data_dir,
|
||||||
folders.reg_data_dir,
|
folders.reg_data_dir,
|
||||||
|
|
@ -1023,6 +1276,7 @@ def dreambooth_tab(
|
||||||
basic_training.learning_rate_te2,
|
basic_training.learning_rate_te2,
|
||||||
basic_training.lr_scheduler,
|
basic_training.lr_scheduler,
|
||||||
basic_training.lr_warmup,
|
basic_training.lr_warmup,
|
||||||
|
basic_training.lr_warmup_steps,
|
||||||
basic_training.train_batch_size,
|
basic_training.train_batch_size,
|
||||||
basic_training.epoch,
|
basic_training.epoch,
|
||||||
basic_training.save_every_n_epochs,
|
basic_training.save_every_n_epochs,
|
||||||
|
|
@ -1035,6 +1289,7 @@ def dreambooth_tab(
|
||||||
basic_training.caption_extension,
|
basic_training.caption_extension,
|
||||||
basic_training.enable_bucket,
|
basic_training.enable_bucket,
|
||||||
advanced_training.gradient_checkpointing,
|
advanced_training.gradient_checkpointing,
|
||||||
|
advanced_training.fp8_base,
|
||||||
advanced_training.full_fp16,
|
advanced_training.full_fp16,
|
||||||
advanced_training.full_bf16,
|
advanced_training.full_bf16,
|
||||||
advanced_training.no_token_padding,
|
advanced_training.no_token_padding,
|
||||||
|
|
@ -1084,6 +1339,7 @@ def dreambooth_tab(
|
||||||
basic_training.optimizer,
|
basic_training.optimizer,
|
||||||
basic_training.optimizer_args,
|
basic_training.optimizer_args,
|
||||||
basic_training.lr_scheduler_args,
|
basic_training.lr_scheduler_args,
|
||||||
|
basic_training.lr_scheduler_type,
|
||||||
advanced_training.noise_offset_type,
|
advanced_training.noise_offset_type,
|
||||||
advanced_training.noise_offset,
|
advanced_training.noise_offset,
|
||||||
advanced_training.noise_offset_random_strength,
|
advanced_training.noise_offset_random_strength,
|
||||||
|
|
@ -1100,18 +1356,28 @@ def dreambooth_tab(
|
||||||
advanced_training.loss_type,
|
advanced_training.loss_type,
|
||||||
advanced_training.huber_schedule,
|
advanced_training.huber_schedule,
|
||||||
advanced_training.huber_c,
|
advanced_training.huber_c,
|
||||||
|
advanced_training.huber_scale,
|
||||||
advanced_training.vae_batch_size,
|
advanced_training.vae_batch_size,
|
||||||
advanced_training.min_snr_gamma,
|
advanced_training.min_snr_gamma,
|
||||||
advanced_training.weighted_captions,
|
advanced_training.weighted_captions,
|
||||||
advanced_training.save_every_n_steps,
|
advanced_training.save_every_n_steps,
|
||||||
advanced_training.save_last_n_steps,
|
advanced_training.save_last_n_steps,
|
||||||
advanced_training.save_last_n_steps_state,
|
advanced_training.save_last_n_steps_state,
|
||||||
|
advanced_training.save_last_n_epochs,
|
||||||
|
advanced_training.save_last_n_epochs_state,
|
||||||
|
advanced_training.skip_cache_check,
|
||||||
advanced_training.log_with,
|
advanced_training.log_with,
|
||||||
advanced_training.wandb_api_key,
|
advanced_training.wandb_api_key,
|
||||||
advanced_training.wandb_run_name,
|
advanced_training.wandb_run_name,
|
||||||
advanced_training.log_tracker_name,
|
advanced_training.log_tracker_name,
|
||||||
advanced_training.log_tracker_config,
|
advanced_training.log_tracker_config,
|
||||||
|
advanced_training.log_config,
|
||||||
advanced_training.scale_v_pred_loss_like_noise_pred,
|
advanced_training.scale_v_pred_loss_like_noise_pred,
|
||||||
|
sdxl_params.disable_mmap_load_safetensors,
|
||||||
|
sdxl_params.fused_backward_pass,
|
||||||
|
sdxl_params.fused_optimizer_groups,
|
||||||
|
sdxl_params.sdxl_cache_text_encoder_outputs,
|
||||||
|
sdxl_params.sdxl_no_half_vae,
|
||||||
advanced_training.min_timestep,
|
advanced_training.min_timestep,
|
||||||
advanced_training.max_timestep,
|
advanced_training.max_timestep,
|
||||||
advanced_training.debiased_estimation_loss,
|
advanced_training.debiased_estimation_loss,
|
||||||
|
|
@ -1128,6 +1394,44 @@ def dreambooth_tab(
|
||||||
metadata.metadata_license,
|
metadata.metadata_license,
|
||||||
metadata.metadata_tags,
|
metadata.metadata_tags,
|
||||||
metadata.metadata_title,
|
metadata.metadata_title,
|
||||||
|
# SD3 Parameters
|
||||||
|
sd3_training.sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_training.sd3_fused_backward_pass,
|
||||||
|
sd3_training.clip_g,
|
||||||
|
sd3_training.clip_l,
|
||||||
|
sd3_training.logit_mean,
|
||||||
|
sd3_training.logit_std,
|
||||||
|
sd3_training.mode_scale,
|
||||||
|
sd3_training.save_clip,
|
||||||
|
sd3_training.save_t5xxl,
|
||||||
|
sd3_training.t5xxl,
|
||||||
|
sd3_training.t5xxl_device,
|
||||||
|
sd3_training.t5xxl_dtype,
|
||||||
|
sd3_training.sd3_text_encoder_batch_size,
|
||||||
|
sd3_training.weighting_scheme,
|
||||||
|
source_model.sd3_checkbox,
|
||||||
|
# Flux1 parameters
|
||||||
|
flux1_training.flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
flux1_training.ae,
|
||||||
|
flux1_training.clip_l,
|
||||||
|
flux1_training.t5xxl,
|
||||||
|
flux1_training.discrete_flow_shift,
|
||||||
|
flux1_training.model_prediction_type,
|
||||||
|
flux1_training.timestep_sampling,
|
||||||
|
flux1_training.split_mode,
|
||||||
|
flux1_training.train_blocks,
|
||||||
|
flux1_training.t5xxl_max_token_length,
|
||||||
|
flux1_training.guidance_scale,
|
||||||
|
flux1_training.blockwise_fused_optimizers,
|
||||||
|
flux1_training.flux_fused_backward_pass,
|
||||||
|
flux1_training.cpu_offload_checkpointing,
|
||||||
|
advanced_training.blocks_to_swap,
|
||||||
|
flux1_training.single_blocks_to_swap,
|
||||||
|
flux1_training.double_blocks_to_swap,
|
||||||
|
flux1_training.mem_eff_save,
|
||||||
|
flux1_training.apply_t5_attn_mask,
|
||||||
]
|
]
|
||||||
|
|
||||||
configuration.button_open_config.click(
|
configuration.button_open_config.click(
|
||||||
|
|
@ -1181,4 +1485,4 @@ def dreambooth_tab(
|
||||||
folders.reg_data_dir,
|
folders.reg_data_dir,
|
||||||
folders.output_dir,
|
folders.output_dir,
|
||||||
folders.logging_dir,
|
folders.logging_dir,
|
||||||
)
|
)
|
||||||
|
|
@ -12,6 +12,7 @@ from .common_gui import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .custom_logging import setup_logging
|
from .custom_logging import setup_logging
|
||||||
|
from .sd_modeltype import SDModelType
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
log = setup_logging()
|
log = setup_logging()
|
||||||
|
|
@ -337,6 +338,19 @@ def gradio_extract_lora_tab(
|
||||||
outputs=[load_tuned_model_to, load_original_model_to],
|
outputs=[load_tuned_model_to, load_original_model_to],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#secondary event on model_tuned for auto-detection of v2/SDXL
|
||||||
|
def change_modeltype_model_tuned(path):
|
||||||
|
detect = SDModelType(path)
|
||||||
|
v2 = gr.Checkbox(value=detect.Is_SD2())
|
||||||
|
sdxl = gr.Checkbox(value=detect.Is_SDXL())
|
||||||
|
return v2, sdxl
|
||||||
|
|
||||||
|
model_tuned.change(
|
||||||
|
change_modeltype_model_tuned,
|
||||||
|
inputs=model_tuned,
|
||||||
|
outputs=[v2, sdxl]
|
||||||
|
)
|
||||||
|
|
||||||
extract_button = gr.Button("Extract LoRA model")
|
extract_button = gr.Button("Extract LoRA model")
|
||||||
|
|
||||||
extract_button.click(
|
extract_button.click(
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,18 @@ from .common_gui import (
|
||||||
SaveConfigFile,
|
SaveConfigFile,
|
||||||
scriptdir,
|
scriptdir,
|
||||||
update_my_data,
|
update_my_data,
|
||||||
validate_file_path, validate_folder_path, validate_model_path,
|
validate_file_path,
|
||||||
validate_args_setting, setup_environment,
|
validate_folder_path,
|
||||||
|
validate_model_path,
|
||||||
|
validate_args_setting,
|
||||||
|
setup_environment,
|
||||||
)
|
)
|
||||||
from .class_accelerate_launch import AccelerateLaunch
|
from .class_accelerate_launch import AccelerateLaunch
|
||||||
from .class_configuration_file import ConfigurationFile
|
from .class_configuration_file import ConfigurationFile
|
||||||
from .class_source_model import SourceModel
|
from .class_source_model import SourceModel
|
||||||
from .class_basic_training import BasicTraining
|
from .class_basic_training import BasicTraining
|
||||||
from .class_advanced_training import AdvancedTraining
|
from .class_advanced_training import AdvancedTraining
|
||||||
|
from .class_sd3 import sd3Training
|
||||||
from .class_folders import Folders
|
from .class_folders import Folders
|
||||||
from .class_sdxl_parameters import SDXLParameters
|
from .class_sdxl_parameters import SDXLParameters
|
||||||
from .class_command_executor import CommandExecutor
|
from .class_command_executor import CommandExecutor
|
||||||
|
|
@ -34,6 +38,7 @@ from .class_sample_images import SampleImages, create_prompt_file
|
||||||
from .class_huggingface import HuggingFace
|
from .class_huggingface import HuggingFace
|
||||||
from .class_metadata import MetaData
|
from .class_metadata import MetaData
|
||||||
from .class_gui_config import KohyaSSGUIConfig
|
from .class_gui_config import KohyaSSGUIConfig
|
||||||
|
from .class_flux1 import flux1Training
|
||||||
|
|
||||||
from .custom_logging import setup_logging
|
from .custom_logging import setup_logging
|
||||||
|
|
||||||
|
|
@ -65,6 +70,7 @@ def save_configuration(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl_checkbox,
|
sdxl_checkbox,
|
||||||
|
flux1_checkbox,
|
||||||
train_dir,
|
train_dir,
|
||||||
image_folder,
|
image_folder,
|
||||||
output_dir,
|
output_dir,
|
||||||
|
|
@ -82,6 +88,7 @@ def save_configuration(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
dataset_repeats,
|
dataset_repeats,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
|
|
@ -116,6 +123,7 @@ def save_configuration(
|
||||||
save_state_on_train_end,
|
save_state_on_train_end,
|
||||||
resume,
|
resume,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
fp8_base,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
block_lr,
|
block_lr,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
|
|
@ -142,6 +150,7 @@ def save_configuration(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -158,18 +167,26 @@ def save_configuration(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
|
fused_backward_pass,
|
||||||
|
fused_optimizer_groups,
|
||||||
sdxl_cache_text_encoder_outputs,
|
sdxl_cache_text_encoder_outputs,
|
||||||
sdxl_no_half_vae,
|
sdxl_no_half_vae,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
|
|
@ -188,6 +205,44 @@ def save_configuration(
|
||||||
metadata_license,
|
metadata_license,
|
||||||
metadata_tags,
|
metadata_tags,
|
||||||
metadata_title,
|
metadata_title,
|
||||||
|
# SD3 parameters
|
||||||
|
sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_fused_backward_pass,
|
||||||
|
clip_g,
|
||||||
|
clip_l,
|
||||||
|
logit_mean,
|
||||||
|
logit_std,
|
||||||
|
mode_scale,
|
||||||
|
save_clip,
|
||||||
|
save_t5xxl,
|
||||||
|
t5xxl,
|
||||||
|
t5xxl_device,
|
||||||
|
t5xxl_dtype,
|
||||||
|
sd3_text_encoder_batch_size,
|
||||||
|
weighting_scheme,
|
||||||
|
sd3_checkbox,
|
||||||
|
# Flux.1
|
||||||
|
flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
ae,
|
||||||
|
flux1_clip_l,
|
||||||
|
flux1_t5xxl,
|
||||||
|
discrete_flow_shift,
|
||||||
|
model_prediction_type,
|
||||||
|
timestep_sampling,
|
||||||
|
split_mode,
|
||||||
|
train_blocks,
|
||||||
|
t5xxl_max_token_length,
|
||||||
|
guidance_scale,
|
||||||
|
blockwise_fused_optimizers,
|
||||||
|
flux_fused_backward_pass,
|
||||||
|
cpu_offload_checkpointing,
|
||||||
|
blocks_to_swap,
|
||||||
|
single_blocks_to_swap,
|
||||||
|
double_blocks_to_swap,
|
||||||
|
mem_eff_save,
|
||||||
|
apply_t5_attn_mask,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
@ -231,6 +286,7 @@ def open_configuration(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl_checkbox,
|
sdxl_checkbox,
|
||||||
|
flux1_checkbox,
|
||||||
train_dir,
|
train_dir,
|
||||||
image_folder,
|
image_folder,
|
||||||
output_dir,
|
output_dir,
|
||||||
|
|
@ -248,6 +304,7 @@ def open_configuration(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
dataset_repeats,
|
dataset_repeats,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
|
|
@ -282,6 +339,7 @@ def open_configuration(
|
||||||
save_state_on_train_end,
|
save_state_on_train_end,
|
||||||
resume,
|
resume,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
fp8_base,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
block_lr,
|
block_lr,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
|
|
@ -308,6 +366,7 @@ def open_configuration(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -324,18 +383,26 @@ def open_configuration(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
|
fused_backward_pass,
|
||||||
|
fused_optimizer_groups,
|
||||||
sdxl_cache_text_encoder_outputs,
|
sdxl_cache_text_encoder_outputs,
|
||||||
sdxl_no_half_vae,
|
sdxl_no_half_vae,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
|
|
@ -354,6 +421,44 @@ def open_configuration(
|
||||||
metadata_license,
|
metadata_license,
|
||||||
metadata_tags,
|
metadata_tags,
|
||||||
metadata_title,
|
metadata_title,
|
||||||
|
# SD3 parameters
|
||||||
|
sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_fused_backward_pass,
|
||||||
|
clip_g,
|
||||||
|
clip_l,
|
||||||
|
logit_mean,
|
||||||
|
logit_std,
|
||||||
|
mode_scale,
|
||||||
|
save_clip,
|
||||||
|
save_t5xxl,
|
||||||
|
t5xxl,
|
||||||
|
t5xxl_device,
|
||||||
|
t5xxl_dtype,
|
||||||
|
sd3_text_encoder_batch_size,
|
||||||
|
weighting_scheme,
|
||||||
|
sd3_checkbox,
|
||||||
|
# Flux.1
|
||||||
|
flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
ae,
|
||||||
|
flux1_clip_l,
|
||||||
|
flux1_t5xxl,
|
||||||
|
discrete_flow_shift,
|
||||||
|
model_prediction_type,
|
||||||
|
timestep_sampling,
|
||||||
|
split_mode,
|
||||||
|
train_blocks,
|
||||||
|
t5xxl_max_token_length,
|
||||||
|
guidance_scale,
|
||||||
|
blockwise_fused_optimizers,
|
||||||
|
flux_fused_backward_pass,
|
||||||
|
cpu_offload_checkpointing,
|
||||||
|
blocks_to_swap,
|
||||||
|
single_blocks_to_swap,
|
||||||
|
double_blocks_to_swap,
|
||||||
|
mem_eff_save,
|
||||||
|
apply_t5_attn_mask,
|
||||||
training_preset,
|
training_preset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
|
|
@ -403,6 +508,7 @@ def train_model(
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
sdxl_checkbox,
|
sdxl_checkbox,
|
||||||
|
flux1_checkbox,
|
||||||
train_dir,
|
train_dir,
|
||||||
image_folder,
|
image_folder,
|
||||||
output_dir,
|
output_dir,
|
||||||
|
|
@ -420,6 +526,7 @@ def train_model(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
dataset_repeats,
|
dataset_repeats,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
|
|
@ -454,6 +561,7 @@ def train_model(
|
||||||
save_state_on_train_end,
|
save_state_on_train_end,
|
||||||
resume,
|
resume,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
fp8_base,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
block_lr,
|
block_lr,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
|
|
@ -480,6 +588,7 @@ def train_model(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -496,18 +605,26 @@ def train_model(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
|
fused_backward_pass,
|
||||||
|
fused_optimizer_groups,
|
||||||
sdxl_cache_text_encoder_outputs,
|
sdxl_cache_text_encoder_outputs,
|
||||||
sdxl_no_half_vae,
|
sdxl_no_half_vae,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
|
|
@ -526,6 +643,44 @@ def train_model(
|
||||||
metadata_license,
|
metadata_license,
|
||||||
metadata_tags,
|
metadata_tags,
|
||||||
metadata_title,
|
metadata_title,
|
||||||
|
# SD3 parameters
|
||||||
|
sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_fused_backward_pass,
|
||||||
|
clip_g,
|
||||||
|
clip_l,
|
||||||
|
logit_mean,
|
||||||
|
logit_std,
|
||||||
|
mode_scale,
|
||||||
|
save_clip,
|
||||||
|
save_t5xxl,
|
||||||
|
t5xxl,
|
||||||
|
t5xxl_device,
|
||||||
|
t5xxl_dtype,
|
||||||
|
sd3_text_encoder_batch_size,
|
||||||
|
weighting_scheme,
|
||||||
|
sd3_checkbox,
|
||||||
|
# Flux.1
|
||||||
|
flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
ae,
|
||||||
|
flux1_clip_l,
|
||||||
|
flux1_t5xxl,
|
||||||
|
discrete_flow_shift,
|
||||||
|
model_prediction_type,
|
||||||
|
timestep_sampling,
|
||||||
|
split_mode,
|
||||||
|
train_blocks,
|
||||||
|
t5xxl_max_token_length,
|
||||||
|
guidance_scale,
|
||||||
|
blockwise_fused_optimizers,
|
||||||
|
flux_fused_backward_pass,
|
||||||
|
cpu_offload_checkpointing,
|
||||||
|
blocks_to_swap,
|
||||||
|
single_blocks_to_swap,
|
||||||
|
double_blocks_to_swap,
|
||||||
|
mem_eff_save,
|
||||||
|
apply_t5_attn_mask,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
|
|
@ -558,44 +713,36 @@ def train_model(
|
||||||
|
|
||||||
#
|
#
|
||||||
# Validate paths
|
# Validate paths
|
||||||
#
|
#
|
||||||
|
|
||||||
if not validate_file_path(dataset_config):
|
if not validate_file_path(dataset_config):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(image_folder):
|
if not validate_folder_path(image_folder):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_file_path(log_tracker_config):
|
if not validate_file_path(log_tracker_config):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True):
|
if not validate_folder_path(
|
||||||
|
logging_dir, can_be_written_to=True, create_if_not_exists=True
|
||||||
|
):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True):
|
if not validate_folder_path(
|
||||||
|
output_dir, can_be_written_to=True, create_if_not_exists=True
|
||||||
|
):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_model_path(pretrained_model_name_or_path):
|
if not validate_model_path(pretrained_model_name_or_path):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
if not validate_folder_path(resume):
|
if not validate_folder_path(resume):
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
#
|
#
|
||||||
# End of path validation
|
# End of path validation
|
||||||
#
|
#
|
||||||
|
|
||||||
# if not validate_paths(
|
|
||||||
# dataset_config=dataset_config,
|
|
||||||
# finetune_image_folder=image_folder,
|
|
||||||
# headless=headless,
|
|
||||||
# log_tracker_config=log_tracker_config,
|
|
||||||
# logging_dir=logging_dir,
|
|
||||||
# output_dir=output_dir,
|
|
||||||
# pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
# resume=resume,
|
|
||||||
# ):
|
|
||||||
# return TRAIN_BUTTON_VISIBLE
|
|
||||||
|
|
||||||
if not print_only and check_if_model_exist(
|
if not print_only and check_if_model_exist(
|
||||||
output_name, output_dir, save_model_as, headless
|
output_name, output_dir, save_model_as, headless
|
||||||
|
|
@ -727,10 +874,16 @@ def train_model(
|
||||||
|
|
||||||
log.info(max_train_steps_info)
|
log.info(max_train_steps_info)
|
||||||
|
|
||||||
if max_train_steps != 0:
|
# Calculate lr_warmup_steps
|
||||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
if lr_warmup_steps > 0:
|
||||||
|
lr_warmup_steps = int(lr_warmup_steps)
|
||||||
|
if lr_warmup > 0:
|
||||||
|
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||||
|
elif lr_warmup != 0:
|
||||||
|
lr_warmup_steps = lr_warmup / 100
|
||||||
else:
|
else:
|
||||||
lr_warmup_steps = 0
|
lr_warmup_steps = 0
|
||||||
|
|
||||||
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
||||||
|
|
||||||
accelerate_path = get_executable_path("accelerate")
|
accelerate_path = get_executable_path("accelerate")
|
||||||
|
|
@ -738,7 +891,7 @@ def train_model(
|
||||||
log.error("accelerate not found")
|
log.error("accelerate not found")
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
||||||
run_cmd = [rf'{accelerate_path}', "launch"]
|
run_cmd = [rf"{accelerate_path}", "launch"]
|
||||||
|
|
||||||
run_cmd = AccelerateLaunch.run_cmd(
|
run_cmd = AccelerateLaunch.run_cmd(
|
||||||
run_cmd=run_cmd,
|
run_cmd=run_cmd,
|
||||||
|
|
@ -758,6 +911,10 @@ def train_model(
|
||||||
|
|
||||||
if sdxl_checkbox:
|
if sdxl_checkbox:
|
||||||
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
|
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
|
||||||
|
elif sd3_checkbox:
|
||||||
|
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py")
|
||||||
|
elif flux1_checkbox:
|
||||||
|
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py")
|
||||||
else:
|
else:
|
||||||
run_cmd.append(rf"{scriptdir}/sd-scripts/fine_tune.py")
|
run_cmd.append(rf"{scriptdir}/sd-scripts/fine_tune.py")
|
||||||
|
|
||||||
|
|
@ -766,7 +923,14 @@ def train_model(
|
||||||
if use_latent_files == "Yes"
|
if use_latent_files == "Yes"
|
||||||
else f"{train_dir}/{caption_metadata_filename}"
|
else f"{train_dir}/{caption_metadata_filename}"
|
||||||
)
|
)
|
||||||
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
|
cache_text_encoder_outputs = (
|
||||||
|
(sdxl_checkbox and sdxl_cache_text_encoder_outputs)
|
||||||
|
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
|
||||||
|
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
|
||||||
|
)
|
||||||
|
cache_text_encoder_outputs_to_disk = (
|
||||||
|
sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
|
||||||
|
) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
|
||||||
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
|
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
|
||||||
|
|
||||||
if max_data_loader_n_workers == "" or None:
|
if max_data_loader_n_workers == "" or None:
|
||||||
|
|
@ -791,22 +955,31 @@ def train_model(
|
||||||
"cache_latents": cache_latents,
|
"cache_latents": cache_latents,
|
||||||
"cache_latents_to_disk": cache_latents_to_disk,
|
"cache_latents_to_disk": cache_latents_to_disk,
|
||||||
"cache_text_encoder_outputs": cache_text_encoder_outputs,
|
"cache_text_encoder_outputs": cache_text_encoder_outputs,
|
||||||
|
"cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk,
|
||||||
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
||||||
"caption_dropout_rate": caption_dropout_rate,
|
"caption_dropout_rate": caption_dropout_rate,
|
||||||
"caption_extension": caption_extension,
|
"caption_extension": caption_extension,
|
||||||
|
"clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None,
|
||||||
"clip_skip": clip_skip if clip_skip != 0 else None,
|
"clip_skip": clip_skip if clip_skip != 0 else None,
|
||||||
"color_aug": color_aug,
|
"color_aug": color_aug,
|
||||||
"dataset_config": dataset_config,
|
"dataset_config": dataset_config,
|
||||||
"dataset_repeats": int(dataset_repeats),
|
"dataset_repeats": int(dataset_repeats),
|
||||||
"debiased_estimation_loss": debiased_estimation_loss,
|
"debiased_estimation_loss": debiased_estimation_loss,
|
||||||
|
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
|
||||||
"dynamo_backend": dynamo_backend,
|
"dynamo_backend": dynamo_backend,
|
||||||
"enable_bucket": True,
|
"enable_bucket": True,
|
||||||
"flip_aug": flip_aug,
|
"flip_aug": flip_aug,
|
||||||
|
"fp8_base": fp8_base,
|
||||||
"full_bf16": full_bf16,
|
"full_bf16": full_bf16,
|
||||||
"full_fp16": full_fp16,
|
"full_fp16": full_fp16,
|
||||||
|
"fused_backward_pass": sd3_fused_backward_pass if sd3_checkbox else flux_fused_backward_pass if flux1_checkbox else fused_backward_pass,
|
||||||
|
"fused_optimizer_groups": (
|
||||||
|
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
|
||||||
|
),
|
||||||
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
||||||
"gradient_checkpointing": gradient_checkpointing,
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
"huber_c": huber_c,
|
"huber_c": huber_c,
|
||||||
|
"huber_scale": huber_scale,
|
||||||
"huber_schedule": huber_schedule,
|
"huber_schedule": huber_schedule,
|
||||||
"huggingface_repo_id": huggingface_repo_id,
|
"huggingface_repo_id": huggingface_repo_id,
|
||||||
"huggingface_token": huggingface_token,
|
"huggingface_token": huggingface_token,
|
||||||
|
|
@ -828,11 +1001,13 @@ def train_model(
|
||||||
learning_rate_te2 if sdxl_checkbox else None
|
learning_rate_te2 if sdxl_checkbox else None
|
||||||
), # only for sdxl
|
), # only for sdxl
|
||||||
"logging_dir": logging_dir,
|
"logging_dir": logging_dir,
|
||||||
|
"log_config": log_config,
|
||||||
"log_tracker_name": log_tracker_name,
|
"log_tracker_name": log_tracker_name,
|
||||||
"log_tracker_config": log_tracker_config,
|
"log_tracker_config": log_tracker_config,
|
||||||
"loss_type": loss_type,
|
"loss_type": loss_type,
|
||||||
"lr_scheduler": lr_scheduler,
|
"lr_scheduler": lr_scheduler,
|
||||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
||||||
|
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
|
||||||
"lr_warmup_steps": lr_warmup_steps,
|
"lr_warmup_steps": lr_warmup_steps,
|
||||||
"masked_loss": masked_loss,
|
"masked_loss": masked_loss,
|
||||||
"max_bucket_reso": int(max_bucket_reso),
|
"max_bucket_reso": int(max_bucket_reso),
|
||||||
|
|
@ -886,6 +1061,10 @@ def train_model(
|
||||||
"save_last_n_steps_state": (
|
"save_last_n_steps_state": (
|
||||||
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
||||||
),
|
),
|
||||||
|
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
|
||||||
|
"save_last_n_epochs_state": (
|
||||||
|
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
|
||||||
|
),
|
||||||
"save_model_as": save_model_as,
|
"save_model_as": save_model_as,
|
||||||
"save_precision": save_precision,
|
"save_precision": save_precision,
|
||||||
"save_state": save_state,
|
"save_state": save_state,
|
||||||
|
|
@ -895,6 +1074,8 @@ def train_model(
|
||||||
"sdpa": True if xformers == "sdpa" else None,
|
"sdpa": True if xformers == "sdpa" else None,
|
||||||
"seed": int(seed) if int(seed) != 0 else None,
|
"seed": int(seed) if int(seed) != 0 else None,
|
||||||
"shuffle_caption": shuffle_caption,
|
"shuffle_caption": shuffle_caption,
|
||||||
|
"skip_cache_check": skip_cache_check,
|
||||||
|
"t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None,
|
||||||
"train_batch_size": train_batch_size,
|
"train_batch_size": train_batch_size,
|
||||||
"train_data_dir": image_folder,
|
"train_data_dir": image_folder,
|
||||||
"train_text_encoder": train_text_encoder,
|
"train_text_encoder": train_text_encoder,
|
||||||
|
|
@ -904,9 +1085,50 @@ def train_model(
|
||||||
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
||||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||||
"wandb_api_key": wandb_api_key,
|
"wandb_api_key": wandb_api_key,
|
||||||
"wandb_run_name": wandb_run_name,
|
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||||
"weighted_captions": weighted_captions,
|
"weighted_captions": weighted_captions,
|
||||||
"xformers": True if xformers == "xformers" else None,
|
"xformers": True if xformers == "xformers" else None,
|
||||||
|
# SD3 only Parameters
|
||||||
|
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||||
|
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||||
|
"clip_g": clip_g if sd3_checkbox else None,
|
||||||
|
# "clip_l": see previous assignment above for code
|
||||||
|
"logit_mean": logit_mean if sd3_checkbox else None,
|
||||||
|
"logit_std": logit_std if sd3_checkbox else None,
|
||||||
|
"mode_scale": mode_scale if sd3_checkbox else None,
|
||||||
|
"save_clip": save_clip if sd3_checkbox else None,
|
||||||
|
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
|
||||||
|
# "t5xxl": see previous assignment above for code
|
||||||
|
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
|
||||||
|
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
|
||||||
|
"text_encoder_batch_size": (
|
||||||
|
sd3_text_encoder_batch_size if sd3_checkbox else None
|
||||||
|
),
|
||||||
|
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
|
||||||
|
# Flux.1 specific parameters
|
||||||
|
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||||
|
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||||
|
"ae": ae if flux1_checkbox else None,
|
||||||
|
# "clip_l": see previous assignment above for code
|
||||||
|
# "t5xxl": see previous assignment above for code
|
||||||
|
"discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None,
|
||||||
|
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
|
||||||
|
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
|
||||||
|
"split_mode": split_mode if flux1_checkbox else None,
|
||||||
|
"train_blocks": train_blocks if flux1_checkbox else None,
|
||||||
|
"t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None,
|
||||||
|
"guidance_scale": guidance_scale if flux1_checkbox else None,
|
||||||
|
"blockwise_fused_optimizers": (
|
||||||
|
blockwise_fused_optimizers if flux1_checkbox else None
|
||||||
|
),
|
||||||
|
"cpu_offload_checkpointing": (
|
||||||
|
cpu_offload_checkpointing if flux1_checkbox else None
|
||||||
|
),
|
||||||
|
"blocks_to_swap": blocks_to_swap if flux1_checkbox else None,
|
||||||
|
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
|
||||||
|
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
|
||||||
|
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
|
||||||
|
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Given dictionary `config_toml_data`
|
# Given dictionary `config_toml_data`
|
||||||
|
|
@ -924,7 +1146,7 @@ def train_model(
|
||||||
|
|
||||||
current_datetime = datetime.now()
|
current_datetime = datetime.now()
|
||||||
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
||||||
tmpfilename = fr"{output_dir}/config_finetune-{formatted_datetime}.toml"
|
tmpfilename = rf"{output_dir}/config_finetune-{formatted_datetime}.toml"
|
||||||
# Save the updated TOML data back to the file
|
# Save the updated TOML data back to the file
|
||||||
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
|
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
|
||||||
toml.dump(config_toml_data, toml_file)
|
toml.dump(config_toml_data, toml_file)
|
||||||
|
|
@ -1090,7 +1312,9 @@ def finetune_tab(
|
||||||
|
|
||||||
# Add SDXL Parameters
|
# Add SDXL Parameters
|
||||||
sdxl_params = SDXLParameters(
|
sdxl_params = SDXLParameters(
|
||||||
source_model.sdxl_checkbox, config=config
|
source_model.sdxl_checkbox,
|
||||||
|
config=config,
|
||||||
|
trainer="finetune",
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
@ -1099,6 +1323,19 @@ def finetune_tab(
|
||||||
label="Train text encoder", value=True
|
label="Train text encoder", value=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add FLUX1 Parameters
|
||||||
|
flux1_training = flux1Training(
|
||||||
|
headless=headless,
|
||||||
|
config=config,
|
||||||
|
flux1_checkbox=source_model.flux1_checkbox,
|
||||||
|
finetuning=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add SD3 Parameters
|
||||||
|
sd3_training = sd3Training(
|
||||||
|
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gradient_accumulation_steps = gr.Slider(
|
gradient_accumulation_steps = gr.Slider(
|
||||||
|
|
@ -1146,6 +1383,7 @@ def finetune_tab(
|
||||||
source_model.v2,
|
source_model.v2,
|
||||||
source_model.v_parameterization,
|
source_model.v_parameterization,
|
||||||
source_model.sdxl_checkbox,
|
source_model.sdxl_checkbox,
|
||||||
|
source_model.flux1_checkbox,
|
||||||
train_dir,
|
train_dir,
|
||||||
image_folder,
|
image_folder,
|
||||||
output_dir,
|
output_dir,
|
||||||
|
|
@ -1163,6 +1401,7 @@ def finetune_tab(
|
||||||
basic_training.learning_rate,
|
basic_training.learning_rate,
|
||||||
basic_training.lr_scheduler,
|
basic_training.lr_scheduler,
|
||||||
basic_training.lr_warmup,
|
basic_training.lr_warmup,
|
||||||
|
basic_training.lr_warmup_steps,
|
||||||
dataset_repeats,
|
dataset_repeats,
|
||||||
basic_training.train_batch_size,
|
basic_training.train_batch_size,
|
||||||
basic_training.epoch,
|
basic_training.epoch,
|
||||||
|
|
@ -1196,6 +1435,7 @@ def finetune_tab(
|
||||||
advanced_training.save_state_on_train_end,
|
advanced_training.save_state_on_train_end,
|
||||||
advanced_training.resume,
|
advanced_training.resume,
|
||||||
advanced_training.gradient_checkpointing,
|
advanced_training.gradient_checkpointing,
|
||||||
|
advanced_training.fp8_base,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
block_lr,
|
block_lr,
|
||||||
advanced_training.mem_eff_attn,
|
advanced_training.mem_eff_attn,
|
||||||
|
|
@ -1222,6 +1462,7 @@ def finetune_tab(
|
||||||
basic_training.optimizer,
|
basic_training.optimizer,
|
||||||
basic_training.optimizer_args,
|
basic_training.optimizer_args,
|
||||||
basic_training.lr_scheduler_args,
|
basic_training.lr_scheduler_args,
|
||||||
|
basic_training.lr_scheduler_type,
|
||||||
advanced_training.noise_offset_type,
|
advanced_training.noise_offset_type,
|
||||||
advanced_training.noise_offset,
|
advanced_training.noise_offset,
|
||||||
advanced_training.noise_offset_random_strength,
|
advanced_training.noise_offset_random_strength,
|
||||||
|
|
@ -1238,18 +1479,26 @@ def finetune_tab(
|
||||||
advanced_training.loss_type,
|
advanced_training.loss_type,
|
||||||
advanced_training.huber_schedule,
|
advanced_training.huber_schedule,
|
||||||
advanced_training.huber_c,
|
advanced_training.huber_c,
|
||||||
|
advanced_training.huber_scale,
|
||||||
advanced_training.vae_batch_size,
|
advanced_training.vae_batch_size,
|
||||||
advanced_training.min_snr_gamma,
|
advanced_training.min_snr_gamma,
|
||||||
weighted_captions,
|
weighted_captions,
|
||||||
advanced_training.save_every_n_steps,
|
advanced_training.save_every_n_steps,
|
||||||
advanced_training.save_last_n_steps,
|
advanced_training.save_last_n_steps,
|
||||||
advanced_training.save_last_n_steps_state,
|
advanced_training.save_last_n_steps_state,
|
||||||
|
advanced_training.save_last_n_epochs,
|
||||||
|
advanced_training.save_last_n_epochs_state,
|
||||||
|
advanced_training.skip_cache_check,
|
||||||
advanced_training.log_with,
|
advanced_training.log_with,
|
||||||
advanced_training.wandb_api_key,
|
advanced_training.wandb_api_key,
|
||||||
advanced_training.wandb_run_name,
|
advanced_training.wandb_run_name,
|
||||||
advanced_training.log_tracker_name,
|
advanced_training.log_tracker_name,
|
||||||
advanced_training.log_tracker_config,
|
advanced_training.log_tracker_config,
|
||||||
|
advanced_training.log_config,
|
||||||
advanced_training.scale_v_pred_loss_like_noise_pred,
|
advanced_training.scale_v_pred_loss_like_noise_pred,
|
||||||
|
sdxl_params.disable_mmap_load_safetensors,
|
||||||
|
sdxl_params.fused_backward_pass,
|
||||||
|
sdxl_params.fused_optimizer_groups,
|
||||||
sdxl_params.sdxl_cache_text_encoder_outputs,
|
sdxl_params.sdxl_cache_text_encoder_outputs,
|
||||||
sdxl_params.sdxl_no_half_vae,
|
sdxl_params.sdxl_no_half_vae,
|
||||||
advanced_training.min_timestep,
|
advanced_training.min_timestep,
|
||||||
|
|
@ -1268,6 +1517,44 @@ def finetune_tab(
|
||||||
metadata.metadata_license,
|
metadata.metadata_license,
|
||||||
metadata.metadata_tags,
|
metadata.metadata_tags,
|
||||||
metadata.metadata_title,
|
metadata.metadata_title,
|
||||||
|
# SD3 Parameters
|
||||||
|
sd3_training.sd3_cache_text_encoder_outputs,
|
||||||
|
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
|
||||||
|
sd3_training.clip_g,
|
||||||
|
sd3_training.clip_l,
|
||||||
|
sd3_training.logit_mean,
|
||||||
|
sd3_training.logit_std,
|
||||||
|
sd3_training.mode_scale,
|
||||||
|
sd3_training.save_clip,
|
||||||
|
sd3_training.save_t5xxl,
|
||||||
|
sd3_training.t5xxl,
|
||||||
|
sd3_training.t5xxl_device,
|
||||||
|
sd3_training.t5xxl_dtype,
|
||||||
|
sd3_training.sd3_text_encoder_batch_size,
|
||||||
|
sd3_training.sd3_fused_backward_pass,
|
||||||
|
sd3_training.weighting_scheme,
|
||||||
|
source_model.sd3_checkbox,
|
||||||
|
# Flux1 parameters
|
||||||
|
flux1_training.flux1_cache_text_encoder_outputs,
|
||||||
|
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
|
||||||
|
flux1_training.ae,
|
||||||
|
flux1_training.clip_l,
|
||||||
|
flux1_training.t5xxl,
|
||||||
|
flux1_training.discrete_flow_shift,
|
||||||
|
flux1_training.model_prediction_type,
|
||||||
|
flux1_training.timestep_sampling,
|
||||||
|
flux1_training.split_mode,
|
||||||
|
flux1_training.train_blocks,
|
||||||
|
flux1_training.t5xxl_max_token_length,
|
||||||
|
flux1_training.guidance_scale,
|
||||||
|
flux1_training.blockwise_fused_optimizers,
|
||||||
|
flux1_training.flux_fused_backward_pass,
|
||||||
|
flux1_training.cpu_offload_checkpointing,
|
||||||
|
advanced_training.blocks_to_swap,
|
||||||
|
flux1_training.single_blocks_to_swap,
|
||||||
|
flux1_training.double_blocks_to_swap,
|
||||||
|
flux1_training.mem_eff_save,
|
||||||
|
flux1_training.apply_t5_attn_mask,
|
||||||
]
|
]
|
||||||
|
|
||||||
configuration.button_open_config.click(
|
configuration.button_open_config.click(
|
||||||
|
|
@ -1353,4 +1640,4 @@ def finetune_tab(
|
||||||
if os.path.exists(top_level_path):
|
if os.path.exists(top_level_path):
|
||||||
with open(os.path.join(top_level_path), "r", encoding="utf-8") as file:
|
with open(os.path.join(top_level_path), "r", encoding="utf-8") as file:
|
||||||
guides_top_level = file.read() + "\n"
|
guides_top_level = file.read() + "\n"
|
||||||
gr.Markdown(guides_top_level)
|
gr.Markdown(guides_top_level)
|
||||||
|
|
@ -0,0 +1,273 @@
|
||||||
|
import gradio as gr
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from .common_gui import (
|
||||||
|
get_saveasfilename_path,
|
||||||
|
get_file_path,
|
||||||
|
scriptdir,
|
||||||
|
list_files,
|
||||||
|
create_refresh_button,
|
||||||
|
setup_environment,
|
||||||
|
)
|
||||||
|
from .custom_logging import setup_logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
log = setup_logging()
|
||||||
|
|
||||||
|
folder_symbol = "\U0001f4c2" # 📂
|
||||||
|
refresh_symbol = "\U0001f504" # 🔄
|
||||||
|
save_style_symbol = "\U0001f4be" # 💾
|
||||||
|
document_symbol = "\U0001F4C4" # 📄
|
||||||
|
|
||||||
|
PYTHON = sys.executable
|
||||||
|
|
||||||
|
|
||||||
|
def extract_flux_lora(
|
||||||
|
model_org,
|
||||||
|
model_tuned,
|
||||||
|
save_to,
|
||||||
|
save_precision,
|
||||||
|
dim,
|
||||||
|
device,
|
||||||
|
clamp_quantile,
|
||||||
|
no_metadata,
|
||||||
|
mem_eff_safe_open,
|
||||||
|
):
|
||||||
|
# Check for required inputs
|
||||||
|
if model_org == "" or model_tuned == "" or save_to == "":
|
||||||
|
log.info(
|
||||||
|
"Please provide all required inputs: original model, tuned model, and save path."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if source models exist
|
||||||
|
if not os.path.isfile(model_org):
|
||||||
|
log.info("The provided original model is not a file")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.isfile(model_tuned):
|
||||||
|
log.info("The provided tuned model is not a file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare save path
|
||||||
|
if os.path.dirname(save_to) == "":
|
||||||
|
save_to = os.path.join(os.path.dirname(model_tuned), save_to)
|
||||||
|
if os.path.isdir(save_to):
|
||||||
|
save_to = os.path.join(save_to, "flux_lora.safetensors")
|
||||||
|
if os.path.normpath(model_tuned) == os.path.normpath(save_to):
|
||||||
|
path, ext = os.path.splitext(save_to)
|
||||||
|
save_to = f"{path}_lora{ext}"
|
||||||
|
|
||||||
|
run_cmd = [
|
||||||
|
rf"{PYTHON}",
|
||||||
|
rf"{scriptdir}/sd-scripts/networks/flux_extract_lora.py",
|
||||||
|
"--model_org",
|
||||||
|
rf"{model_org}",
|
||||||
|
"--model_tuned",
|
||||||
|
rf"{model_tuned}",
|
||||||
|
"--save_to",
|
||||||
|
rf"{save_to}",
|
||||||
|
"--dim",
|
||||||
|
str(dim),
|
||||||
|
"--device",
|
||||||
|
device,
|
||||||
|
"--clamp_quantile",
|
||||||
|
str(clamp_quantile),
|
||||||
|
]
|
||||||
|
|
||||||
|
if save_precision:
|
||||||
|
run_cmd.extend(["--save_precision", save_precision])
|
||||||
|
|
||||||
|
if no_metadata:
|
||||||
|
run_cmd.append("--no_metadata")
|
||||||
|
|
||||||
|
if mem_eff_safe_open:
|
||||||
|
run_cmd.append("--mem_eff_safe_open")
|
||||||
|
|
||||||
|
env = setup_environment()
|
||||||
|
|
||||||
|
# Reconstruct the safe command string for display
|
||||||
|
command_to_run = " ".join(run_cmd)
|
||||||
|
log.info(f"Executing command: {command_to_run}")
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd, env=env)
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_flux_extract_lora_tab(headless=False):
|
||||||
|
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||||
|
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||||
|
|
||||||
|
def list_models(path):
|
||||||
|
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||||
|
|
||||||
|
with gr.Tab("Extract Flux LoRA"):
|
||||||
|
gr.Markdown(
|
||||||
|
"This utility can extract a LoRA network from a finetuned Flux model."
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||||
|
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||||
|
model_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||||
|
model_ext_name = gr.Textbox(value="Model types", visible=False)
|
||||||
|
|
||||||
|
with gr.Group(), gr.Row():
|
||||||
|
model_org = gr.Dropdown(
|
||||||
|
label="Original Flux model (path to the original model)",
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_models(current_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
model_org,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_models(current_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_model_org_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not headless),
|
||||||
|
)
|
||||||
|
button_model_org_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[model_org, model_ext, model_ext_name],
|
||||||
|
outputs=model_org,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_tuned = gr.Dropdown(
|
||||||
|
label="Finetuned Flux model (path to the finetuned model to extract)",
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_models(current_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
model_tuned,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_models(current_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_model_tuned_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not headless),
|
||||||
|
)
|
||||||
|
button_model_tuned_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[model_tuned, model_ext, model_ext_name],
|
||||||
|
outputs=model_tuned,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group(), gr.Row():
|
||||||
|
save_to = gr.Dropdown(
|
||||||
|
label="Save to (path where to save the extracted LoRA model...)",
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_models(current_save_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
save_to,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_models(current_save_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_save_to = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not headless),
|
||||||
|
)
|
||||||
|
button_save_to.click(
|
||||||
|
get_saveasfilename_path,
|
||||||
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
|
outputs=save_to,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_precision = gr.Dropdown(
|
||||||
|
label="Save precision",
|
||||||
|
choices=["None", "float", "fp16", "bf16"],
|
||||||
|
value="None",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
dim = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=1024,
|
||||||
|
label="Network Dimension (Rank)",
|
||||||
|
value=4,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
device = gr.Dropdown(
|
||||||
|
label="Device",
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
value="cuda",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
clamp_quantile = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
label="Clamp Quantile",
|
||||||
|
value=0.99,
|
||||||
|
step=0.01,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
no_metadata = gr.Checkbox(
|
||||||
|
label="No metadata (do not save sai modelspec metadata)",
|
||||||
|
value=False,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
mem_eff_safe_open = gr.Checkbox(
|
||||||
|
label="Memory efficient safe open (experimental feature)",
|
||||||
|
value=False,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_button = gr.Button("Extract Flux LoRA model")
|
||||||
|
|
||||||
|
extract_button.click(
|
||||||
|
extract_flux_lora,
|
||||||
|
inputs=[
|
||||||
|
model_org,
|
||||||
|
model_tuned,
|
||||||
|
save_to,
|
||||||
|
save_precision,
|
||||||
|
dim,
|
||||||
|
device,
|
||||||
|
clamp_quantile,
|
||||||
|
no_metadata,
|
||||||
|
mem_eff_safe_open,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_org.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||||
|
inputs=model_org,
|
||||||
|
outputs=model_org,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
model_tuned.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||||
|
inputs=model_tuned,
|
||||||
|
outputs=model_tuned,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
save_to.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||||
|
inputs=save_to,
|
||||||
|
outputs=save_to,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,470 @@
|
||||||
|
# Standard library imports
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Third-party imports
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
# Local module imports
|
||||||
|
from .common_gui import (
|
||||||
|
get_saveasfilename_path,
|
||||||
|
get_file_path,
|
||||||
|
scriptdir,
|
||||||
|
list_files,
|
||||||
|
create_refresh_button,
|
||||||
|
setup_environment,
|
||||||
|
)
|
||||||
|
from .custom_logging import setup_logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
log = setup_logging()
|
||||||
|
|
||||||
|
folder_symbol = "\U0001f4c2" # 📂
|
||||||
|
refresh_symbol = "\U0001f504" # 🔄
|
||||||
|
save_style_symbol = "\U0001f4be" # 💾
|
||||||
|
document_symbol = "\U0001F4C4" # 📄
|
||||||
|
|
||||||
|
PYTHON = sys.executable
|
||||||
|
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
|
if not model:
|
||||||
|
return True
|
||||||
|
if not os.path.isfile(model):
|
||||||
|
log.info(f"The provided {model} is not a file")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def verify_conditions(flux_model, lora_models):
|
||||||
|
lora_models_count = sum(1 for model in lora_models if model)
|
||||||
|
if flux_model and lora_models_count >= 1:
|
||||||
|
return True
|
||||||
|
elif not flux_model and lora_models_count >= 2:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class GradioFluxMergeLoRaTab:
|
||||||
|
def __init__(self, headless=False):
|
||||||
|
self.headless = headless
|
||||||
|
self.build_tab()
|
||||||
|
|
||||||
|
def save_inputs_to_json(self, file_path, inputs):
|
||||||
|
with open(file_path, "w", encoding="utf-8") as file:
|
||||||
|
json.dump(inputs, file)
|
||||||
|
log.info(f"Saved inputs to {file_path}")
|
||||||
|
|
||||||
|
def load_inputs_from_json(self, file_path):
|
||||||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
|
inputs = json.load(file)
|
||||||
|
log.info(f"Loaded inputs from {file_path}")
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def build_tab(self):
|
||||||
|
current_flux_model_dir = os.path.join(scriptdir, "outputs")
|
||||||
|
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||||
|
current_lora_model_dir = current_flux_model_dir
|
||||||
|
|
||||||
|
def list_flux_models(path):
|
||||||
|
nonlocal current_flux_model_dir
|
||||||
|
current_flux_model_dir = path
|
||||||
|
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||||
|
|
||||||
|
def list_lora_models(path):
|
||||||
|
nonlocal current_lora_model_dir
|
||||||
|
current_lora_model_dir = path
|
||||||
|
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||||
|
|
||||||
|
def list_save_to(path):
|
||||||
|
nonlocal current_save_dir
|
||||||
|
current_save_dir = path
|
||||||
|
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||||
|
|
||||||
|
with gr.Tab("Merge FLUX LoRA"):
|
||||||
|
gr.Markdown(
|
||||||
|
"This utility can merge up to 4 LoRA into a FLUX model or alternatively merge up to 4 LoRA together."
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||||
|
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||||
|
flux_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||||
|
flux_ext_name = gr.Textbox(value="FLUX model types", visible=False)
|
||||||
|
|
||||||
|
with gr.Group(), gr.Row():
|
||||||
|
flux_model = gr.Dropdown(
|
||||||
|
label="FLUX Model (Optional. FLUX model path, if you want to merge it with LoRA files via the 'concat' method)",
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_flux_models(current_flux_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
flux_model,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_flux_models(current_flux_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
flux_model_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not self.headless),
|
||||||
|
)
|
||||||
|
flux_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[flux_model, flux_ext, flux_ext_name],
|
||||||
|
outputs=flux_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
flux_model.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_flux_models(path)),
|
||||||
|
inputs=flux_model,
|
||||||
|
outputs=flux_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group(), gr.Row():
|
||||||
|
lora_a_model = gr.Dropdown(
|
||||||
|
label='LoRA model "A" (path to the LoRA A model)',
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
lora_a_model,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_lora_a_model_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not self.headless),
|
||||||
|
)
|
||||||
|
button_lora_a_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
|
outputs=lora_a_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_b_model = gr.Dropdown(
|
||||||
|
label='LoRA model "B" (path to the LoRA B model)',
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
lora_b_model,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_lora_b_model_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not self.headless),
|
||||||
|
)
|
||||||
|
button_lora_b_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||||
|
outputs=lora_b_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_a_model.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||||
|
inputs=lora_a_model,
|
||||||
|
outputs=lora_a_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
lora_b_model.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||||
|
inputs=lora_b_model,
|
||||||
|
outputs=lora_b_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
ratio_a = gr.Slider(
|
||||||
|
label="Model A merge ratio (eg: 0.5 mean 50%)",
|
||||||
|
minimum=0,
|
||||||
|
maximum=2,
|
||||||
|
step=0.01,
|
||||||
|
value=0.0,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio_b = gr.Slider(
|
||||||
|
label="Model B merge ratio (eg: 0.5 mean 50%)",
|
||||||
|
minimum=0,
|
||||||
|
maximum=2,
|
||||||
|
step=0.01,
|
||||||
|
value=0.0,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group(), gr.Row():
|
||||||
|
lora_c_model = gr.Dropdown(
|
||||||
|
label='LoRA model "C" (path to the LoRA C model)',
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
lora_c_model,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_lora_c_model_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not self.headless),
|
||||||
|
)
|
||||||
|
button_lora_c_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[lora_c_model, lora_ext, lora_ext_name],
|
||||||
|
outputs=lora_c_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_d_model = gr.Dropdown(
|
||||||
|
label='LoRA model "D" (path to the LoRA D model)',
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
lora_d_model,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_lora_d_model_file = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not self.headless),
|
||||||
|
)
|
||||||
|
button_lora_d_model_file.click(
|
||||||
|
get_file_path,
|
||||||
|
inputs=[lora_d_model, lora_ext, lora_ext_name],
|
||||||
|
outputs=lora_d_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
lora_c_model.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||||
|
inputs=lora_c_model,
|
||||||
|
outputs=lora_c_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
lora_d_model.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||||
|
inputs=lora_d_model,
|
||||||
|
outputs=lora_d_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
ratio_c = gr.Slider(
|
||||||
|
label="Model C merge ratio (eg: 0.5 mean 50%)",
|
||||||
|
minimum=0,
|
||||||
|
maximum=2,
|
||||||
|
step=0.01,
|
||||||
|
value=0.0,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio_d = gr.Slider(
|
||||||
|
label="Model D merge ratio (eg: 0.5 mean 50%)",
|
||||||
|
minimum=0,
|
||||||
|
maximum=2,
|
||||||
|
step=0.01,
|
||||||
|
value=0.0,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group(), gr.Row():
|
||||||
|
save_to = gr.Dropdown(
|
||||||
|
label="Save to (path for the file to save...)",
|
||||||
|
interactive=True,
|
||||||
|
choices=[""] + list_save_to(current_save_dir),
|
||||||
|
value="",
|
||||||
|
allow_custom_value=True,
|
||||||
|
)
|
||||||
|
create_refresh_button(
|
||||||
|
save_to,
|
||||||
|
lambda: None,
|
||||||
|
lambda: {"choices": list_save_to(current_save_dir)},
|
||||||
|
"open_folder_small",
|
||||||
|
)
|
||||||
|
button_save_to = gr.Button(
|
||||||
|
folder_symbol,
|
||||||
|
elem_id="open_folder_small",
|
||||||
|
elem_classes=["tool"],
|
||||||
|
visible=(not self.headless),
|
||||||
|
)
|
||||||
|
button_save_to.click(
|
||||||
|
get_saveasfilename_path,
|
||||||
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
|
outputs=save_to,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
precision = gr.Radio(
|
||||||
|
label="Merge precision",
|
||||||
|
choices=["float", "fp16", "bf16"],
|
||||||
|
value="float",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
save_precision = gr.Radio(
|
||||||
|
label="Save precision",
|
||||||
|
choices=["float", "fp16", "bf16", "fp8"],
|
||||||
|
value="fp16",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_to.change(
|
||||||
|
fn=lambda path: gr.Dropdown(choices=[""] + list_save_to(path)),
|
||||||
|
inputs=save_to,
|
||||||
|
outputs=save_to,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
loading_device = gr.Dropdown(
|
||||||
|
label="Loading device",
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
value="cpu",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
working_device = gr.Dropdown(
|
||||||
|
label="Working device",
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
value="cpu",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
concat = gr.Checkbox(label="Concat LoRA", value=False)
|
||||||
|
shuffle = gr.Checkbox(label="Shuffle LoRA weights", value=False)
|
||||||
|
no_metadata = gr.Checkbox(label="Don't save metadata", value=False)
|
||||||
|
diffusers = gr.Checkbox(label="Diffusers LoRA", value=False)
|
||||||
|
|
||||||
|
merge_button = gr.Button("Merge model")
|
||||||
|
|
||||||
|
merge_button.click(
|
||||||
|
self.merge_flux_lora,
|
||||||
|
inputs=[
|
||||||
|
flux_model,
|
||||||
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
lora_c_model,
|
||||||
|
lora_d_model,
|
||||||
|
ratio_a,
|
||||||
|
ratio_b,
|
||||||
|
ratio_c,
|
||||||
|
ratio_d,
|
||||||
|
save_to,
|
||||||
|
precision,
|
||||||
|
save_precision,
|
||||||
|
loading_device,
|
||||||
|
working_device,
|
||||||
|
concat,
|
||||||
|
shuffle,
|
||||||
|
no_metadata,
|
||||||
|
diffusers,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def merge_flux_lora(
|
||||||
|
self,
|
||||||
|
flux_model,
|
||||||
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
lora_c_model,
|
||||||
|
lora_d_model,
|
||||||
|
ratio_a,
|
||||||
|
ratio_b,
|
||||||
|
ratio_c,
|
||||||
|
ratio_d,
|
||||||
|
save_to,
|
||||||
|
precision,
|
||||||
|
save_precision,
|
||||||
|
loading_device,
|
||||||
|
working_device,
|
||||||
|
concat,
|
||||||
|
shuffle,
|
||||||
|
no_metadata,
|
||||||
|
difffusers,
|
||||||
|
):
|
||||||
|
log.info("Merge FLUX LoRA...")
|
||||||
|
models = [
|
||||||
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
lora_c_model,
|
||||||
|
lora_d_model,
|
||||||
|
]
|
||||||
|
lora_models = [model for model in models if model]
|
||||||
|
ratios = [ratio for model, ratio in zip(models, [ratio_a, ratio_b, ratio_c, ratio_d]) if model]
|
||||||
|
|
||||||
|
# if not verify_conditions(flux_model, lora_models):
|
||||||
|
# log.info(
|
||||||
|
# "Warning: Either provide at least one LoRA model along with the FLUX model or at least two LoRA models if no FLUX model is provided."
|
||||||
|
# )
|
||||||
|
# return
|
||||||
|
|
||||||
|
for model in [flux_model] + lora_models:
|
||||||
|
if not check_model(model):
|
||||||
|
return
|
||||||
|
|
||||||
|
run_cmd = [rf"{PYTHON}", rf"{scriptdir}/sd-scripts/networks/flux_merge_lora.py"]
|
||||||
|
|
||||||
|
if flux_model:
|
||||||
|
run_cmd.extend(["--flux_model", rf"{flux_model}"])
|
||||||
|
|
||||||
|
run_cmd.extend([
|
||||||
|
"--save_precision", save_precision,
|
||||||
|
"--precision", precision,
|
||||||
|
"--save_to", rf"{save_to}",
|
||||||
|
"--loading_device", loading_device,
|
||||||
|
"--working_device", working_device,
|
||||||
|
])
|
||||||
|
|
||||||
|
if lora_models:
|
||||||
|
run_cmd.append("--models")
|
||||||
|
run_cmd.extend(lora_models)
|
||||||
|
run_cmd.append("--ratios")
|
||||||
|
run_cmd.extend(map(str, ratios))
|
||||||
|
|
||||||
|
if concat:
|
||||||
|
run_cmd.append("--concat")
|
||||||
|
if shuffle:
|
||||||
|
run_cmd.append("--shuffle")
|
||||||
|
if no_metadata:
|
||||||
|
run_cmd.append("--no_metadata")
|
||||||
|
if difffusers:
|
||||||
|
run_cmd.append("--diffusers")
|
||||||
|
|
||||||
|
env = setup_environment()
|
||||||
|
|
||||||
|
# Reconstruct the safe command string for display
|
||||||
|
command_to_run = " ".join(run_cmd)
|
||||||
|
log.info(f"Executing command: {command_to_run}")
|
||||||
|
|
||||||
|
# Run the command in the sd-scripts folder context
|
||||||
|
subprocess.run(run_cmd, env=env)
|
||||||
|
|
||||||
|
log.info("Done merging...")
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -16,6 +16,7 @@ from .common_gui import (
|
||||||
create_refresh_button, setup_environment
|
create_refresh_button, setup_environment
|
||||||
)
|
)
|
||||||
from .custom_logging import setup_logging
|
from .custom_logging import setup_logging
|
||||||
|
from .sd_modeltype import SDModelType
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
log = setup_logging()
|
log = setup_logging()
|
||||||
|
|
@ -145,6 +146,13 @@ class GradioMergeLoRaTab:
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#secondary event on sd_model for auto-detection of SDXL
|
||||||
|
sd_model.change(
|
||||||
|
lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()),
|
||||||
|
inputs=sd_model,
|
||||||
|
outputs=sdxl_model
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Group(), gr.Row():
|
with gr.Group(), gr.Row():
|
||||||
lora_a_model = gr.Dropdown(
|
lora_a_model = gr.Dropdown(
|
||||||
label='LoRA model "A" (path to the LoRA A model)',
|
label='LoRA model "A" (path to the LoRA A model)',
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
from os.path import isfile
|
||||||
|
from safetensors import safe_open
|
||||||
|
import enum
|
||||||
|
|
||||||
|
# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(enum.Enum):
|
||||||
|
UNKNOWN = 0
|
||||||
|
SD1 = 1
|
||||||
|
SD2 = 2
|
||||||
|
SDXL = 3
|
||||||
|
SD3 = 4
|
||||||
|
FLUX1 = 5
|
||||||
|
|
||||||
|
|
||||||
|
class SDModelType:
|
||||||
|
def __init__(self, safetensors_path):
|
||||||
|
self.model_type = ModelType.UNKNOWN
|
||||||
|
|
||||||
|
if not isfile(safetensors_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
st = safe_open(filename=safetensors_path, framework="numpy", device="cpu")
|
||||||
|
|
||||||
|
# print(st.keys())
|
||||||
|
|
||||||
|
def hasKeyPrefix(pfx):
|
||||||
|
return any(k.startswith(pfx) for k in st.keys())
|
||||||
|
|
||||||
|
if "model.diffusion_model.x_embedder.proj.weight" in st.keys():
|
||||||
|
self.model_type = ModelType.SD3
|
||||||
|
elif (
|
||||||
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
|
||||||
|
in st.keys()
|
||||||
|
or "double_blocks.0.img_attn.norm.key_norm.scale" in st.keys()
|
||||||
|
):
|
||||||
|
# print("flux1 model detected...")
|
||||||
|
self.model_type = ModelType.FLUX1
|
||||||
|
elif hasKeyPrefix("conditioner."):
|
||||||
|
self.model_type = ModelType.SDXL
|
||||||
|
elif hasKeyPrefix("cond_stage_model.model."):
|
||||||
|
self.model_type = ModelType.SD2
|
||||||
|
elif hasKeyPrefix("model."):
|
||||||
|
self.model_type = ModelType.SD1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# print(f"Model type: {self.model_type}")
|
||||||
|
|
||||||
|
def Is_SD1(self):
|
||||||
|
return self.model_type == ModelType.SD1
|
||||||
|
|
||||||
|
def Is_SD2(self):
|
||||||
|
return self.model_type == ModelType.SD2
|
||||||
|
|
||||||
|
def Is_SDXL(self):
|
||||||
|
return self.model_type == ModelType.SDXL
|
||||||
|
|
||||||
|
def Is_SD3(self):
|
||||||
|
return self.model_type == ModelType.SD3
|
||||||
|
|
||||||
|
def Is_FLUX1(self):
|
||||||
|
return self.model_type == ModelType.FLUX1
|
||||||
|
|
@ -70,6 +70,7 @@ def save_configuration(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
|
|
@ -135,6 +136,7 @@ def save_configuration(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -151,17 +153,23 @@ def save_configuration(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
max_timestep,
|
max_timestep,
|
||||||
sdxl_no_half_vae,
|
sdxl_no_half_vae,
|
||||||
|
|
@ -229,6 +237,7 @@ def open_configuration(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
|
|
@ -294,6 +303,7 @@ def open_configuration(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -310,17 +320,23 @@ def open_configuration(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
max_timestep,
|
max_timestep,
|
||||||
sdxl_no_half_vae,
|
sdxl_no_half_vae,
|
||||||
|
|
@ -381,6 +397,7 @@ def train_model(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
lr_warmup,
|
lr_warmup,
|
||||||
|
lr_warmup_steps,
|
||||||
train_batch_size,
|
train_batch_size,
|
||||||
epoch,
|
epoch,
|
||||||
save_every_n_epochs,
|
save_every_n_epochs,
|
||||||
|
|
@ -446,6 +463,7 @@ def train_model(
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,
|
||||||
lr_scheduler_args,
|
lr_scheduler_args,
|
||||||
|
lr_scheduler_type,
|
||||||
noise_offset_type,
|
noise_offset_type,
|
||||||
noise_offset,
|
noise_offset,
|
||||||
noise_offset_random_strength,
|
noise_offset_random_strength,
|
||||||
|
|
@ -462,17 +480,23 @@ def train_model(
|
||||||
loss_type,
|
loss_type,
|
||||||
huber_schedule,
|
huber_schedule,
|
||||||
huber_c,
|
huber_c,
|
||||||
|
huber_scale,
|
||||||
vae_batch_size,
|
vae_batch_size,
|
||||||
min_snr_gamma,
|
min_snr_gamma,
|
||||||
save_every_n_steps,
|
save_every_n_steps,
|
||||||
save_last_n_steps,
|
save_last_n_steps,
|
||||||
save_last_n_steps_state,
|
save_last_n_steps_state,
|
||||||
|
save_last_n_epochs,
|
||||||
|
save_last_n_epochs_state,
|
||||||
|
skip_cache_check,
|
||||||
log_with,
|
log_with,
|
||||||
wandb_api_key,
|
wandb_api_key,
|
||||||
wandb_run_name,
|
wandb_run_name,
|
||||||
log_tracker_name,
|
log_tracker_name,
|
||||||
log_tracker_config,
|
log_tracker_config,
|
||||||
|
log_config,
|
||||||
scale_v_pred_loss_like_noise_pred,
|
scale_v_pred_loss_like_noise_pred,
|
||||||
|
disable_mmap_load_safetensors,
|
||||||
min_timestep,
|
min_timestep,
|
||||||
max_timestep,
|
max_timestep,
|
||||||
sdxl_no_half_vae,
|
sdxl_no_half_vae,
|
||||||
|
|
@ -549,20 +573,6 @@ def train_model(
|
||||||
# End of path validation
|
# End of path validation
|
||||||
#
|
#
|
||||||
|
|
||||||
# if not validate_paths(
|
|
||||||
# dataset_config=dataset_config,
|
|
||||||
# headless=headless,
|
|
||||||
# log_tracker_config=log_tracker_config,
|
|
||||||
# logging_dir=logging_dir,
|
|
||||||
# output_dir=output_dir,
|
|
||||||
# pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
# reg_data_dir=reg_data_dir,
|
|
||||||
# resume=resume,
|
|
||||||
# train_data_dir=train_data_dir,
|
|
||||||
# vae=vae,
|
|
||||||
# ):
|
|
||||||
# return TRAIN_BUTTON_VISIBLE
|
|
||||||
|
|
||||||
if token_string == "":
|
if token_string == "":
|
||||||
output_message(msg="Token string is missing", headless=headless)
|
output_message(msg="Token string is missing", headless=headless)
|
||||||
return TRAIN_BUTTON_VISIBLE
|
return TRAIN_BUTTON_VISIBLE
|
||||||
|
|
@ -588,13 +598,6 @@ def train_model(
|
||||||
stop_text_encoder_training = math.ceil(
|
stop_text_encoder_training = math.ceil(
|
||||||
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
||||||
)
|
)
|
||||||
|
|
||||||
if lr_warmup != 0:
|
|
||||||
lr_warmup_steps = round(
|
|
||||||
float(int(lr_warmup) * int(max_train_steps) / 100)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lr_warmup_steps = 0
|
|
||||||
else:
|
else:
|
||||||
stop_text_encoder_training = 0
|
stop_text_encoder_training = 0
|
||||||
lr_warmup_steps = 0
|
lr_warmup_steps = 0
|
||||||
|
|
@ -657,11 +660,11 @@ def train_model(
|
||||||
reg_factor = 1
|
reg_factor = 1
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
"Regularisation images are used... Will double the number of steps required..."
|
"Regularization images are used... Will double the number of steps required..."
|
||||||
)
|
)
|
||||||
reg_factor = 2
|
reg_factor = 2
|
||||||
|
|
||||||
log.info(f"Regulatization factor: {reg_factor}")
|
log.info(f"Regularization factor: {reg_factor}")
|
||||||
|
|
||||||
if max_train_steps == 0:
|
if max_train_steps == 0:
|
||||||
# calculate max_train_steps
|
# calculate max_train_steps
|
||||||
|
|
@ -689,13 +692,18 @@ def train_model(
|
||||||
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
||||||
)
|
)
|
||||||
|
|
||||||
if lr_warmup != 0:
|
|
||||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
|
||||||
else:
|
|
||||||
lr_warmup_steps = 0
|
|
||||||
|
|
||||||
log.info(f"Total steps: {total_steps}")
|
log.info(f"Total steps: {total_steps}")
|
||||||
|
|
||||||
|
# Calculate lr_warmup_steps
|
||||||
|
if lr_warmup_steps > 0:
|
||||||
|
lr_warmup_steps = int(lr_warmup_steps)
|
||||||
|
if lr_warmup > 0:
|
||||||
|
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||||
|
elif lr_warmup != 0:
|
||||||
|
lr_warmup_steps = lr_warmup / 100
|
||||||
|
else:
|
||||||
|
lr_warmup_steps = 0
|
||||||
|
|
||||||
log.info(f"Train batch size: {train_batch_size}")
|
log.info(f"Train batch size: {train_batch_size}")
|
||||||
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
|
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
|
||||||
log.info(f"Epoch: {epoch}")
|
log.info(f"Epoch: {epoch}")
|
||||||
|
|
@ -757,6 +765,7 @@ def train_model(
|
||||||
"clip_skip": clip_skip if clip_skip != 0 else None,
|
"clip_skip": clip_skip if clip_skip != 0 else None,
|
||||||
"color_aug": color_aug,
|
"color_aug": color_aug,
|
||||||
"dataset_config": dataset_config,
|
"dataset_config": dataset_config,
|
||||||
|
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
|
||||||
"dynamo_backend": dynamo_backend,
|
"dynamo_backend": dynamo_backend,
|
||||||
"enable_bucket": enable_bucket,
|
"enable_bucket": enable_bucket,
|
||||||
"epoch": int(epoch),
|
"epoch": int(epoch),
|
||||||
|
|
@ -765,6 +774,7 @@ def train_model(
|
||||||
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
||||||
"gradient_checkpointing": gradient_checkpointing,
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
"huber_c": huber_c,
|
"huber_c": huber_c,
|
||||||
|
"huber_scale": huber_scale,
|
||||||
"huber_schedule": huber_schedule,
|
"huber_schedule": huber_schedule,
|
||||||
"huggingface_repo_id": huggingface_repo_id,
|
"huggingface_repo_id": huggingface_repo_id,
|
||||||
"huggingface_token": huggingface_token,
|
"huggingface_token": huggingface_token,
|
||||||
|
|
@ -777,6 +787,7 @@ def train_model(
|
||||||
"keep_tokens": int(keep_tokens),
|
"keep_tokens": int(keep_tokens),
|
||||||
"learning_rate": learning_rate,
|
"learning_rate": learning_rate,
|
||||||
"logging_dir": logging_dir,
|
"logging_dir": logging_dir,
|
||||||
|
"log_config": log_config,
|
||||||
"log_tracker_name": log_tracker_name,
|
"log_tracker_name": log_tracker_name,
|
||||||
"log_tracker_config": log_tracker_config,
|
"log_tracker_config": log_tracker_config,
|
||||||
"loss_type": loss_type,
|
"loss_type": loss_type,
|
||||||
|
|
@ -786,6 +797,7 @@ def train_model(
|
||||||
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
|
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
|
||||||
),
|
),
|
||||||
"lr_scheduler_power": lr_scheduler_power,
|
"lr_scheduler_power": lr_scheduler_power,
|
||||||
|
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
|
||||||
"lr_warmup_steps": lr_warmup_steps,
|
"lr_warmup_steps": lr_warmup_steps,
|
||||||
"max_bucket_reso": max_bucket_reso,
|
"max_bucket_reso": max_bucket_reso,
|
||||||
"max_timestep": max_timestep if max_timestep != 0 else None,
|
"max_timestep": max_timestep if max_timestep != 0 else None,
|
||||||
|
|
@ -840,6 +852,10 @@ def train_model(
|
||||||
"save_last_n_steps_state": (
|
"save_last_n_steps_state": (
|
||||||
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
||||||
),
|
),
|
||||||
|
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
|
||||||
|
"save_last_n_epochs_state": (
|
||||||
|
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
|
||||||
|
),
|
||||||
"save_model_as": save_model_as,
|
"save_model_as": save_model_as,
|
||||||
"save_precision": save_precision,
|
"save_precision": save_precision,
|
||||||
"save_state": save_state,
|
"save_state": save_state,
|
||||||
|
|
@ -849,6 +865,7 @@ def train_model(
|
||||||
"sdpa": True if xformers == "sdpa" else None,
|
"sdpa": True if xformers == "sdpa" else None,
|
||||||
"seed": int(seed) if int(seed) != 0 else None,
|
"seed": int(seed) if int(seed) != 0 else None,
|
||||||
"shuffle_caption": shuffle_caption,
|
"shuffle_caption": shuffle_caption,
|
||||||
|
"skip_cache_check": skip_cache_check,
|
||||||
"stop_text_encoder_training": (
|
"stop_text_encoder_training": (
|
||||||
stop_text_encoder_training if stop_text_encoder_training != 0 else None
|
stop_text_encoder_training if stop_text_encoder_training != 0 else None
|
||||||
),
|
),
|
||||||
|
|
@ -862,8 +879,8 @@ def train_model(
|
||||||
"vae": vae,
|
"vae": vae,
|
||||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||||
"wandb_api_key": wandb_api_key,
|
"wandb_api_key": wandb_api_key,
|
||||||
"wandb_run_name": wandb_run_name,
|
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||||
"weigts": weights,
|
"weights": weights,
|
||||||
"use_object_template": True if template == "object template" else None,
|
"use_object_template": True if template == "object template" else None,
|
||||||
"use_style_template": True if template == "style template" else None,
|
"use_style_template": True if template == "style template" else None,
|
||||||
"xformers": True if xformers == "xformers" else None,
|
"xformers": True if xformers == "xformers" else None,
|
||||||
|
|
@ -1130,6 +1147,7 @@ def ti_tab(
|
||||||
basic_training.learning_rate,
|
basic_training.learning_rate,
|
||||||
basic_training.lr_scheduler,
|
basic_training.lr_scheduler,
|
||||||
basic_training.lr_warmup,
|
basic_training.lr_warmup,
|
||||||
|
basic_training.lr_warmup_steps,
|
||||||
basic_training.train_batch_size,
|
basic_training.train_batch_size,
|
||||||
basic_training.epoch,
|
basic_training.epoch,
|
||||||
basic_training.save_every_n_epochs,
|
basic_training.save_every_n_epochs,
|
||||||
|
|
@ -1194,6 +1212,7 @@ def ti_tab(
|
||||||
basic_training.optimizer,
|
basic_training.optimizer,
|
||||||
basic_training.optimizer_args,
|
basic_training.optimizer_args,
|
||||||
basic_training.lr_scheduler_args,
|
basic_training.lr_scheduler_args,
|
||||||
|
basic_training.lr_scheduler_type,
|
||||||
advanced_training.noise_offset_type,
|
advanced_training.noise_offset_type,
|
||||||
advanced_training.noise_offset,
|
advanced_training.noise_offset,
|
||||||
advanced_training.noise_offset_random_strength,
|
advanced_training.noise_offset_random_strength,
|
||||||
|
|
@ -1210,17 +1229,23 @@ def ti_tab(
|
||||||
advanced_training.loss_type,
|
advanced_training.loss_type,
|
||||||
advanced_training.huber_schedule,
|
advanced_training.huber_schedule,
|
||||||
advanced_training.huber_c,
|
advanced_training.huber_c,
|
||||||
|
advanced_training.huber_scale,
|
||||||
advanced_training.vae_batch_size,
|
advanced_training.vae_batch_size,
|
||||||
advanced_training.min_snr_gamma,
|
advanced_training.min_snr_gamma,
|
||||||
advanced_training.save_every_n_steps,
|
advanced_training.save_every_n_steps,
|
||||||
advanced_training.save_last_n_steps,
|
advanced_training.save_last_n_steps,
|
||||||
advanced_training.save_last_n_steps_state,
|
advanced_training.save_last_n_steps_state,
|
||||||
|
advanced_training.save_last_n_epochs,
|
||||||
|
advanced_training.save_last_n_epochs_state,
|
||||||
|
advanced_training.skip_cache_check,
|
||||||
advanced_training.log_with,
|
advanced_training.log_with,
|
||||||
advanced_training.wandb_api_key,
|
advanced_training.wandb_api_key,
|
||||||
advanced_training.wandb_run_name,
|
advanced_training.wandb_run_name,
|
||||||
advanced_training.log_tracker_name,
|
advanced_training.log_tracker_name,
|
||||||
advanced_training.log_tracker_config,
|
advanced_training.log_tracker_config,
|
||||||
|
advanced_training.log_config,
|
||||||
advanced_training.scale_v_pred_loss_like_noise_pred,
|
advanced_training.scale_v_pred_loss_like_noise_pred,
|
||||||
|
sdxl_params.disable_mmap_load_safetensors,
|
||||||
advanced_training.min_timestep,
|
advanced_training.min_timestep,
|
||||||
advanced_training.max_timestep,
|
advanced_training.max_timestep,
|
||||||
sdxl_params.sdxl_no_half_vae,
|
sdxl_params.sdxl_no_half_vae,
|
||||||
|
|
@ -1289,4 +1314,4 @@ def ti_tab(
|
||||||
folders.reg_data_dir,
|
folders.reg_data_dir,
|
||||||
folders.output_dir,
|
folders.output_dir,
|
||||||
folders.logging_dir,
|
folders.logging_dir,
|
||||||
)
|
)
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
{
|
||||||
|
"adaptive_noise_scale": 0,
|
||||||
|
"additional_parameters": "",
|
||||||
|
"async_upload": false,
|
||||||
|
"bucket_no_upscale": true,
|
||||||
|
"bucket_reso_steps": 64,
|
||||||
|
"cache_latents": true,
|
||||||
|
"cache_latents_to_disk": true,
|
||||||
|
"caption_dropout_every_n_epochs": 0,
|
||||||
|
"caption_dropout_rate": 0,
|
||||||
|
"caption_extension": ".txt",
|
||||||
|
"clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors",
|
||||||
|
"clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors",
|
||||||
|
"clip_skip": 1,
|
||||||
|
"color_aug": false,
|
||||||
|
"dataset_config": "",
|
||||||
|
"debiased_estimation_loss": false,
|
||||||
|
"disable_mmap_load_safetensors": false,
|
||||||
|
"dynamo_backend": "no",
|
||||||
|
"dynamo_mode": "default",
|
||||||
|
"dynamo_use_dynamic": false,
|
||||||
|
"dynamo_use_fullgraph": false,
|
||||||
|
"enable_bucket": true,
|
||||||
|
"epoch": 8,
|
||||||
|
"extra_accelerate_launch_args": "",
|
||||||
|
"flip_aug": false,
|
||||||
|
"full_bf16": false,
|
||||||
|
"full_fp16": false,
|
||||||
|
"fused_backward_pass": false,
|
||||||
|
"fused_optimizer_groups": 0,
|
||||||
|
"gpu_ids": "",
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"gradient_checkpointing": true,
|
||||||
|
"huber_c": 0.1,
|
||||||
|
"huber_schedule": "snr",
|
||||||
|
"huggingface_path_in_repo": "",
|
||||||
|
"huggingface_repo_id": "",
|
||||||
|
"huggingface_repo_type": "",
|
||||||
|
"huggingface_repo_visibility": "",
|
||||||
|
"huggingface_token": "",
|
||||||
|
"ip_noise_gamma": 0,
|
||||||
|
"ip_noise_gamma_random_strength": false,
|
||||||
|
"keep_tokens": 0,
|
||||||
|
"learning_rate": 5e-06,
|
||||||
|
"learning_rate_te": 0,
|
||||||
|
"learning_rate_te1": 1e-05,
|
||||||
|
"learning_rate_te2": 1e-05,
|
||||||
|
"log_config": false,
|
||||||
|
"log_tracker_config": "",
|
||||||
|
"log_tracker_name": "",
|
||||||
|
"log_with": "",
|
||||||
|
"logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3",
|
||||||
|
"logit_mean": 0,
|
||||||
|
"logit_std": 1,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"lr_scheduler_args": "",
|
||||||
|
"lr_scheduler_num_cycles": 1,
|
||||||
|
"lr_scheduler_power": 1,
|
||||||
|
"lr_scheduler_type": "",
|
||||||
|
"lr_warmup": 10,
|
||||||
|
"main_process_port": 0,
|
||||||
|
"masked_loss": false,
|
||||||
|
"max_bucket_reso": 1536,
|
||||||
|
"max_data_loader_n_workers": 0,
|
||||||
|
"max_resolution": "512,512",
|
||||||
|
"max_timestep": 1000,
|
||||||
|
"max_token_length": 225,
|
||||||
|
"max_train_epochs": 8,
|
||||||
|
"max_train_steps": 1600,
|
||||||
|
"mem_eff_attn": false,
|
||||||
|
"metadata_author": "",
|
||||||
|
"metadata_description": "",
|
||||||
|
"metadata_license": "",
|
||||||
|
"metadata_tags": "",
|
||||||
|
"metadata_title": "",
|
||||||
|
"min_bucket_reso": 256,
|
||||||
|
"min_snr_gamma": 0,
|
||||||
|
"min_timestep": 0,
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"mode_scale": 1.29,
|
||||||
|
"model_list": "custom",
|
||||||
|
"multi_gpu": false,
|
||||||
|
"multires_noise_discount": 0.3,
|
||||||
|
"multires_noise_iterations": 0,
|
||||||
|
"no_token_padding": false,
|
||||||
|
"noise_offset": 0,
|
||||||
|
"noise_offset_random_strength": false,
|
||||||
|
"noise_offset_type": "Original",
|
||||||
|
"num_cpu_threads_per_process": 2,
|
||||||
|
"num_machines": 1,
|
||||||
|
"num_processes": 1,
|
||||||
|
"optimizer": "PagedAdamW8bit",
|
||||||
|
"optimizer_args": "weight_decay=0.1 betas=.9,.95",
|
||||||
|
"output_dir": "E:/models/sd3",
|
||||||
|
"output_name": "sd3",
|
||||||
|
"persistent_data_loader_workers": false,
|
||||||
|
"pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors",
|
||||||
|
"prior_loss_weight": 1,
|
||||||
|
"random_crop": false,
|
||||||
|
"reg_data_dir": "",
|
||||||
|
"resume": "",
|
||||||
|
"resume_from_huggingface": "",
|
||||||
|
"sample_every_n_epochs": 0,
|
||||||
|
"sample_every_n_steps": 0,
|
||||||
|
"sample_prompts": "",
|
||||||
|
"sample_sampler": "euler_a",
|
||||||
|
"save_as_bool": false,
|
||||||
|
"save_clip": false,
|
||||||
|
"save_every_n_epochs": 0,
|
||||||
|
"save_every_n_steps": 0,
|
||||||
|
"save_last_n_steps": 0,
|
||||||
|
"save_last_n_steps_state": 0,
|
||||||
|
"save_model_as": "safetensors",
|
||||||
|
"save_precision": "fp16",
|
||||||
|
"save_state": false,
|
||||||
|
"save_state_on_train_end": false,
|
||||||
|
"save_state_to_huggingface": false,
|
||||||
|
"save_t5xxl": false,
|
||||||
|
"scale_v_pred_loss_like_noise_pred": false,
|
||||||
|
"sd3_cache_text_encoder_outputs": true,
|
||||||
|
"sd3_cache_text_encoder_outputs_to_disk": true,
|
||||||
|
"sd3_checkbox": true,
|
||||||
|
"sd3_text_encoder_batch_size": 1,
|
||||||
|
"sdxl": false,
|
||||||
|
"sdxl_cache_text_encoder_outputs": false,
|
||||||
|
"sdxl_no_half_vae": false,
|
||||||
|
"seed": 1026,
|
||||||
|
"shuffle_caption": false,
|
||||||
|
"stop_text_encoder_training": 0,
|
||||||
|
"t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors",
|
||||||
|
"t5xxl_device": "",
|
||||||
|
"t5xxl_dtype": "bf16",
|
||||||
|
"train_batch_size": 1,
|
||||||
|
"train_data_dir": "C:/Users/berna/Downloads/martini/img2",
|
||||||
|
"v2": false,
|
||||||
|
"v_parameterization": false,
|
||||||
|
"v_pred_like_loss": 0,
|
||||||
|
"vae": "",
|
||||||
|
"vae_batch_size": 0,
|
||||||
|
"wandb_api_key": "",
|
||||||
|
"wandb_run_name": "",
|
||||||
|
"weighted_captions": false,
|
||||||
|
"weighting_scheme": "logit_normal",
|
||||||
|
"xformers": "sdpa"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
{
|
||||||
|
"adaptive_noise_scale": 0,
|
||||||
|
"additional_parameters": "",
|
||||||
|
"async_upload": false,
|
||||||
|
"bucket_no_upscale": true,
|
||||||
|
"bucket_reso_steps": 64,
|
||||||
|
"cache_latents": true,
|
||||||
|
"cache_latents_to_disk": true,
|
||||||
|
"caption_dropout_every_n_epochs": 0,
|
||||||
|
"caption_dropout_rate": 0,
|
||||||
|
"caption_extension": ".txt",
|
||||||
|
"clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors",
|
||||||
|
"clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors",
|
||||||
|
"clip_skip": 1,
|
||||||
|
"color_aug": false,
|
||||||
|
"dataset_config": "",
|
||||||
|
"debiased_estimation_loss": false,
|
||||||
|
"disable_mmap_load_safetensors": false,
|
||||||
|
"dynamo_backend": "no",
|
||||||
|
"dynamo_mode": "default",
|
||||||
|
"dynamo_use_dynamic": false,
|
||||||
|
"dynamo_use_fullgraph": false,
|
||||||
|
"enable_bucket": true,
|
||||||
|
"epoch": 8,
|
||||||
|
"extra_accelerate_launch_args": "",
|
||||||
|
"flip_aug": false,
|
||||||
|
"full_bf16": false,
|
||||||
|
"full_fp16": false,
|
||||||
|
"fused_backward_pass": false,
|
||||||
|
"fused_optimizer_groups": 0,
|
||||||
|
"gpu_ids": "",
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"gradient_checkpointing": true,
|
||||||
|
"huber_c": 0.1,
|
||||||
|
"huber_schedule": "snr",
|
||||||
|
"huggingface_path_in_repo": "",
|
||||||
|
"huggingface_repo_id": "",
|
||||||
|
"huggingface_repo_type": "",
|
||||||
|
"huggingface_repo_visibility": "",
|
||||||
|
"huggingface_token": "",
|
||||||
|
"ip_noise_gamma": 0,
|
||||||
|
"ip_noise_gamma_random_strength": false,
|
||||||
|
"keep_tokens": 0,
|
||||||
|
"learning_rate": 5e-06,
|
||||||
|
"learning_rate_te": 0,
|
||||||
|
"learning_rate_te1": 1e-05,
|
||||||
|
"learning_rate_te2": 1e-05,
|
||||||
|
"log_config": false,
|
||||||
|
"log_tracker_config": "",
|
||||||
|
"log_tracker_name": "",
|
||||||
|
"log_with": "",
|
||||||
|
"logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3",
|
||||||
|
"logit_mean": 0,
|
||||||
|
"logit_std": 1,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"lr_scheduler_args": "",
|
||||||
|
"lr_scheduler_num_cycles": 1,
|
||||||
|
"lr_scheduler_power": 1,
|
||||||
|
"lr_scheduler_type": "",
|
||||||
|
"lr_warmup": 10,
|
||||||
|
"main_process_port": 0,
|
||||||
|
"masked_loss": false,
|
||||||
|
"max_bucket_reso": 1536,
|
||||||
|
"max_data_loader_n_workers": 0,
|
||||||
|
"max_resolution": "512,512",
|
||||||
|
"max_timestep": 1000,
|
||||||
|
"max_token_length": 150,
|
||||||
|
"max_train_epochs": 8,
|
||||||
|
"max_train_steps": 1600,
|
||||||
|
"mem_eff_attn": false,
|
||||||
|
"metadata_author": "",
|
||||||
|
"metadata_description": "",
|
||||||
|
"metadata_license": "",
|
||||||
|
"metadata_tags": "",
|
||||||
|
"metadata_title": "",
|
||||||
|
"min_bucket_reso": 256,
|
||||||
|
"min_snr_gamma": 0,
|
||||||
|
"min_timestep": 0,
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"mode_scale": 1.29,
|
||||||
|
"model_list": "custom",
|
||||||
|
"multi_gpu": false,
|
||||||
|
"multires_noise_discount": 0.3,
|
||||||
|
"multires_noise_iterations": 0,
|
||||||
|
"no_token_padding": false,
|
||||||
|
"noise_offset": 0,
|
||||||
|
"noise_offset_random_strength": false,
|
||||||
|
"noise_offset_type": "Original",
|
||||||
|
"num_cpu_threads_per_process": 2,
|
||||||
|
"num_machines": 1,
|
||||||
|
"num_processes": 1,
|
||||||
|
"optimizer": "PagedAdamW8bit",
|
||||||
|
"optimizer_args": "weight_decay=0.1 betas=.9,.95",
|
||||||
|
"output_dir": "E:/models/sd3",
|
||||||
|
"output_name": "sd3_v2",
|
||||||
|
"persistent_data_loader_workers": false,
|
||||||
|
"pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors",
|
||||||
|
"prior_loss_weight": 1,
|
||||||
|
"random_crop": false,
|
||||||
|
"reg_data_dir": "",
|
||||||
|
"resume": "",
|
||||||
|
"resume_from_huggingface": "",
|
||||||
|
"sample_every_n_epochs": 0,
|
||||||
|
"sample_every_n_steps": 0,
|
||||||
|
"sample_prompts": "",
|
||||||
|
"sample_sampler": "euler_a",
|
||||||
|
"save_as_bool": false,
|
||||||
|
"save_clip": false,
|
||||||
|
"save_every_n_epochs": 0,
|
||||||
|
"save_every_n_steps": 0,
|
||||||
|
"save_last_n_steps": 0,
|
||||||
|
"save_last_n_steps_state": 0,
|
||||||
|
"save_model_as": "safetensors",
|
||||||
|
"save_precision": "fp16",
|
||||||
|
"save_state": false,
|
||||||
|
"save_state_on_train_end": false,
|
||||||
|
"save_state_to_huggingface": false,
|
||||||
|
"save_t5xxl": false,
|
||||||
|
"scale_v_pred_loss_like_noise_pred": false,
|
||||||
|
"sd3_cache_text_encoder_outputs": true,
|
||||||
|
"sd3_cache_text_encoder_outputs_to_disk": true,
|
||||||
|
"sd3_checkbox": true,
|
||||||
|
"sd3_text_encoder_batch_size": 1,
|
||||||
|
"sdxl": false,
|
||||||
|
"sdxl_cache_text_encoder_outputs": false,
|
||||||
|
"sdxl_no_half_vae": false,
|
||||||
|
"seed": 1026,
|
||||||
|
"shuffle_caption": false,
|
||||||
|
"stop_text_encoder_training": 0,
|
||||||
|
"t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors",
|
||||||
|
"t5xxl_device": "",
|
||||||
|
"t5xxl_dtype": "bf16",
|
||||||
|
"train_batch_size": 1,
|
||||||
|
"train_data_dir": "C:/Users/berna/Downloads/martini/img",
|
||||||
|
"v2": false,
|
||||||
|
"v_parameterization": false,
|
||||||
|
"v_pred_like_loss": 0,
|
||||||
|
"vae": "",
|
||||||
|
"vae_batch_size": 0,
|
||||||
|
"wandb_api_key": "",
|
||||||
|
"wandb_run_name": "",
|
||||||
|
"weighted_captions": false,
|
||||||
|
"weighting_scheme": "logit_normal",
|
||||||
|
"xformers": "sdpa"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
{
|
||||||
|
"LoRA_type": "Flux1",
|
||||||
|
"LyCORIS_preset": "full",
|
||||||
|
"adaptive_noise_scale": 0,
|
||||||
|
"additional_parameters": "",
|
||||||
|
"ae": "put the full path to ae.sft here",
|
||||||
|
"apply_t5_attn_mask": true,
|
||||||
|
"async_upload": false,
|
||||||
|
"block_alphas": "",
|
||||||
|
"block_dims": "",
|
||||||
|
"block_lr_zero_threshold": "",
|
||||||
|
"bucket_no_upscale": true,
|
||||||
|
"bucket_reso_steps": 64,
|
||||||
|
"bypass_mode": false,
|
||||||
|
"cache_latents": true,
|
||||||
|
"cache_latents_to_disk": true,
|
||||||
|
"caption_dropout_every_n_epochs": 0,
|
||||||
|
"caption_dropout_rate": 0,
|
||||||
|
"caption_extension": ".txt",
|
||||||
|
"clip_l": "put the full path to clip_l.safetensors here",
|
||||||
|
"clip_skip": 1,
|
||||||
|
"color_aug": false,
|
||||||
|
"constrain": 0,
|
||||||
|
"conv_alpha": 1,
|
||||||
|
"conv_block_alphas": "",
|
||||||
|
"conv_block_dims": "",
|
||||||
|
"conv_dim": 1,
|
||||||
|
"dataset_config": "",
|
||||||
|
"debiased_estimation_loss": false,
|
||||||
|
"decompose_both": false,
|
||||||
|
"dim_from_weights": false,
|
||||||
|
"discrete_flow_shift": 3,
|
||||||
|
"dora_wd": false,
|
||||||
|
"down_lr_weight": "",
|
||||||
|
"dynamo_backend": "no",
|
||||||
|
"dynamo_mode": "default",
|
||||||
|
"dynamo_use_dynamic": false,
|
||||||
|
"dynamo_use_fullgraph": false,
|
||||||
|
"enable_bucket": true,
|
||||||
|
"epoch": 1,
|
||||||
|
"extra_accelerate_launch_args": "",
|
||||||
|
"factor": -1,
|
||||||
|
"flip_aug": false,
|
||||||
|
"flux1_cache_text_encoder_outputs": true,
|
||||||
|
"flux1_cache_text_encoder_outputs_to_disk": true,
|
||||||
|
"flux1_checkbox": true,
|
||||||
|
"fp8_base": true,
|
||||||
|
"full_bf16": true,
|
||||||
|
"full_fp16": false,
|
||||||
|
"gpu_ids": "",
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"gradient_checkpointing": true,
|
||||||
|
"guidance_scale": 1,
|
||||||
|
"highvram": false,
|
||||||
|
"huber_c": 0.1,
|
||||||
|
"huber_schedule": "snr",
|
||||||
|
"huggingface_path_in_repo": "",
|
||||||
|
"huggingface_repo_id": "",
|
||||||
|
"huggingface_repo_type": "",
|
||||||
|
"huggingface_repo_visibility": "",
|
||||||
|
"huggingface_token": "",
|
||||||
|
"ip_noise_gamma": 0,
|
||||||
|
"ip_noise_gamma_random_strength": false,
|
||||||
|
"keep_tokens": 0,
|
||||||
|
"learning_rate": 0.0003,
|
||||||
|
"log_config": false,
|
||||||
|
"log_tracker_config": "",
|
||||||
|
"log_tracker_name": "",
|
||||||
|
"log_with": "",
|
||||||
|
"logging_dir": "./test/logs-saruman",
|
||||||
|
"loraplus_lr_ratio": 0,
|
||||||
|
"loraplus_text_encoder_lr_ratio": 0,
|
||||||
|
"loraplus_unet_lr_ratio": 0,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"lowvram": false,
|
||||||
|
"lr_scheduler": "constant",
|
||||||
|
"lr_scheduler_args": "",
|
||||||
|
"lr_scheduler_num_cycles": 1,
|
||||||
|
"lr_scheduler_power": 1,
|
||||||
|
"lr_scheduler_type": "",
|
||||||
|
"lr_warmup": 0,
|
||||||
|
"main_process_port": 0,
|
||||||
|
"masked_loss": false,
|
||||||
|
"max_bucket_reso": 2048,
|
||||||
|
"max_data_loader_n_workers": 0,
|
||||||
|
"max_grad_norm": 1,
|
||||||
|
"max_resolution": "512,512",
|
||||||
|
"max_timestep": 1000,
|
||||||
|
"max_token_length": 75,
|
||||||
|
"max_train_epochs": 0,
|
||||||
|
"max_train_steps": 1000,
|
||||||
|
"mem_eff_attn": false,
|
||||||
|
"mem_eff_save": false,
|
||||||
|
"metadata_author": "",
|
||||||
|
"metadata_description": "",
|
||||||
|
"metadata_license": "",
|
||||||
|
"metadata_tags": "",
|
||||||
|
"metadata_title": "",
|
||||||
|
"mid_lr_weight": "",
|
||||||
|
"min_bucket_reso": 256,
|
||||||
|
"min_snr_gamma": 7,
|
||||||
|
"min_timestep": 0,
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"model_list": "custom",
|
||||||
|
"model_prediction_type": "raw",
|
||||||
|
"module_dropout": 0,
|
||||||
|
"multi_gpu": false,
|
||||||
|
"multires_noise_discount": 0.3,
|
||||||
|
"multires_noise_iterations": 0,
|
||||||
|
"network_alpha": 16,
|
||||||
|
"network_dim": 16,
|
||||||
|
"network_dropout": 0,
|
||||||
|
"network_weights": "",
|
||||||
|
"noise_offset": 0.05,
|
||||||
|
"noise_offset_random_strength": false,
|
||||||
|
"noise_offset_type": "Original",
|
||||||
|
"num_cpu_threads_per_process": 2,
|
||||||
|
"num_machines": 1,
|
||||||
|
"num_processes": 1,
|
||||||
|
"optimizer": "AdamW8bit",
|
||||||
|
"optimizer_args": "",
|
||||||
|
"output_dir": "put the full path to output folder here",
|
||||||
|
"output_name": "Flux.my-super-duper-model-name-goes-here-v1.0",
|
||||||
|
"persistent_data_loader_workers": false,
|
||||||
|
"pretrained_model_name_or_path": "put the full path to flux1-dev.safetensors here",
|
||||||
|
"prior_loss_weight": 1,
|
||||||
|
"random_crop": false,
|
||||||
|
"rank_dropout": 0,
|
||||||
|
"rank_dropout_scale": false,
|
||||||
|
"reg_data_dir": "",
|
||||||
|
"rescaled": false,
|
||||||
|
"resume": "",
|
||||||
|
"resume_from_huggingface": "",
|
||||||
|
"sample_every_n_epochs": 0,
|
||||||
|
"sample_every_n_steps": 0,
|
||||||
|
"sample_prompts": "saruman posing under a stormy lightning sky, photorealistic --w 832 --h 1216 --s 20 --l 4 --d 42",
|
||||||
|
"sample_sampler": "euler",
|
||||||
|
"save_as_bool": false,
|
||||||
|
"save_every_n_epochs": 1,
|
||||||
|
"save_every_n_steps": 50,
|
||||||
|
"save_last_n_steps": 0,
|
||||||
|
"save_last_n_steps_state": 0,
|
||||||
|
"save_model_as": "safetensors",
|
||||||
|
"save_precision": "bf16",
|
||||||
|
"save_state": false,
|
||||||
|
"save_state_on_train_end": false,
|
||||||
|
"save_state_to_huggingface": false,
|
||||||
|
"scale_v_pred_loss_like_noise_pred": false,
|
||||||
|
"scale_weight_norms": 0,
|
||||||
|
"sdxl": false,
|
||||||
|
"sdxl_cache_text_encoder_outputs": true,
|
||||||
|
"sdxl_no_half_vae": true,
|
||||||
|
"seed": 42,
|
||||||
|
"shuffle_caption": false,
|
||||||
|
"split_mode": false,
|
||||||
|
"stop_text_encoder_training": 0,
|
||||||
|
"t5xxl": "put the full path to the file here. Use the fp16 version",
|
||||||
|
"t5xxl_max_token_length": 512,
|
||||||
|
"text_encoder_lr": 0,
|
||||||
|
"timestep_sampling": "sigmoid",
|
||||||
|
"train_batch_size": 1,
|
||||||
|
"train_blocks": "all",
|
||||||
|
"train_data_dir": "put your image folder here",
|
||||||
|
"train_norm": false,
|
||||||
|
"train_on_input": true,
|
||||||
|
"training_comment": "",
|
||||||
|
"unet_lr": 0.0003,
|
||||||
|
"unit": 1,
|
||||||
|
"up_lr_weight": "",
|
||||||
|
"use_cp": false,
|
||||||
|
"use_scalar": false,
|
||||||
|
"use_tucker": false,
|
||||||
|
"v2": false,
|
||||||
|
"v_parameterization": false,
|
||||||
|
"v_pred_like_loss": 0,
|
||||||
|
"vae": "",
|
||||||
|
"vae_batch_size": 0,
|
||||||
|
"wandb_api_key": "",
|
||||||
|
"wandb_run_name": "",
|
||||||
|
"weighted_captions": false,
|
||||||
|
"xformers": "sdpa"
|
||||||
|
}
|
||||||
|
|
@ -1,35 +1,38 @@
|
||||||
accelerate==0.25.0
|
accelerate==0.33.0
|
||||||
aiofiles==23.2.1
|
aiofiles==23.2.1
|
||||||
altair==4.2.2
|
altair==4.2.2
|
||||||
dadaptation==3.1
|
dadaptation==3.2
|
||||||
diffusers[torch]==0.25.0
|
diffusers[torch]==0.25.0
|
||||||
easygui==0.98.3
|
easygui==0.98.3
|
||||||
einops==0.7.0
|
einops==0.7.0
|
||||||
fairscale==0.4.13
|
fairscale==0.4.13
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
gradio==4.43.0
|
gradio==5.23.1
|
||||||
huggingface-hub==0.20.1
|
huggingface-hub==0.29.3
|
||||||
imagesize==1.4.1
|
imagesize==1.4.1
|
||||||
invisible-watermark==0.2.0
|
invisible-watermark==0.2.0
|
||||||
lion-pytorch==0.0.6
|
lion-pytorch==0.0.6
|
||||||
lycoris_lora==2.2.0.post3
|
lycoris_lora==3.1.0
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
onnx==1.16.1
|
onnx==1.16.1
|
||||||
prodigyopt==1.0
|
prodigyopt==1.1.2
|
||||||
protobuf==3.20.3
|
protobuf==3.20.3
|
||||||
open-clip-torch==2.20.0
|
open-clip-torch==2.20.0
|
||||||
opencv-python==4.7.0.68
|
opencv-python==4.10.0.84
|
||||||
prodigyopt==1.0
|
prodigy-plus-schedule-free==1.8.0
|
||||||
pytorch-lightning==1.9.0
|
pytorch-lightning==1.9.0
|
||||||
|
pytorch-optimizer==3.5.0
|
||||||
rich>=13.7.1
|
rich>=13.7.1
|
||||||
safetensors==0.4.2
|
safetensors==0.4.4
|
||||||
|
schedulefree==1.4
|
||||||
scipy==1.11.4
|
scipy==1.11.4
|
||||||
|
# for T5XXL tokenizer (SD3/FLUX)
|
||||||
|
sentencepiece==0.2.0
|
||||||
timm==0.6.12
|
timm==0.6.12
|
||||||
tk==0.1.0
|
tk==0.1.0
|
||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
transformers==4.38.0
|
transformers==4.44.2
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
wandb==0.15.11
|
wandb==0.18.0
|
||||||
scipy==1.11.4
|
# for kohya_ss sd-scripts library
|
||||||
# for kohya_ss library
|
-e ./sd-scripts
|
||||||
-e ./sd-scripts # no_verify leave this to specify not checking this a verification stage
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,13 @@
|
||||||
torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
|
# Custom index URL for specific packages
|
||||||
bitsandbytes==0.43.0
|
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||||
tensorboard==2.15.2 tensorflow==2.15.0.post1
|
|
||||||
onnxruntime-gpu==1.17.1
|
torch==2.5.0+cu124
|
||||||
|
torchvision==0.20.0+cu124
|
||||||
|
xformers==0.0.28.post2
|
||||||
|
|
||||||
|
bitsandbytes==0.44.0
|
||||||
|
tensorboard==2.15.2
|
||||||
|
tensorflow==2.15.0.post1
|
||||||
|
onnxruntime-gpu==1.19.2
|
||||||
|
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
xformers>=0.0.20
|
xformers>=0.0.20
|
||||||
bitsandbytes==0.43.0
|
bitsandbytes==0.44.0
|
||||||
accelerate==0.25.0
|
accelerate==0.33.0
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,17 @@
|
||||||
torch==2.1.0.post0+cxx11.abi torchvision==0.16.0.post0+cxx11.abi intel-extension-for-pytorch==2.1.20+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
# Custom index URL for specific packages
|
||||||
tensorboard==2.15.2 tensorflow==2.15.0 intel-extension-for-tensorflow[xpu]==2.15.0.0
|
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
mkl==2024.1.0 mkl-dpcpp==2024.1.0 oneccl-devel==2021.12.0 impi-devel==2021.12.0
|
|
||||||
onnxruntime-openvino==1.17.1
|
torch==2.1.0.post3+cxx11.abi
|
||||||
|
torchvision==0.16.0.post3+cxx11.abi
|
||||||
|
intel-extension-for-pytorch==2.1.40+xpu
|
||||||
|
oneccl_bind_pt==2.1.400+xpu
|
||||||
|
|
||||||
|
tensorflow==2.15.1
|
||||||
|
intel-extension-for-tensorflow[xpu]==2.15.0.1
|
||||||
|
mkl==2024.2.0
|
||||||
|
mkl-dpcpp==2024.2.0
|
||||||
|
oneccl-devel==2021.13.0
|
||||||
|
impi-devel==2021.13.0
|
||||||
|
onnxruntime-openvino==1.18.0
|
||||||
|
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,13 @@
|
||||||
torch==2.3.0+rocm6.0 torchvision==0.18.0+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0
|
# Custom index URL for specific packages
|
||||||
tensorboard==2.14.1 tensorflow-rocm==2.14.0.600
|
--extra-index-url https://download.pytorch.org/whl/rocm6.1
|
||||||
onnxruntime-training --pre --index-url https://pypi.lsh.sh/60/ --extra-index-url https://pypi.org/simple
|
torch==2.5.0+rocm6.1
|
||||||
|
torchvision==0.20.0+rocm6.1
|
||||||
|
|
||||||
|
tensorboard==2.14.1
|
||||||
|
tensorflow-rocm==2.14.0.600
|
||||||
|
|
||||||
|
# Custom index URL for specific packages
|
||||||
|
--extra-index-url https://pypi.lsh.sh/60/
|
||||||
|
onnxruntime-training --pre
|
||||||
|
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
xformers bitsandbytes==0.41.1
|
xformers bitsandbytes==0.43.3
|
||||||
tensorflow-macos tensorboard==2.14.1
|
tensorflow-macos tensorboard==2.14.1
|
||||||
onnxruntime==1.17.1
|
onnxruntime==1.17.1
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
xformers bitsandbytes==0.41.1
|
xformers bitsandbytes==0.43.3
|
||||||
tensorflow-macos tensorflow-metal tensorboard==2.14.1
|
tensorflow-macos tensorflow-metal tensorboard==2.14.1
|
||||||
onnxruntime==1.17.1
|
onnxruntime==1.17.1
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,8 @@
|
||||||
torch==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
# Custom index URL for specific packages
|
||||||
torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||||
xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118
|
|
||||||
|
torch==2.5.0+cu124
|
||||||
|
torchvision==0.20.0+cu124
|
||||||
|
xformers==0.0.28.post2
|
||||||
|
|
||||||
|
-r requirements_windows.txt
|
||||||
|
|
@ -1,6 +1,13 @@
|
||||||
torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
|
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||||
bitsandbytes==0.43.0
|
torch==2.5.0+cu124
|
||||||
tensorboard==2.14.1 tensorflow==2.14.0 wheel
|
torchvision==0.20.0+cu124
|
||||||
|
xformers==0.0.28.post2
|
||||||
|
|
||||||
|
bitsandbytes==0.44.0
|
||||||
|
tensorboard==2.14.1
|
||||||
|
tensorflow==2.14.0
|
||||||
|
wheel
|
||||||
tensorrt
|
tensorrt
|
||||||
onnxruntime-gpu==1.17.1
|
onnxruntime-gpu==1.19.2
|
||||||
|
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
bitsandbytes==0.43.0
|
bitsandbytes==0.44.0
|
||||||
tensorboard
|
tensorboard
|
||||||
tensorflow>=2.16.1
|
tensorflow>=2.16.1
|
||||||
onnxruntime-gpu==1.17.1
|
onnxruntime-gpu==1.19.2
|
||||||
|
|
||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
|
|
@ -1 +1 @@
|
||||||
Subproject commit b8896aad400222c8c4441b217fda0f9bb0807ffd
|
Subproject commit 8ebe858f896340d698f03fc33d99ca010131320a
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
IF NOT EXIST venv (
|
IF NOT EXIST venv (
|
||||||
echo Creating venv...
|
echo Creating venv...
|
||||||
py -3.10 -m venv venv
|
py -3.10.11 -m venv venv
|
||||||
)
|
)
|
||||||
|
|
||||||
:: Create the directory if it doesn't exist
|
:: Create the directory if it doesn't exist
|
||||||
|
|
@ -13,6 +13,9 @@ call .\venv\Scripts\deactivate.bat
|
||||||
|
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
|
|
||||||
|
REM first make sure we have setuptools available in the venv
|
||||||
|
python -m pip install --require-virtualenv --no-input -q -q setuptools
|
||||||
|
|
||||||
REM Check if the batch was started via double-click
|
REM Check if the batch was started via double-click
|
||||||
IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" (
|
IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" (
|
||||||
REM echo This script was started by double clicking.
|
REM echo This script was started by double clicking.
|
||||||
|
|
|
||||||
18
setup.sh
18
setup.sh
|
|
@ -1,4 +1,5 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
# Function to display help information
|
# Function to display help information
|
||||||
display_help() {
|
display_help() {
|
||||||
|
|
@ -23,6 +24,7 @@ Options:
|
||||||
-i, --interactive Interactively configure accelerate instead of using default config file.
|
-i, --interactive Interactively configure accelerate instead of using default config file.
|
||||||
-n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations.
|
-n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations.
|
||||||
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
|
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
|
||||||
|
-q, --quiet Suppress all output except errors.
|
||||||
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
|
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
|
||||||
-s, --skip-space-check Skip the 10Gb minimum storage space check.
|
-s, --skip-space-check Skip the 10Gb minimum storage space check.
|
||||||
-u, --no-gui Skips launching the GUI.
|
-u, --no-gui Skips launching the GUI.
|
||||||
|
|
@ -91,6 +93,7 @@ PARENT_DIR=""
|
||||||
VENV_DIR=""
|
VENV_DIR=""
|
||||||
USE_IPEX=false
|
USE_IPEX=false
|
||||||
USE_ROCM=false
|
USE_ROCM=false
|
||||||
|
QUIET="--show_stdout"
|
||||||
|
|
||||||
# Function to get the distro name
|
# Function to get the distro name
|
||||||
get_distro_name() {
|
get_distro_name() {
|
||||||
|
|
@ -206,20 +209,20 @@ install_python_dependencies() {
|
||||||
case "$OSTYPE" in
|
case "$OSTYPE" in
|
||||||
"lin"*)
|
"lin"*)
|
||||||
if [ "$RUNPOD" = true ]; then
|
if [ "$RUNPOD" = true ]; then
|
||||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt
|
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt $QUIET
|
||||||
elif [ "$USE_IPEX" = true ]; then
|
elif [ "$USE_IPEX" = true ]; then
|
||||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt
|
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt $QUIET
|
||||||
elif [ "$USE_ROCM" = true ] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
|
elif [ "$USE_ROCM" = true ] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
|
||||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt
|
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt $QUIET
|
||||||
else
|
else
|
||||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt
|
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt $QUIET
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
"darwin"*)
|
"darwin"*)
|
||||||
if [[ "$(uname -m)" == "arm64" ]]; then
|
if [[ "$(uname -m)" == "arm64" ]]; then
|
||||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_arm64.txt
|
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_arm64.txt $QUIET
|
||||||
else
|
else
|
||||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_amd64.txt
|
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_amd64.txt $QUIET
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
@ -307,7 +310,7 @@ update_kohya_ss() {
|
||||||
|
|
||||||
# Section: Command-line options parsing
|
# Section: Command-line options parsing
|
||||||
|
|
||||||
while getopts ":vb:d:g:inprus-:" opt; do
|
while getopts ":vb:d:g:inpqrus-:" opt; do
|
||||||
# support long options: https://stackoverflow.com/a/28466267/519360
|
# support long options: https://stackoverflow.com/a/28466267/519360
|
||||||
if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG
|
if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG
|
||||||
opt="${OPTARG%%=*}" # extract long option name
|
opt="${OPTARG%%=*}" # extract long option name
|
||||||
|
|
@ -322,6 +325,7 @@ while getopts ":vb:d:g:inprus-:" opt; do
|
||||||
i | interactive) INTERACTIVE=true ;;
|
i | interactive) INTERACTIVE=true ;;
|
||||||
n | no-git-update) SKIP_GIT_UPDATE=true ;;
|
n | no-git-update) SKIP_GIT_UPDATE=true ;;
|
||||||
p | public) PUBLIC=true ;;
|
p | public) PUBLIC=true ;;
|
||||||
|
q | quiet) QUIET="" ;;
|
||||||
r | runpod) RUNPOD=true ;;
|
r | runpod) RUNPOD=true ;;
|
||||||
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
|
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
|
||||||
u | no-gui) SKIP_GUI=true ;;
|
u | no-gui) SKIP_GUI=true ;;
|
||||||
|
|
|
||||||
|
|
@ -1,363 +1,321 @@
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
|
import subprocess
|
||||||
|
import re
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
errors = 0 # Define the 'errors' variable before using it
|
log = logging.getLogger("sd")
|
||||||
log = logging.getLogger('sd')
|
|
||||||
|
# Constants
|
||||||
|
MIN_PYTHON_VERSION = (3, 10, 9)
|
||||||
|
MAX_PYTHON_VERSION = (3, 13, 0)
|
||||||
|
LOG_DIR = "../logs/setup/"
|
||||||
|
LOG_LEVEL = "INFO" # Set to "INFO" or "WARNING" for less verbose logging
|
||||||
|
|
||||||
|
|
||||||
def check_python_version():
|
def check_python_version():
|
||||||
"""
|
"""
|
||||||
Check if the current Python version is within the acceptable range.
|
Check if the current Python version is within the acceptable range.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the current Python version is valid, False otherwise.
|
bool: True if the current Python version is valid, False otherwise.
|
||||||
"""
|
"""
|
||||||
min_version = (3, 10, 9)
|
log.debug("Checking Python version...")
|
||||||
max_version = (3, 11, 0)
|
|
||||||
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_version = sys.version_info
|
current_version = sys.version_info
|
||||||
log.info(f"Python version is {sys.version}")
|
log.info(f"Python version is {sys.version}")
|
||||||
|
|
||||||
if not (min_version <= current_version < max_version):
|
if not (MIN_PYTHON_VERSION <= current_version < MAX_PYTHON_VERSION):
|
||||||
log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI")
|
log.error(
|
||||||
log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0")
|
f"The current version of python ({sys.version}) is not supported."
|
||||||
|
)
|
||||||
|
log.error("The Python version must be >= 3.10.9 and < 3.13.0.")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Failed to verify Python version. Error: {e}")
|
log.error(f"Failed to verify Python version. Error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def update_submodule(quiet=True):
|
def update_submodule(quiet=True):
|
||||||
"""
|
"""
|
||||||
Ensure the submodule is initialized and updated.
|
Ensure the submodule is initialized and updated.
|
||||||
|
|
||||||
This function uses the Git command line interface to initialize and update
|
|
||||||
the specified submodule recursively. Errors during the Git operation
|
|
||||||
or if Git is not found are caught and logged.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- quiet: If True, suppresses the output of the Git command.
|
|
||||||
"""
|
"""
|
||||||
|
log.debug("Updating submodule...")
|
||||||
git_command = ["git", "submodule", "update", "--init", "--recursive"]
|
git_command = ["git", "submodule", "update", "--init", "--recursive"]
|
||||||
|
|
||||||
if quiet:
|
if quiet:
|
||||||
git_command.append("--quiet")
|
git_command.append("--quiet")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize and update the submodule
|
|
||||||
subprocess.run(git_command, check=True)
|
subprocess.run(git_command, check=True)
|
||||||
log.info("Submodule initialized and updated.")
|
log.info("Submodule initialized and updated.")
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
# Log the error if the Git operation fails
|
|
||||||
log.error(f"Error during Git operation: {e}")
|
log.error(f"Error during Git operation: {e}")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
# Log the error if the file is not found
|
|
||||||
log.error(e)
|
log.error(e)
|
||||||
|
|
||||||
# def read_tag_version_from_file(file_path):
|
|
||||||
# """
|
|
||||||
# Read the tag version from a given file.
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
# - file_path: The path to the file containing the tag version.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# The tag version as a string.
|
|
||||||
# """
|
|
||||||
# with open(file_path, 'r') as file:
|
|
||||||
# # Read the first line and strip whitespace
|
|
||||||
# tag_version = file.readline().strip()
|
|
||||||
# return tag_version
|
|
||||||
|
|
||||||
def clone_or_checkout(repo_url, branch_or_tag, directory_name):
|
def clone_or_checkout(repo_url, branch_or_tag, directory_name):
|
||||||
"""
|
"""
|
||||||
Clone a repo or checkout a specific branch or tag if the repo already exists.
|
Clone a repo or checkout a specific branch or tag if the repo already exists.
|
||||||
For branches, it updates to the latest version before checking out.
|
|
||||||
Suppresses detached HEAD advice for tags or specific commits.
|
|
||||||
Restores the original working directory after operations.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- repo_url: The URL of the Git repository.
|
|
||||||
- branch_or_tag: The name of the branch or tag to clone or checkout.
|
|
||||||
- directory_name: The name of the directory to clone into or where the repo already exists.
|
|
||||||
"""
|
"""
|
||||||
original_dir = os.getcwd() # Store the original directory
|
log.debug(
|
||||||
|
f"Cloning or checking out repository: {repo_url}, branch/tag: {branch_or_tag}, directory: {directory_name}"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(directory_name):
|
if not os.path.exists(directory_name):
|
||||||
# Directory does not exist, clone the repo quietly
|
run_cmd = [
|
||||||
|
"git",
|
||||||
# Construct the command as a string for logging
|
"clone",
|
||||||
# run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}"
|
"--branch",
|
||||||
run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name]
|
branch_or_tag,
|
||||||
|
"--single-branch",
|
||||||
|
"--quiet",
|
||||||
# Log the command
|
repo_url,
|
||||||
log.debug(run_cmd)
|
directory_name,
|
||||||
|
]
|
||||||
# Run the command
|
log.debug(f"Cloning repository: {run_cmd}")
|
||||||
process = subprocess.Popen(
|
subprocess.run(run_cmd, check=True)
|
||||||
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
log.info(f"Successfully cloned {repo_url} ({branch_or_tag})")
|
||||||
)
|
|
||||||
output, error = process.communicate()
|
|
||||||
|
|
||||||
if error and not error.startswith("Note: switching to"):
|
|
||||||
log.warning(error)
|
|
||||||
else:
|
|
||||||
log.info(f"Successfully cloned sd-scripts {branch_or_tag}")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
os.chdir(directory_name)
|
os.chdir(directory_name)
|
||||||
|
log.debug("Fetching all branches and tags...")
|
||||||
subprocess.run(["git", "fetch", "--all", "--quiet"], check=True)
|
subprocess.run(["git", "fetch", "--all", "--quiet"], check=True)
|
||||||
subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True)
|
subprocess.run(
|
||||||
|
["git", "config", "advice.detachedHead", "false"], check=True
|
||||||
# Get the current branch or commit hash
|
)
|
||||||
current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
|
||||||
tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode()
|
current_branch_hash = (
|
||||||
|
subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||||
if current_branch_hash != tag_branch_hash:
|
)
|
||||||
run_cmd = f"git checkout {branch_or_tag} --quiet"
|
target_branch_hash = (
|
||||||
# Log the command
|
subprocess.check_output(["git", "rev-parse", branch_or_tag])
|
||||||
log.debug(run_cmd)
|
.strip()
|
||||||
|
.decode()
|
||||||
# Execute the checkout command
|
)
|
||||||
process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
output, error = process.communicate()
|
if current_branch_hash != target_branch_hash:
|
||||||
|
log.debug(f"Checking out branch/tag: {branch_or_tag}")
|
||||||
if error:
|
subprocess.run(
|
||||||
log.warning(error.decode())
|
["git", "checkout", branch_or_tag, "--quiet"], check=True
|
||||||
else:
|
)
|
||||||
log.info(f"Checked out sd-scripts {branch_or_tag} successfully.")
|
log.info(f"Checked out {branch_or_tag} successfully.")
|
||||||
else:
|
else:
|
||||||
log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.")
|
log.info(f"Already at required branch/tag: {branch_or_tag}")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
log.error(f"Error during Git operation: {e}")
|
log.error(f"Error during Git operation: {e}")
|
||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir) # Restore the original directory
|
os.chdir(original_dir)
|
||||||
|
|
||||||
# setup console and file logging
|
|
||||||
def setup_logging(clean=False):
|
def setup_logging():
|
||||||
#
|
"""
|
||||||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
Set up logging to file and console.
|
||||||
#
|
"""
|
||||||
|
log.debug("Setting up logging...")
|
||||||
|
|
||||||
from rich.theme import Theme
|
from rich.theme import Theme
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.pretty import install as pretty_install
|
|
||||||
from rich.traceback import install as traceback_install
|
|
||||||
|
|
||||||
console = Console(
|
console = Console(
|
||||||
log_time=True,
|
log_time=True,
|
||||||
log_time_format='%H:%M:%S-%f',
|
log_time_format="%H:%M:%S-%f",
|
||||||
theme=Theme(
|
theme=Theme({"traceback.border": "black", "inspect.value.border": "black"}),
|
||||||
{
|
|
||||||
'traceback.border': 'black',
|
|
||||||
'traceback.border.syntax_error': 'black',
|
|
||||||
'inspect.value.border': 'black',
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# logging.getLogger("urllib3").setLevel(logging.ERROR)
|
current_datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
# logging.getLogger("httpx").setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
current_datetime = datetime.datetime.now()
|
|
||||||
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S')
|
|
||||||
log_file = os.path.join(
|
log_file = os.path.join(
|
||||||
os.path.dirname(__file__),
|
os.path.dirname(__file__), f"{LOG_DIR}kohya_ss_gui_{current_datetime_str}.log"
|
||||||
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log',
|
|
||||||
)
|
)
|
||||||
|
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||||
|
|
||||||
# Create directories if they don't exist
|
|
||||||
log_directory = os.path.dirname(log_file)
|
|
||||||
os.makedirs(log_directory, exist_ok=True)
|
|
||||||
|
|
||||||
level = logging.INFO
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.ERROR,
|
level=logging.ERROR,
|
||||||
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s',
|
format="%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s",
|
||||||
filename=log_file,
|
filename=log_file,
|
||||||
filemode='a',
|
filemode="a",
|
||||||
encoding='utf-8',
|
encoding="utf-8",
|
||||||
force=True,
|
force=True,
|
||||||
)
|
)
|
||||||
log.setLevel(
|
log_level = os.getenv("LOG_LEVEL", LOG_LEVEL).upper()
|
||||||
logging.DEBUG
|
log.setLevel(getattr(logging, log_level, logging.DEBUG))
|
||||||
) # log to file is always at level debug for facility `sd`
|
rich_handler = RichHandler(console=console)
|
||||||
pretty_install(console=console)
|
|
||||||
traceback_install(
|
# Replace existing handlers with the rich handler
|
||||||
console=console,
|
log.handlers.clear()
|
||||||
extra_lines=1,
|
log.addHandler(rich_handler)
|
||||||
width=console.width,
|
log.debug("Logging setup complete.")
|
||||||
word_wrap=False,
|
|
||||||
indent_guides=False,
|
|
||||||
suppress=[],
|
|
||||||
)
|
|
||||||
rh = RichHandler(
|
|
||||||
show_time=True,
|
|
||||||
omit_repeated_times=False,
|
|
||||||
show_level=True,
|
|
||||||
show_path=False,
|
|
||||||
markup=False,
|
|
||||||
rich_tracebacks=True,
|
|
||||||
log_time_format='%H:%M:%S-%f',
|
|
||||||
level=level,
|
|
||||||
console=console,
|
|
||||||
)
|
|
||||||
rh.set_name(level)
|
|
||||||
while log.hasHandlers() and len(log.handlers) > 0:
|
|
||||||
log.removeHandler(log.handlers[0])
|
|
||||||
log.addHandler(rh)
|
|
||||||
|
|
||||||
|
|
||||||
def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False):
|
def install_requirements_inbulk(
|
||||||
|
requirements_file, show_stdout=True, optional_parm="", upgrade=False
|
||||||
|
):
|
||||||
|
log.debug(f"Installing requirements in bulk from: {requirements_file}")
|
||||||
if not os.path.exists(requirements_file):
|
if not os.path.exists(requirements_file):
|
||||||
log.error(f'Could not find the requirements file in {requirements_file}.')
|
log.error(f"Could not find the requirements file in {requirements_file}.")
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info(f'Installing requirements from {requirements_file}...')
|
log.info(f"Installing/Validating requirements from {requirements_file}...")
|
||||||
|
|
||||||
|
# Build the command as a list
|
||||||
|
cmd = ["pip", "install", "-r", requirements_file]
|
||||||
if upgrade:
|
if upgrade:
|
||||||
optional_parm += " -U"
|
cmd.append("--upgrade")
|
||||||
|
if not show_stdout:
|
||||||
|
cmd.append("--quiet")
|
||||||
|
if optional_parm:
|
||||||
|
cmd.extend(optional_parm.split())
|
||||||
|
|
||||||
if show_stdout:
|
try:
|
||||||
run_cmd(f'pip install -r {requirements_file} {optional_parm}')
|
# Run the command and filter output in real-time
|
||||||
else:
|
process = subprocess.Popen(
|
||||||
run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet')
|
cmd,
|
||||||
log.info(f'Requirements from {requirements_file} installed.')
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
universal_newlines=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in process.stdout:
|
||||||
|
if "Requirement already satisfied" not in line:
|
||||||
|
log.info(line.strip()) if show_stdout else None
|
||||||
|
|
||||||
|
# Capture and log any errors
|
||||||
|
_, stderr = process.communicate()
|
||||||
|
if process.returncode != 0:
|
||||||
|
log.error(f"Failed to install requirements: {stderr.strip()}")
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
log.error(f"An error occurred while installing requirements: {e}")
|
||||||
|
|
||||||
|
|
||||||
def configure_accelerate(run_accelerate=False):
|
def configure_accelerate(run_accelerate=False):
|
||||||
#
|
log.debug("Configuring accelerate...")
|
||||||
# This function was taken and adapted from code written by jstayco
|
|
||||||
#
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def env_var_exists(var_name):
|
def env_var_exists(var_name):
|
||||||
return var_name in os.environ and os.environ[var_name] != ''
|
return var_name in os.environ and os.environ[var_name] != ""
|
||||||
|
|
||||||
|
log.info("Configuring accelerate...")
|
||||||
|
|
||||||
log.info('Configuring accelerate...')
|
|
||||||
|
|
||||||
source_accelerate_config_file = os.path.join(
|
source_accelerate_config_file = os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)),
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
'..',
|
"..",
|
||||||
'config_files',
|
"config_files",
|
||||||
'accelerate',
|
"accelerate",
|
||||||
'default_config.yaml',
|
"default_config.yaml",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.exists(source_accelerate_config_file):
|
if not os.path.exists(source_accelerate_config_file):
|
||||||
|
log.warning(
|
||||||
|
f"Could not find the accelerate configuration file in {source_accelerate_config_file}."
|
||||||
|
)
|
||||||
if run_accelerate:
|
if run_accelerate:
|
||||||
run_cmd('accelerate config')
|
log.debug("Running accelerate configuration command...")
|
||||||
|
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.'
|
"Please configure accelerate manually by running the option in the menu."
|
||||||
)
|
)
|
||||||
|
return
|
||||||
log.debug(
|
|
||||||
f'Source accelerate config location: {source_accelerate_config_file}'
|
log.debug(f"Source accelerate config location: {source_accelerate_config_file}")
|
||||||
)
|
|
||||||
|
|
||||||
target_config_location = None
|
target_config_location = None
|
||||||
|
|
||||||
log.debug(
|
env_vars = {
|
||||||
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, "
|
"HF_HOME": Path(os.environ.get("HF_HOME", "")),
|
||||||
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, "
|
"LOCALAPPDATA": Path(
|
||||||
f"USERPROFILE: {os.environ.get('USERPROFILE')}"
|
os.environ.get("LOCALAPPDATA", ""),
|
||||||
)
|
"huggingface",
|
||||||
if env_var_exists('HF_HOME'):
|
"accelerate",
|
||||||
target_config_location = Path(
|
"default_config.yaml",
|
||||||
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml'
|
),
|
||||||
)
|
"USERPROFILE": Path(
|
||||||
elif env_var_exists('LOCALAPPDATA'):
|
os.environ.get("USERPROFILE", ""),
|
||||||
target_config_location = Path(
|
".cache",
|
||||||
os.environ['LOCALAPPDATA'],
|
"huggingface",
|
||||||
'huggingface',
|
"accelerate",
|
||||||
'accelerate',
|
"default_config.yaml",
|
||||||
'default_config.yaml',
|
),
|
||||||
)
|
}
|
||||||
elif env_var_exists('USERPROFILE'):
|
|
||||||
target_config_location = Path(
|
|
||||||
os.environ['USERPROFILE'],
|
|
||||||
'.cache',
|
|
||||||
'huggingface',
|
|
||||||
'accelerate',
|
|
||||||
'default_config.yaml',
|
|
||||||
)
|
|
||||||
|
|
||||||
log.debug(f'Target config location: {target_config_location}')
|
for var, path in env_vars.items():
|
||||||
|
if env_var_exists(var):
|
||||||
|
target_config_location = path
|
||||||
|
break
|
||||||
|
|
||||||
|
log.debug(f"Target config location: {target_config_location}")
|
||||||
|
|
||||||
if target_config_location:
|
if target_config_location:
|
||||||
if not target_config_location.is_file():
|
if not target_config_location.is_file():
|
||||||
|
log.debug(
|
||||||
|
f"Creating target config directory: {target_config_location.parent}"
|
||||||
|
)
|
||||||
target_config_location.parent.mkdir(parents=True, exist_ok=True)
|
target_config_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
log.debug(
|
log.debug(
|
||||||
f'Target accelerate config location: {target_config_location}'
|
f"Copying config file to target location: {target_config_location}"
|
||||||
)
|
)
|
||||||
shutil.copyfile(
|
shutil.copyfile(source_accelerate_config_file, target_config_location)
|
||||||
source_accelerate_config_file, target_config_location
|
log.info(f"Copied accelerate config file to: {target_config_location}")
|
||||||
)
|
elif run_accelerate:
|
||||||
log.info(
|
log.debug("Running accelerate configuration command...")
|
||||||
f'Copied accelerate config file to: {target_config_location}'
|
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||||
)
|
|
||||||
else:
|
|
||||||
if run_accelerate:
|
|
||||||
run_cmd('accelerate config')
|
|
||||||
else:
|
|
||||||
log.warning(
|
|
||||||
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if run_accelerate:
|
|
||||||
run_cmd('accelerate config')
|
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
|
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
|
||||||
)
|
)
|
||||||
|
elif run_accelerate:
|
||||||
|
log.debug("Running accelerate configuration command...")
|
||||||
|
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_torch():
|
def check_torch():
|
||||||
|
log.debug("Checking Torch installation...")
|
||||||
#
|
#
|
||||||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
# This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
|
||||||
#
|
#
|
||||||
|
|
||||||
# Check for toolkit
|
# Check for toolkit
|
||||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
if shutil.which("nvidia-smi") is not None or os.path.exists(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
os.environ.get("SystemRoot") or r"C:\Windows",
|
||||||
'System32',
|
"System32",
|
||||||
'nvidia-smi.exe',
|
"nvidia-smi.exe",
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
log.info('nVidia toolkit detected')
|
log.info("nVidia toolkit detected")
|
||||||
elif shutil.which('rocminfo') is not None or os.path.exists(
|
elif shutil.which("rocminfo") is not None or os.path.exists(
|
||||||
'/opt/rocm/bin/rocminfo'
|
"/opt/rocm/bin/rocminfo"
|
||||||
):
|
):
|
||||||
log.info('AMD toolkit detected')
|
log.info("AMD toolkit detected")
|
||||||
elif (shutil.which('sycl-ls') is not None
|
elif (
|
||||||
or os.environ.get('ONEAPI_ROOT') is not None
|
shutil.which("sycl-ls") is not None
|
||||||
or os.path.exists('/opt/intel/oneapi')):
|
or os.environ.get("ONEAPI_ROOT") is not None
|
||||||
log.info('Intel OneAPI toolkit detected')
|
or os.path.exists("/opt/intel/oneapi")
|
||||||
|
):
|
||||||
|
log.info("Intel OneAPI toolkit detected")
|
||||||
else:
|
else:
|
||||||
log.info('Using CPU-only Torch')
|
log.info("Using CPU-only Torch")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
log.debug("Torch module imported successfully.")
|
||||||
try:
|
try:
|
||||||
# Import IPEX / XPU support
|
# Import IPEX / XPU support
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
except Exception:
|
|
||||||
pass
|
log.debug("Intel extension for PyTorch imported successfully.")
|
||||||
log.info(f'Torch {torch.__version__}')
|
except Exception as e:
|
||||||
|
log.warning(f"Failed to import intel_extension_for_pytorch: {e}")
|
||||||
|
log.info(f"Torch {torch.__version__}")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if torch.version.cuda:
|
if torch.version.cuda:
|
||||||
|
|
@ -367,33 +325,33 @@ def check_torch():
|
||||||
)
|
)
|
||||||
elif torch.version.hip:
|
elif torch.version.hip:
|
||||||
# Log AMD ROCm HIP version
|
# Log AMD ROCm HIP version
|
||||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
||||||
else:
|
else:
|
||||||
log.warning('Unknown Torch backend')
|
log.warning("Unknown Torch backend")
|
||||||
|
|
||||||
# Log information about detected GPUs
|
# Log information about detected GPUs
|
||||||
for device in [
|
for device in [
|
||||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||||
]:
|
]:
|
||||||
log.info(
|
log.info(
|
||||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
f"Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}"
|
||||||
)
|
)
|
||||||
# Check if XPU is available
|
# Check if XPU is available
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
# Log Intel IPEX version
|
# Log Intel IPEX version
|
||||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
||||||
for device in [
|
for device in [
|
||||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
||||||
]:
|
]:
|
||||||
log.info(
|
log.info(
|
||||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
f"Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.warning('Torch reports GPU not available')
|
log.warning("Torch reports GPU not available")
|
||||||
|
|
||||||
return int(torch.__version__[0])
|
return int(torch.__version__[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# log.warning(f'Could not load torch: {e}')
|
log.error(f"Could not load torch: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -404,17 +362,19 @@ def check_repo_version():
|
||||||
in the current directory. If the file exists, it reads the release version from the file and logs it.
|
in the current directory. If the file exists, it reads the release version from the file and logs it.
|
||||||
If the file does not exist, it logs a debug message indicating that the release could not be read.
|
If the file does not exist, it logs a debug message indicating that the release could not be read.
|
||||||
"""
|
"""
|
||||||
if os.path.exists('.release'):
|
log.debug("Checking repository version...")
|
||||||
|
if os.path.exists(".release"):
|
||||||
try:
|
try:
|
||||||
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
|
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
|
||||||
release= file.read()
|
release = file.read()
|
||||||
|
|
||||||
log.info(f'Kohya_ss GUI version: {release}')
|
log.info(f"Kohya_ss GUI version: {release}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f'Could not read release: {e}')
|
log.error(f"Could not read release: {e}")
|
||||||
else:
|
else:
|
||||||
log.debug('Could not read release...')
|
log.debug("Could not read release...")
|
||||||
|
|
||||||
|
|
||||||
# execute git command
|
# execute git command
|
||||||
def git(arg: str, folder: str = None, ignore: bool = False):
|
def git(arg: str, folder: str = None, ignore: bool = False):
|
||||||
"""
|
"""
|
||||||
|
|
@ -433,22 +393,31 @@ def git(arg: str, folder: str = None, ignore: bool = False):
|
||||||
If set to True, errors will not be logged.
|
If set to True, errors will not be logged.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
|
||||||
"""
|
"""
|
||||||
|
log.debug(f"Running git command: git {arg} in folder: {folder or '.'}")
|
||||||
git_cmd = os.environ.get('GIT', "git")
|
result = subprocess.run(
|
||||||
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
|
["git", arg],
|
||||||
|
check=False,
|
||||||
|
shell=True,
|
||||||
|
env=os.environ,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
cwd=folder or ".",
|
||||||
|
)
|
||||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||||
if len(result.stderr) > 0:
|
if len(result.stderr) > 0:
|
||||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
|
||||||
|
encoding="utf8", errors="ignore"
|
||||||
|
)
|
||||||
txt = txt.strip()
|
txt = txt.strip()
|
||||||
if result.returncode != 0 and not ignore:
|
if result.returncode != 0 and not ignore:
|
||||||
global errors
|
global errors
|
||||||
errors += 1
|
errors += 1
|
||||||
log.error(f'Error running git: {folder} / {arg}')
|
log.error(f"Error running git: {folder} / {arg}")
|
||||||
if 'or stash them' in txt:
|
if "or stash them" in txt:
|
||||||
log.error(f'Local changes detected: check log for details...')
|
log.error(f"Local changes detected: check log for details...")
|
||||||
log.debug(f'Git output: {txt}')
|
log.debug(f"Git output: {txt}")
|
||||||
|
|
||||||
|
|
||||||
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
|
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
|
||||||
|
|
@ -473,31 +442,42 @@ def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool =
|
||||||
Returns:
|
Returns:
|
||||||
- The output of the pip command as a string, or None if the 'show_stdout' flag is set.
|
- The output of the pip command as a string, or None if the 'show_stdout' flag is set.
|
||||||
"""
|
"""
|
||||||
# arg = arg.replace('>=', '==')
|
log.debug(f"Running pip command: {arg}")
|
||||||
if not quiet:
|
if not quiet:
|
||||||
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}')
|
log.info(
|
||||||
log.debug(f"Running pip: {arg}")
|
f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}'
|
||||||
|
)
|
||||||
|
pip_cmd = [rf"{sys.executable}", "-m", "pip"] + arg.split(" ")
|
||||||
|
log.debug(f"Running pip: {pip_cmd}")
|
||||||
if show_stdout:
|
if show_stdout:
|
||||||
subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ)
|
subprocess.run(pip_cmd, shell=False, check=False, env=os.environ)
|
||||||
else:
|
else:
|
||||||
result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
result = subprocess.run(
|
||||||
|
pip_cmd,
|
||||||
|
shell=False,
|
||||||
|
check=False,
|
||||||
|
env=os.environ,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||||
if len(result.stderr) > 0:
|
if len(result.stderr) > 0:
|
||||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
|
||||||
|
encoding="utf8", errors="ignore"
|
||||||
|
)
|
||||||
txt = txt.strip()
|
txt = txt.strip()
|
||||||
if result.returncode != 0 and not ignore:
|
if result.returncode != 0 and not ignore:
|
||||||
global errors # pylint: disable=global-statement
|
log.error(f"Error running pip: {arg}")
|
||||||
errors += 1
|
log.error(f"Pip output: {txt}")
|
||||||
log.error(f'Error running pip: {arg}')
|
|
||||||
log.debug(f'Pip output: {txt}')
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
|
||||||
def installed(package, friendly: str = None):
|
def installed(package, friendly: str = None):
|
||||||
"""
|
"""
|
||||||
Checks if the specified package(s) are installed with the correct version.
|
Checks if the specified package(s) are installed with the correct version.
|
||||||
This function can handle package specifications with or without version constraints,
|
This function can handle package specifications with or without version constraints,
|
||||||
and can also filter out command-line options and URLs when a 'friendly' string is provided.
|
and can also filter out command-line options and URLs when a 'friendly' string is provided.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- package: A string that specifies one or more packages with optional version constraints.
|
- package: A string that specifies one or more packages with optional version constraints.
|
||||||
- friendly: An optional string used to provide a cleaner version of the package string
|
- friendly: An optional string used to provide a cleaner version of the package string
|
||||||
|
|
@ -505,43 +485,39 @@ def installed(package, friendly: str = None):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- True if all specified packages are installed with the correct versions, False otherwise.
|
- True if all specified packages are installed with the correct versions, False otherwise.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This function was adapted from code written by vladimandic.
|
This function was adapted from code written by vladimandic.
|
||||||
"""
|
"""
|
||||||
|
log.debug(f"Checking if package is installed: {package}")
|
||||||
# Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version")
|
# Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version")
|
||||||
package = re.sub(r'\[.*?\]', '', package)
|
package = re.sub(r"\[.*?\]", "", package)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if friendly:
|
if friendly:
|
||||||
# If a 'friendly' version of the package string is provided, split it into components
|
# If a 'friendly' version of the package string is provided, split it into components
|
||||||
pkgs = friendly.split()
|
pkgs = friendly.split()
|
||||||
|
|
||||||
# Filter out command-line options and URLs from the package specification
|
# Filter out command-line options and URLs from the package specification
|
||||||
pkgs = [
|
pkgs = [
|
||||||
p
|
p for p in package.split() if not p.startswith("--") and "://" not in p
|
||||||
for p in package.split()
|
|
||||||
if not p.startswith('--') and "://" not in p
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Split the package string into components, excluding '-' and '=' prefixed items
|
# Split the package string into components, excluding '-' and '=' prefixed items
|
||||||
pkgs = [
|
pkgs = [
|
||||||
p
|
p
|
||||||
for p in package.split()
|
for p in package.split()
|
||||||
if not p.startswith('-') and not p.startswith('=')
|
if not p.startswith("-") and not p.startswith("=")
|
||||||
]
|
]
|
||||||
# For each package component, extract the package name, excluding any URLs
|
# For each package component, extract the package name, excluding any URLs
|
||||||
pkgs = [
|
pkgs = [p.split("/")[-1] for p in pkgs]
|
||||||
p.split('/')[-1] for p in pkgs
|
|
||||||
]
|
|
||||||
|
|
||||||
for pkg in pkgs:
|
for pkg in pkgs:
|
||||||
# Parse the package name and version based on the version specifier used
|
# Parse the package name and version based on the version specifier used
|
||||||
if '>=' in pkg:
|
if ">=" in pkg:
|
||||||
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
|
pkg_name, pkg_version = [x.strip() for x in pkg.split(">=")]
|
||||||
elif '==' in pkg:
|
elif "==" in pkg:
|
||||||
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
|
pkg_name, pkg_version = [x.strip() for x in pkg.split("==")]
|
||||||
else:
|
else:
|
||||||
pkg_name, pkg_version = pkg.strip(), None
|
pkg_name, pkg_version = pkg.strip(), None
|
||||||
|
|
||||||
|
|
@ -552,38 +528,41 @@ def installed(package, friendly: str = None):
|
||||||
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
|
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
|
||||||
if spec is None:
|
if spec is None:
|
||||||
# Try replacing underscores with dashes
|
# Try replacing underscores with dashes
|
||||||
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
|
spec = pkg_resources.working_set.by_key.get(
|
||||||
|
pkg_name.replace("_", "-"), None
|
||||||
|
)
|
||||||
|
|
||||||
if spec is not None:
|
if spec is not None:
|
||||||
# Package is found, check version
|
# Package is found, check version
|
||||||
version = pkg_resources.get_distribution(pkg_name).version
|
version = pkg_resources.get_distribution(pkg_name).version
|
||||||
log.debug(f'Package version found: {pkg_name} {version}')
|
log.debug(f"Package version found: {pkg_name} {version}")
|
||||||
|
|
||||||
if pkg_version is not None:
|
if pkg_version is not None:
|
||||||
# Verify if the installed version meets the specified constraints
|
# Verify if the installed version meets the specified constraints
|
||||||
if '>=' in pkg:
|
if ">=" in pkg:
|
||||||
ok = version >= pkg_version
|
ok = version >= pkg_version
|
||||||
else:
|
else:
|
||||||
ok = version == pkg_version
|
ok = version == pkg_version
|
||||||
|
|
||||||
if not ok:
|
if not ok:
|
||||||
# Version mismatch, log warning and return False
|
# Version mismatch, log warning and return False
|
||||||
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
|
log.warning(
|
||||||
|
f"Package wrong version: {pkg_name} {version} required {pkg_version}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
# Package not found, log debug message and return False
|
# Package not found, log debug message and return False
|
||||||
log.debug(f'Package version not found: {pkg_name}')
|
log.debug(f"Package version not found: {pkg_name}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# All specified packages are installed with the correct versions
|
# All specified packages are installed with the correct versions
|
||||||
return True
|
return True
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
# One or more packages are not installed, log debug message and return False
|
# One or more packages are not installed, log debug message and return False
|
||||||
log.debug(f'Package not installed: {pkgs}')
|
log.debug(f"Package not installed: {pkgs}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# install package using pip if not already installed
|
# install package using pip if not already installed
|
||||||
def install(
|
def install(
|
||||||
package,
|
package,
|
||||||
|
|
@ -595,7 +574,7 @@ def install(
|
||||||
"""
|
"""
|
||||||
Installs or upgrades a Python package using pip, with options to ignode errors,
|
Installs or upgrades a Python package using pip, with options to ignode errors,
|
||||||
reinstall packages, and display outputs.
|
reinstall packages, and display outputs.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- package (str): The name of the package to be installed or upgraded. Can include
|
- package (str): The name of the package to be installed or upgraded. Can include
|
||||||
version specifiers. Anything after a '#' in the package name will be ignored.
|
version specifiers. Anything after a '#' in the package name will be ignored.
|
||||||
|
|
@ -611,103 +590,98 @@ def install(
|
||||||
Returns:
|
Returns:
|
||||||
None. The function performs operations that affect the environment but does not return
|
None. The function performs operations that affect the environment but does not return
|
||||||
any value.
|
any value.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If `reinstall` is True, it disables any mechanism that allows for skipping installations
|
If `reinstall` is True, it disables any mechanism that allows for skipping installations
|
||||||
when the package is already present, forcing a fresh install.
|
when the package is already present, forcing a fresh install.
|
||||||
"""
|
"""
|
||||||
|
log.debug(f"Installing package: {package}")
|
||||||
# Remove anything after '#' in the package variable
|
# Remove anything after '#' in the package variable
|
||||||
package = package.split('#')[0].strip()
|
package = package.split("#")[0].strip()
|
||||||
|
|
||||||
if reinstall:
|
if reinstall:
|
||||||
global quick_allowed # pylint: disable=global-statement
|
global quick_allowed # pylint: disable=global-statement
|
||||||
quick_allowed = False
|
quick_allowed = False
|
||||||
if reinstall or not installed(package, friendly):
|
if reinstall or not installed(package, friendly):
|
||||||
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout)
|
pip(f"install --upgrade {package}", ignore=ignore, show_stdout=show_stdout)
|
||||||
|
|
||||||
|
|
||||||
def process_requirements_line(line, show_stdout: bool = False):
|
def process_requirements_line(line, show_stdout: bool = False):
|
||||||
|
log.debug(f"Processing requirements line: {line}")
|
||||||
# Remove brackets and their contents from the line using regular expressions
|
# Remove brackets and their contents from the line using regular expressions
|
||||||
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
||||||
package_name = re.sub(r'\[.*?\]', '', line)
|
package_name = re.sub(r"\[.*?\]", "", line)
|
||||||
install(line, package_name, show_stdout=show_stdout)
|
install(line, package_name, show_stdout=show_stdout)
|
||||||
|
|
||||||
|
|
||||||
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False):
|
def install_requirements(
|
||||||
if check_no_verify_flag:
|
requirements_file, check_no_verify_flag=False, show_stdout: bool = False
|
||||||
log.info(f'Verifying modules installation status from {requirements_file}...')
|
):
|
||||||
else:
|
"""
|
||||||
log.info(f'Installing modules from {requirements_file}...')
|
Install or verify modules from a requirements file.
|
||||||
with open(requirements_file, 'r', encoding='utf8') as f:
|
|
||||||
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.'
|
|
||||||
if check_no_verify_flag:
|
|
||||||
lines = [
|
|
||||||
line.strip()
|
|
||||||
for line in f.readlines()
|
|
||||||
if line.strip() != ''
|
|
||||||
and not line.startswith('#')
|
|
||||||
and line is not None
|
|
||||||
and 'no_verify' not in line
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
lines = [
|
|
||||||
line.strip()
|
|
||||||
for line in f.readlines()
|
|
||||||
if line.strip() != ''
|
|
||||||
and not line.startswith('#')
|
|
||||||
and line is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Iterate over each line and install the requirements
|
Parameters:
|
||||||
for line in lines:
|
- requirements_file (str): Path to the requirements file.
|
||||||
# Check if the line starts with '-r' to include another requirements file
|
- check_no_verify_flag (bool): If True, verify modules installation status without installing.
|
||||||
if line.startswith('-r'):
|
- show_stdout (bool): If True, show the standard output of the installation process.
|
||||||
# Get the path to the included requirements file
|
"""
|
||||||
included_file = line[2:].strip()
|
log.debug(f"Installing requirements from file: {requirements_file}")
|
||||||
# Expand the included requirements file recursively
|
action = "Verifying" if check_no_verify_flag else "Installing"
|
||||||
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout)
|
log.info(f"{action} modules from {requirements_file}...")
|
||||||
else:
|
|
||||||
process_requirements_line(line, show_stdout=show_stdout)
|
with open(requirements_file, "r", encoding="utf8") as f:
|
||||||
|
lines = [
|
||||||
|
line.strip()
|
||||||
|
for line in f.readlines()
|
||||||
|
if line.strip() and not line.startswith("#") and "no_verify" not in line
|
||||||
|
]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("-r"):
|
||||||
|
included_file = line[2:].strip()
|
||||||
|
log.debug(f"Processing included requirements file: {included_file}")
|
||||||
|
install_requirements(
|
||||||
|
included_file,
|
||||||
|
check_no_verify_flag=check_no_verify_flag,
|
||||||
|
show_stdout=show_stdout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
process_requirements_line(line, show_stdout=show_stdout)
|
||||||
|
|
||||||
|
|
||||||
def ensure_base_requirements():
|
def ensure_base_requirements():
|
||||||
try:
|
try:
|
||||||
import rich # pylint: disable=unused-import
|
import rich # pylint: disable=unused-import
|
||||||
except ImportError:
|
except ImportError:
|
||||||
install('--upgrade rich', 'rich')
|
install("--upgrade rich", "rich")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import packaging
|
import packaging
|
||||||
except ImportError:
|
except ImportError:
|
||||||
install('packaging')
|
install("packaging")
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(run_cmd):
|
def run_cmd(run_cmd):
|
||||||
|
"""
|
||||||
|
Execute a command using subprocess.
|
||||||
|
"""
|
||||||
|
log.debug(f"Running command: {run_cmd}")
|
||||||
try:
|
try:
|
||||||
subprocess.run(run_cmd, shell=True, check=False, env=os.environ)
|
subprocess.run(run_cmd, shell=True, check=True, env=os.environ)
|
||||||
|
log.debug(f"Command executed successfully: {run_cmd}")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
log.error(f'Error occurred while running command: {run_cmd}')
|
log.error(f"Error occurred while running command: {run_cmd}")
|
||||||
log.error(f'Error: {e}')
|
log.error(f"Error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def delete_file(file_path):
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_file(file_path, content):
|
|
||||||
try:
|
|
||||||
with open(file_path, 'w') as file:
|
|
||||||
file.write(content)
|
|
||||||
except IOError as e:
|
|
||||||
print(f'Error occurred while writing to file: {file_path}')
|
|
||||||
print(f'Error: {e}')
|
|
||||||
|
|
||||||
|
|
||||||
def clear_screen():
|
def clear_screen():
|
||||||
# Check the current operating system to execute the correct clear screen command
|
"""
|
||||||
if os.name == 'nt': # If the operating system is Windows
|
Clear the terminal screen.
|
||||||
os.system('cls')
|
"""
|
||||||
else: # If the operating system is Linux or Mac
|
log.debug("Attempting to clear the terminal screen")
|
||||||
os.system('clear')
|
try:
|
||||||
|
os.system("cls" if os.name == "nt" else "clear")
|
||||||
|
log.info("Terminal screen cleared successfully")
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Error occurred while clearing the terminal screen")
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,10 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce
|
||||||
|
|
||||||
# Upgrade pip if needed
|
# Upgrade pip if needed
|
||||||
setup_common.install('pip')
|
setup_common.install('pip')
|
||||||
setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout)
|
setup_common.install_requirements_inbulk(
|
||||||
|
platform_requirements_file, show_stdout=show_stdout,
|
||||||
|
)
|
||||||
|
# setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout)
|
||||||
if not no_run_accelerate:
|
if not no_run_accelerate:
|
||||||
setup_common.configure_accelerate(run_accelerate=False)
|
setup_common.configure_accelerate(run_accelerate=False)
|
||||||
|
|
||||||
|
|
@ -31,10 +34,6 @@ if __name__ == '__main__':
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
setup_common.update_submodule()
|
setup_common.update_submodule()
|
||||||
|
|
||||||
# setup_common.clone_or_checkout(
|
|
||||||
# "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
|
|
||||||
# )
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file')
|
parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file')
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,10 @@ def main_menu(platform_requirements_file):
|
||||||
|
|
||||||
# Upgrade pip if needed
|
# Upgrade pip if needed
|
||||||
setup_common.install('pip')
|
setup_common.install('pip')
|
||||||
setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=True)
|
|
||||||
|
setup_common.install_requirements_inbulk(
|
||||||
|
platform_requirements_file, show_stdout=True,
|
||||||
|
)
|
||||||
configure_accelerate()
|
configure_accelerate()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -123,12 +123,13 @@ def install_kohya_ss_torch2(headless: bool = False):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
setup_common.install_requirements_inbulk(
|
setup_common.install_requirements_inbulk(
|
||||||
"requirements_pytorch_windows.txt", show_stdout=True, optional_parm="--index-url https://download.pytorch.org/whl/cu118"
|
"requirements_pytorch_windows.txt", show_stdout=True,
|
||||||
|
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
|
||||||
)
|
)
|
||||||
|
|
||||||
setup_common.install_requirements_inbulk(
|
# setup_common.install_requirements_inbulk(
|
||||||
"requirements_windows.txt", show_stdout=True, upgrade=True
|
# "requirements_windows.txt", show_stdout=True, upgrade=True
|
||||||
)
|
# )
|
||||||
|
|
||||||
setup_common.run_cmd("accelerate config default")
|
setup_common.run_cmd("accelerate config default")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,11 @@ import argparse
|
||||||
import setup_common
|
import setup_common
|
||||||
|
|
||||||
# Get the absolute path of the current file's directory (Kohua_SS project directory)
|
# Get the absolute path of the current file's directory (Kohua_SS project directory)
|
||||||
project_directory = os.path.dirname(os.path.abspath(__file__))
|
project_directory = (
|
||||||
|
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
# Check if the "setup" directory is present in the project_directory
|
if "setup" in os.path.dirname(os.path.abspath(__file__))
|
||||||
if "setup" in project_directory:
|
else os.path.dirname(os.path.abspath(__file__))
|
||||||
# If the "setup" directory is present, move one level up to the parent directory
|
)
|
||||||
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
# Add the project directory to the beginning of the Python search path
|
# Add the project directory to the beginning of the Python search path
|
||||||
sys.path.insert(0, project_directory)
|
sys.path.insert(0, project_directory)
|
||||||
|
|
@ -19,115 +18,178 @@ from kohya_gui.custom_logging import setup_logging
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
log = setup_logging()
|
log = setup_logging()
|
||||||
|
log.debug(f"Project directory set to: {project_directory}")
|
||||||
|
|
||||||
def check_path_with_space():
|
def check_path_with_space():
|
||||||
# Get the current working directory
|
"""Check if the current working directory contains a space."""
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
|
log.debug(f"Current working directory: {cwd}")
|
||||||
# Check if the current working directory contains a space
|
|
||||||
if " " in cwd:
|
if " " in cwd:
|
||||||
log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
|
# Log an error if the current working directory contains spaces
|
||||||
log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
|
log.error(
|
||||||
log.error("The current working directory is: " + cwd)
|
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
|
||||||
exit(1)
|
)
|
||||||
|
log.error(
|
||||||
|
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
|
||||||
|
)
|
||||||
|
log.error(f"The current working directory is: {cwd}")
|
||||||
|
raise RuntimeError("Invalid path: contains spaces.")
|
||||||
|
|
||||||
def check_torch():
|
def detect_toolkit():
|
||||||
# Check for toolkit
|
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
|
||||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
log.debug("Detecting available toolkit...")
|
||||||
|
# Check for NVIDIA toolkit by looking for nvidia-smi executable
|
||||||
|
if shutil.which("nvidia-smi") or os.path.exists(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
|
||||||
'System32',
|
|
||||||
'nvidia-smi.exe',
|
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
log.info('nVidia toolkit detected')
|
log.debug("nVidia toolkit detected")
|
||||||
elif shutil.which('rocminfo') is not None or os.path.exists(
|
return "nVidia"
|
||||||
'/opt/rocm/bin/rocminfo'
|
# Check for AMD toolkit by looking for rocminfo executable
|
||||||
|
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
|
||||||
|
log.debug("AMD toolkit detected")
|
||||||
|
return "AMD"
|
||||||
|
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
|
||||||
|
elif (
|
||||||
|
shutil.which("sycl-ls")
|
||||||
|
or os.environ.get("ONEAPI_ROOT")
|
||||||
|
or os.path.exists("/opt/intel/oneapi")
|
||||||
):
|
):
|
||||||
log.info('AMD toolkit detected')
|
log.debug("Intel toolkit detected")
|
||||||
elif (shutil.which('sycl-ls') is not None
|
return "Intel"
|
||||||
or os.environ.get('ONEAPI_ROOT') is not None
|
# Default to CPU if no toolkit is detected
|
||||||
or os.path.exists('/opt/intel/oneapi')):
|
|
||||||
log.info('Intel OneAPI toolkit detected')
|
|
||||||
else:
|
else:
|
||||||
log.info('Using CPU-only Torch')
|
log.debug("No specific GPU toolkit detected, defaulting to CPU")
|
||||||
|
return "CPU"
|
||||||
|
|
||||||
|
def check_torch():
|
||||||
|
"""Check if torch is available and log the relevant information."""
|
||||||
|
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
|
||||||
|
toolkit = detect_toolkit()
|
||||||
|
log.info(f"{toolkit} toolkit detected")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import PyTorch
|
||||||
|
log.debug("Importing PyTorch...")
|
||||||
import torch
|
import torch
|
||||||
try:
|
|
||||||
# Import IPEX / XPU support
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
log.info(f'Torch {torch.__version__}')
|
|
||||||
|
|
||||||
|
ipex = None
|
||||||
|
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
|
||||||
|
if toolkit == "Intel":
|
||||||
|
try:
|
||||||
|
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
|
||||||
|
except ImportError:
|
||||||
|
log.warning("Intel Extension for PyTorch (IPEX) not found.")
|
||||||
|
|
||||||
|
# Log the PyTorch version
|
||||||
|
log.info(f"Torch {torch.__version__}")
|
||||||
|
|
||||||
|
# Check if CUDA (NVIDIA GPU) is available
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if torch.version.cuda:
|
log.debug("CUDA is available, logging CUDA info...")
|
||||||
# Log nVidia CUDA and cuDNN versions
|
log_cuda_info(torch)
|
||||||
log.info(
|
# Check if XPU (Intel GPU) is available
|
||||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
|
||||||
)
|
|
||||||
elif torch.version.hip:
|
|
||||||
# Log AMD ROCm HIP version
|
|
||||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
|
||||||
else:
|
|
||||||
log.warning('Unknown Torch backend')
|
|
||||||
|
|
||||||
# Log information about detected GPUs
|
|
||||||
for device in [
|
|
||||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
|
||||||
]:
|
|
||||||
log.info(
|
|
||||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
|
||||||
)
|
|
||||||
# Check if XPU is available
|
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
# Log Intel IPEX version
|
log.debug("XPU is available, logging XPU info...")
|
||||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
log_xpu_info(torch, ipex)
|
||||||
for device in [
|
# Log a warning if no GPU is available
|
||||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
|
||||||
]:
|
|
||||||
log.info(
|
|
||||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log.warning('Torch reports GPU not available')
|
log.warning("Torch reports GPU not available")
|
||||||
|
|
||||||
|
# Return the major version of PyTorch
|
||||||
return int(torch.__version__[0])
|
return int(torch.__version__[0])
|
||||||
except Exception as e:
|
except ImportError as e:
|
||||||
log.error(f'Could not load torch: {e}')
|
# Log an error if PyTorch cannot be loaded
|
||||||
|
log.error(f"Could not load torch: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
# Log an unexpected error
|
||||||
|
log.error(f"Unexpected error while checking torch: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def log_cuda_info(torch):
|
||||||
|
"""Log information about CUDA-enabled GPUs."""
|
||||||
|
# Log the CUDA and cuDNN versions if available
|
||||||
|
if torch.version.cuda:
|
||||||
|
log.info(
|
||||||
|
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||||
|
)
|
||||||
|
# Log the ROCm HIP version if using AMD GPU
|
||||||
|
elif torch.version.hip:
|
||||||
|
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
||||||
|
else:
|
||||||
|
log.warning("Unknown Torch backend")
|
||||||
|
|
||||||
|
# Log information about each detected CUDA-enabled GPU
|
||||||
|
for device in range(torch.cuda.device_count()):
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
log.info(
|
||||||
|
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_xpu_info(torch, ipex):
|
||||||
|
"""Log information about Intel XPU-enabled GPUs."""
|
||||||
|
# Log the Intel Extension for PyTorch (IPEX) version if available
|
||||||
|
if ipex:
|
||||||
|
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
||||||
|
# Log information about each detected XPU-enabled GPU
|
||||||
|
for device in range(torch.xpu.device_count()):
|
||||||
|
props = torch.xpu.get_device_properties(device)
|
||||||
|
log.info(
|
||||||
|
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
|
||||||
|
)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Check the repository version to ensure compatibility
|
||||||
|
log.debug("Checking repository version...")
|
||||||
setup_common.check_repo_version()
|
setup_common.check_repo_version()
|
||||||
|
# Check if the current path contains spaces, which are not supported
|
||||||
|
log.debug("Checking if the current path contains spaces...")
|
||||||
check_path_with_space()
|
check_path_with_space()
|
||||||
|
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
|
log.debug("Parsing command line arguments...")
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Validate that requirements are satisfied.'
|
description="Validate that requirements are satisfied."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-r',
|
"-r", "--requirements", type=str, help="Path to the requirements file."
|
||||||
'--requirements',
|
|
||||||
type=str,
|
|
||||||
help='Path to the requirements file.',
|
|
||||||
)
|
)
|
||||||
parser.add_argument('--debug', action='store_true', help='Debug on')
|
parser.add_argument("--debug", action="store_true", help="Debug on")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Update git submodules if necessary
|
||||||
|
log.debug("Updating git submodules...")
|
||||||
setup_common.update_submodule()
|
setup_common.update_submodule()
|
||||||
|
|
||||||
|
# Check if PyTorch is installed and log relevant information
|
||||||
|
log.debug("Checking if PyTorch is installed...")
|
||||||
torch_ver = check_torch()
|
torch_ver = check_torch()
|
||||||
|
|
||||||
if not setup_common.check_python_version():
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
if args.requirements:
|
|
||||||
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
|
|
||||||
else:
|
|
||||||
setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
|
|
||||||
setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
# Check if the Python version is compatible
|
||||||
|
log.debug("Checking Python version...")
|
||||||
|
if not setup_common.check_python_version():
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Install required packages from the specified requirements file
|
||||||
|
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
|
||||||
|
log.debug(f"Installing requirements from: {requirements_file}")
|
||||||
|
setup_common.install_requirements_inbulk(
|
||||||
|
requirements_file, show_stdout=True,
|
||||||
|
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
|
||||||
|
|
||||||
|
# log.debug("Installing additional requirements from: requirements_windows.txt")
|
||||||
|
# setup_common.install_requirements(
|
||||||
|
# "requirements_windows.txt", check_no_verify_flag=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
log.debug("Starting main function...")
|
||||||
main()
|
main()
|
||||||
|
log.debug("Main function finished.")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,125 @@
|
||||||
|
{
|
||||||
|
"adaptive_noise_scale": 0,
|
||||||
|
"additional_parameters": "",
|
||||||
|
"async_upload": false,
|
||||||
|
"bucket_no_upscale": true,
|
||||||
|
"bucket_reso_steps": 1,
|
||||||
|
"cache_latents": true,
|
||||||
|
"cache_latents_to_disk": false,
|
||||||
|
"caption_dropout_every_n_epochs": 0,
|
||||||
|
"caption_dropout_rate": 0.05,
|
||||||
|
"caption_extension": "",
|
||||||
|
"clip_skip": 2,
|
||||||
|
"color_aug": false,
|
||||||
|
"dataset_config": "",
|
||||||
|
"dynamo_backend": "no",
|
||||||
|
"dynamo_mode": "default",
|
||||||
|
"dynamo_use_dynamic": false,
|
||||||
|
"dynamo_use_fullgraph": false,
|
||||||
|
"enable_bucket": true,
|
||||||
|
"epoch": 8,
|
||||||
|
"extra_accelerate_launch_args": "",
|
||||||
|
"flip_aug": false,
|
||||||
|
"full_fp16": false,
|
||||||
|
"gpu_ids": "",
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"huber_c": 0.1,
|
||||||
|
"huber_schedule": "snr",
|
||||||
|
"huggingface_path_in_repo": "",
|
||||||
|
"huggingface_repo_id": "False",
|
||||||
|
"huggingface_repo_type": "",
|
||||||
|
"huggingface_repo_visibility": "",
|
||||||
|
"huggingface_token": "",
|
||||||
|
"init_word": "*",
|
||||||
|
"ip_noise_gamma": 0.1,
|
||||||
|
"ip_noise_gamma_random_strength": true,
|
||||||
|
"keep_tokens": 0,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"log_config": false,
|
||||||
|
"log_tracker_config": "",
|
||||||
|
"log_tracker_name": "",
|
||||||
|
"log_with": "",
|
||||||
|
"logging_dir": "./test/logs",
|
||||||
|
"loss_type": "l2",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"lr_scheduler_args": "",
|
||||||
|
"lr_scheduler_num_cycles": 1,
|
||||||
|
"lr_scheduler_power": 1,
|
||||||
|
"lr_scheduler_type": "",
|
||||||
|
"lr_warmup": 0,
|
||||||
|
"main_process_port": 0,
|
||||||
|
"max_bucket_reso": 2048,
|
||||||
|
"max_data_loader_n_workers": 0,
|
||||||
|
"max_resolution": "1024,1024",
|
||||||
|
"max_timestep": 0,
|
||||||
|
"max_token_length": 75,
|
||||||
|
"max_train_epochs": 0,
|
||||||
|
"max_train_steps": 0,
|
||||||
|
"mem_eff_attn": false,
|
||||||
|
"metadata_author": "False",
|
||||||
|
"metadata_description": "",
|
||||||
|
"metadata_license": "",
|
||||||
|
"metadata_tags": "",
|
||||||
|
"metadata_title": "",
|
||||||
|
"min_bucket_reso": 256,
|
||||||
|
"min_snr_gamma": 10,
|
||||||
|
"min_timestep": false,
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"model_list": "custom",
|
||||||
|
"multi_gpu": false,
|
||||||
|
"multires_noise_discount": 0.2,
|
||||||
|
"multires_noise_iterations": 8,
|
||||||
|
"no_token_padding": false,
|
||||||
|
"noise_offset": 0.05,
|
||||||
|
"noise_offset_random_strength": true,
|
||||||
|
"noise_offset_type": "Original",
|
||||||
|
"num_cpu_threads_per_process": 2,
|
||||||
|
"num_machines": 1,
|
||||||
|
"num_processes": 1,
|
||||||
|
"num_vectors_per_token": 8,
|
||||||
|
"optimizer": "AdamW8bit",
|
||||||
|
"optimizer_args": "",
|
||||||
|
"output_dir": "./test/output",
|
||||||
|
"output_name": "TI-Adamw8bit-SDXL",
|
||||||
|
"persistent_data_loader_workers": false,
|
||||||
|
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
"prior_loss_weight": 1,
|
||||||
|
"random_crop": false,
|
||||||
|
"reg_data_dir": "",
|
||||||
|
"resume": "",
|
||||||
|
"resume_from_huggingface": "False",
|
||||||
|
"sample_every_n_epochs": 0,
|
||||||
|
"sample_every_n_steps": 20,
|
||||||
|
"sample_prompts": "a painting of man wearing a gas mask , by darius kawasaki",
|
||||||
|
"sample_sampler": "euler_a",
|
||||||
|
"save_as_bool": false,
|
||||||
|
"save_every_n_epochs": 1,
|
||||||
|
"save_every_n_steps": 0,
|
||||||
|
"save_last_n_steps": 0,
|
||||||
|
"save_last_n_steps_state": 0,
|
||||||
|
"save_model_as": "safetensors",
|
||||||
|
"save_precision": "fp16",
|
||||||
|
"save_state": false,
|
||||||
|
"save_state_on_train_end": false,
|
||||||
|
"save_state_to_huggingface": false,
|
||||||
|
"scale_v_pred_loss_like_noise_pred": false,
|
||||||
|
"sdxl": true,
|
||||||
|
"sdxl_no_half_vae": true,
|
||||||
|
"seed": 1234,
|
||||||
|
"shuffle_caption": false,
|
||||||
|
"stop_text_encoder_training": 0,
|
||||||
|
"template": "style template",
|
||||||
|
"token_string": "zxc",
|
||||||
|
"train_batch_size": 4,
|
||||||
|
"train_data_dir": "./test/img",
|
||||||
|
"v2": false,
|
||||||
|
"v_parameterization": false,
|
||||||
|
"v_pred_like_loss": 0,
|
||||||
|
"vae": "",
|
||||||
|
"vae_batch_size": 0,
|
||||||
|
"wandb_api_key": "",
|
||||||
|
"wandb_run_name": "",
|
||||||
|
"weights": "",
|
||||||
|
"xformers": "xformers"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
[general]
|
||||||
|
# define common settings here
|
||||||
|
flip_aug = true
|
||||||
|
color_aug = false
|
||||||
|
keep_tokens_separator= "|||"
|
||||||
|
shuffle_caption = false
|
||||||
|
caption_tag_dropout_rate = 0
|
||||||
|
caption_extension = ".txt"
|
||||||
|
min_bucket_reso = 64
|
||||||
|
max_bucket_reso = 2048
|
||||||
|
|
||||||
|
[[datasets]]
|
||||||
|
# define the first resolution here
|
||||||
|
batch_size = 1
|
||||||
|
enable_bucket = true
|
||||||
|
resolution = [1024, 1024]
|
||||||
|
|
||||||
|
[[datasets.subsets]]
|
||||||
|
image_dir = "./test/img/10_darius kawasaki person"
|
||||||
|
num_repeats = 10
|
||||||
|
|
||||||
|
[[datasets]]
|
||||||
|
# define the second resolution here
|
||||||
|
batch_size = 1
|
||||||
|
enable_bucket = true
|
||||||
|
resolution = [768, 768]
|
||||||
|
|
||||||
|
[[datasets.subsets]]
|
||||||
|
image_dir = "./test/img/10_darius kawasaki person"
|
||||||
|
num_repeats = 10
|
||||||
|
|
||||||
|
[[datasets]]
|
||||||
|
# define the third resolution here
|
||||||
|
batch_size = 1
|
||||||
|
enable_bucket = true
|
||||||
|
resolution = [512, 512]
|
||||||
|
|
||||||
|
[[datasets.subsets]]
|
||||||
|
image_dir = "./test/img/10_darius kawasaki person"
|
||||||
|
num_repeats = 10
|
||||||
|
|
@ -1,49 +1,75 @@
|
||||||
{
|
{
|
||||||
"adaptive_noise_scale": 0,
|
"adaptive_noise_scale": 0,
|
||||||
"additional_parameters": "",
|
"additional_parameters": "",
|
||||||
|
"async_upload": false,
|
||||||
"bucket_no_upscale": true,
|
"bucket_no_upscale": true,
|
||||||
"bucket_reso_steps": 64,
|
"bucket_reso_steps": 64,
|
||||||
"cache_latents": true,
|
"cache_latents": true,
|
||||||
"cache_latents_to_disk": false,
|
"cache_latents_to_disk": false,
|
||||||
"caption_dropout_every_n_epochs": 0.0,
|
"caption_dropout_every_n_epochs": 0,
|
||||||
"caption_dropout_rate": 0.05,
|
"caption_dropout_rate": 0.05,
|
||||||
"caption_extension": "",
|
"caption_extension": "",
|
||||||
"clip_skip": 2,
|
"clip_skip": 2,
|
||||||
"color_aug": false,
|
"color_aug": false,
|
||||||
"dataset_config": "./test/config/dataset.toml",
|
"dataset_config": "./test/config/dataset.toml",
|
||||||
|
"debiased_estimation_loss": false,
|
||||||
|
"disable_mmap_load_safetensors": false,
|
||||||
|
"dynamo_backend": "no",
|
||||||
|
"dynamo_mode": "default",
|
||||||
|
"dynamo_use_dynamic": false,
|
||||||
|
"dynamo_use_fullgraph": false,
|
||||||
"enable_bucket": true,
|
"enable_bucket": true,
|
||||||
"epoch": 1,
|
"epoch": 1,
|
||||||
|
"extra_accelerate_launch_args": "",
|
||||||
"flip_aug": false,
|
"flip_aug": false,
|
||||||
"full_bf16": false,
|
"full_bf16": false,
|
||||||
"full_fp16": false,
|
"full_fp16": false,
|
||||||
|
"fused_backward_pass": false,
|
||||||
|
"fused_optimizer_groups": 0,
|
||||||
"gpu_ids": "",
|
"gpu_ids": "",
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"gradient_checkpointing": false,
|
"gradient_checkpointing": false,
|
||||||
|
"huber_c": 0.1,
|
||||||
|
"huber_schedule": "snr",
|
||||||
|
"huggingface_path_in_repo": "",
|
||||||
|
"huggingface_repo_id": "",
|
||||||
|
"huggingface_repo_type": "",
|
||||||
|
"huggingface_repo_visibility": "",
|
||||||
|
"huggingface_token": "",
|
||||||
"ip_noise_gamma": 0,
|
"ip_noise_gamma": 0,
|
||||||
"ip_noise_gamma_random_strength": false,
|
"ip_noise_gamma_random_strength": false,
|
||||||
"keep_tokens": "0",
|
"keep_tokens": 0,
|
||||||
"learning_rate": 5e-05,
|
"learning_rate": 5e-05,
|
||||||
"learning_rate_te": 1e-05,
|
"learning_rate_te": 1e-05,
|
||||||
"learning_rate_te1": 1e-05,
|
"learning_rate_te1": 1e-05,
|
||||||
"learning_rate_te2": 1e-05,
|
"learning_rate_te2": 1e-05,
|
||||||
|
"log_config": false,
|
||||||
"log_tracker_config": "",
|
"log_tracker_config": "",
|
||||||
"log_tracker_name": "",
|
"log_tracker_name": "",
|
||||||
|
"log_with": "",
|
||||||
"logging_dir": "./test/logs",
|
"logging_dir": "./test/logs",
|
||||||
|
"loss_type": "l2",
|
||||||
"lr_scheduler": "constant",
|
"lr_scheduler": "constant",
|
||||||
"lr_scheduler_args": "",
|
"lr_scheduler_args": "T_max=100",
|
||||||
"lr_scheduler_num_cycles": "",
|
"lr_scheduler_num_cycles": 1,
|
||||||
"lr_scheduler_power": "",
|
"lr_scheduler_power": 1,
|
||||||
|
"lr_scheduler_type": "CosineAnnealingLR",
|
||||||
"lr_warmup": 0,
|
"lr_warmup": 0,
|
||||||
"main_process_port": 12345,
|
"main_process_port": 12345,
|
||||||
"masked_loss": false,
|
"masked_loss": false,
|
||||||
"max_bucket_reso": 2048,
|
"max_bucket_reso": 2048,
|
||||||
"max_data_loader_n_workers": "0",
|
"max_data_loader_n_workers": 0,
|
||||||
"max_resolution": "512,512",
|
"max_resolution": "512,512",
|
||||||
"max_timestep": 1000,
|
"max_timestep": 1000,
|
||||||
"max_token_length": "75",
|
"max_token_length": 75,
|
||||||
"max_train_epochs": "",
|
"max_train_epochs": 0,
|
||||||
"max_train_steps": "",
|
"max_train_steps": 0,
|
||||||
"mem_eff_attn": false,
|
"mem_eff_attn": false,
|
||||||
|
"metadata_author": "",
|
||||||
|
"metadata_description": "",
|
||||||
|
"metadata_license": "",
|
||||||
|
"metadata_tags": "",
|
||||||
|
"metadata_title": "",
|
||||||
"min_bucket_reso": 256,
|
"min_bucket_reso": 256,
|
||||||
"min_snr_gamma": 0,
|
"min_snr_gamma": 0,
|
||||||
"min_timestep": 0,
|
"min_timestep": 0,
|
||||||
|
|
@ -65,14 +91,16 @@
|
||||||
"output_name": "db-AdamW8bit-toml",
|
"output_name": "db-AdamW8bit-toml",
|
||||||
"persistent_data_loader_workers": false,
|
"persistent_data_loader_workers": false,
|
||||||
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
|
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
|
||||||
"prior_loss_weight": 1.0,
|
"prior_loss_weight": 1,
|
||||||
"random_crop": false,
|
"random_crop": false,
|
||||||
"reg_data_dir": "",
|
"reg_data_dir": "",
|
||||||
"resume": "",
|
"resume": "",
|
||||||
|
"resume_from_huggingface": "",
|
||||||
"sample_every_n_epochs": 0,
|
"sample_every_n_epochs": 0,
|
||||||
"sample_every_n_steps": 25,
|
"sample_every_n_steps": 25,
|
||||||
"sample_prompts": "a painting of a gas mask , by darius kawasaki",
|
"sample_prompts": "a painting of a gas mask , by darius kawasaki",
|
||||||
"sample_sampler": "euler_a",
|
"sample_sampler": "euler_a",
|
||||||
|
"save_as_bool": false,
|
||||||
"save_every_n_epochs": 1,
|
"save_every_n_epochs": 1,
|
||||||
"save_every_n_steps": 0,
|
"save_every_n_steps": 0,
|
||||||
"save_last_n_steps": 0,
|
"save_last_n_steps": 0,
|
||||||
|
|
@ -81,14 +109,16 @@
|
||||||
"save_precision": "fp16",
|
"save_precision": "fp16",
|
||||||
"save_state": false,
|
"save_state": false,
|
||||||
"save_state_on_train_end": false,
|
"save_state_on_train_end": false,
|
||||||
|
"save_state_to_huggingface": false,
|
||||||
"scale_v_pred_loss_like_noise_pred": false,
|
"scale_v_pred_loss_like_noise_pred": false,
|
||||||
"sdxl": false,
|
"sdxl": false,
|
||||||
"seed": "1234",
|
"sdxl_cache_text_encoder_outputs": false,
|
||||||
|
"sdxl_no_half_vae": false,
|
||||||
|
"seed": 1234,
|
||||||
"shuffle_caption": false,
|
"shuffle_caption": false,
|
||||||
"stop_text_encoder_training": 0,
|
"stop_text_encoder_training": 0,
|
||||||
"train_batch_size": 4,
|
"train_batch_size": 4,
|
||||||
"train_data_dir": "",
|
"train_data_dir": "",
|
||||||
"use_wandb": false,
|
|
||||||
"v2": false,
|
"v2": false,
|
||||||
"v_parameterization": false,
|
"v_parameterization": false,
|
||||||
"v_pred_like_loss": 0,
|
"v_pred_like_loss": 0,
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
solo,simple background,teeth,grey background,from side,no humans,mask,1other,science fiction,cable,gas mask,tube,steampunk,machine
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
no humans,what
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
1girl,solo,nude,colored skin,monster,blue skin
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
solo,upper body,horns,from side,no humans,blood,1other
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
solo,1boy,male focus,mask,instrument,science fiction,realistic,music,gas mask
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
solo,no humans,mask,helmet,robot,mecha,1other,science fiction,damaged,gas mask,steampunk
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
solo,from side,no humans,mask,moon,helmet,portrait,1other,ambiguous gender,gas mask
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
outdoors,sky,cloud,no humans,monster,realistic,desert
|
|
||||||
Loading…
Reference in New Issue