mirror of https://github.com/bmaltais/kohya_ss
v25.0.0 release (#3138)
* Add support for custom learning rate scheduler type to the GUI * Add .webp image extension support to BLIP2 captioning. * Check for --debug flag for gui command-line args at startup * Validate GPU ID accelerate input and return error when needed * Update to latest sd-scripts dev commit * Fix issue with pip upgrade * Remove confusing log after command execution. * piecewise_constant scheduler * Update to latest sd-scripts dev commit * fix: fixed docker-compose for passing models via volumes * Prevent providing the legacy learning_rate if unet or te learning rate is provided * Fix toml noise offset parameters based on selected type * Fix adaptive_noise_scale value not properly loading from json config * Fix prompt.txt location * Improve "print command" output format * Use output model name as wandb run name if not provided * Update sd-scripts dev release * Bump crate-ci/typos from 1.21.0 to 1.22.9 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.21.0 to 1.22.9. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.21.0...v1.22.9) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Bump docker/build-push-action from 5 to 6 Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 5 to 6. - [Release notes](https://github.com/docker/build-push-action/releases) - [Commits](https://github.com/docker/build-push-action/compare/v5...v6) --- updated-dependencies: - dependency-name: docker/build-push-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <support@github.com> * Get latest sd3 code * Adding SD3 GUI elements * Fix interactivity * MVP GUI for SD3 * Fix text encoder issue * Add fork section to readme * Update sd3 commit * Merge security-fix * Update sc-script to latest code * Auto-detect model type for safetensors files Automatically tick the checkboxes for v2 and SDXL on the common training UI and LoRA extract/merge utilities. * autodetect-modeltype: remove unused lambda inputs * rework TE1/TE2 learning rate handling for SDXL dreambooth SDXL dreambooth apparently trains without the text encoders by default, requiring the `--train_text_encoder` flag to be passed so that the learning rates for TE1/TE2 are recognized. The toml handling now permits 0 to be passed as a learning rate in order to disable training of one or both text encoders. This behavior aligns with the description given on the GUI. TE1/TE2 learning rate parameters can be left blank on the GUI to not pass a value to the training script. * dreambooth_gui: fix toml value filtering condition In python3, `0 == False` will evaluate True. That can cause arg values of 0 to be wrongly eliminated from the toml output. The conditional must check the type when comparing for False. * autodetect-modeltype: also do the v2 checkbox in extract_lora * Update to latest dev branch code * bring back SDXLConfig accordion for dreambooth gui (#2694) b-fission <b-fission@users.noreply.github.com> * Update to latest sd3 branch commit * Fix merge issue * Update gradio version * Update to latest flux.1 code * Add Flux.1 Model checkbox and detection * Adding LoRA type "Flux1" to dropdown * Added Flux.1 parameters to GUI * Update sd-scripts and requirements * Add missing Flux.1 GUI parameters * Update to latest sd-scripts sd3 code * Fix issue with cache_text_encoder_outputs * Update to latest sd-scripts flux1 code * Adding new flux.1 options to GUI * Update to latest sd-scripts version of flux.1 * Adding guidance_scale option * Update to latest sd3 flux.1 sd-scripts * Add dreambooth and finetuning support for flux.1 * Update README * Fix t5xxl path issue in DB * add missing fp8_base parameter * Fix issue with guidance scale not being passed as float for values like 1 * Temporary fir for blockwise_fused_optimizers * Update to latest sd-scripts Flux.1 code * Fix blockwise_fused_optimizers typo * Add mem_eff_save option to GUI for Flux.1 * Added support for Flux.1 LoRA Merge * Update to latest sd-scripts sd3 branch code * Add diffusers option to flux.1 merge LoRA utility * Fix issue with split_mode and train_blocks * Updating requirements * Add flux_fused_backward_pass to dreambooth and finetuning * Update requirements_linux_docker.txt update accelerate version for linux_docker * Update to latest sd3 flux code * Add extract flux lora GUI * MErged latest sd3 branch code * Add support for split_qkv * Add missing network argument for split_qkv * Add timestep_sampling shift support * Update to latest sd-scripts flux.1 code * Add support for fp8_base_unet * Update requirements as per sd-scripts suggestion * Upgrade to cu124 * Update IPEX and ROCm * Fix issue with balancing when folder with name already exist * Update sd-scripts * Removed unsupported parameters from flux lora network * Bump crate-ci/typos from 1.23.6 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update sd-scripts code * Adding flux_shift option to timestep_sampling * Update sd-scripts release * Add support for Train T5-XXL * Update sd-scripts submodule * Add support for cpu_offload_checkpointing to GUI * Force t5xxl_max_token_length to be served as an integer * Fix typo for flux_shift * Update to latest sd-scripts code * Grouping lora parameters * Validate if lora type is Flux1 when flux1_checkbox is true * Improve visual sectioning of parameters for lora * Add dark mode styles * Missed one color * Update sd-scripts and add support for t5xxl LR * Update transformers and wandb module * Fix issue with new text_encoder_lr parameter syntax * Add support for lr_warmup_steps override * Update lr_warmup_steps code * Removing stable-diffusion-1.5 default model * Fix for max_train_steps * Revert some changes * Preliminary support for Flux1 OFT * Fix logic typo * Update sd-scripts * Add support for Rank for layers * Update lora_gui.py Fixed minor typos of "Regularization" * Update dreambooth_gui.py Fixed minor typos of "Regularization" * Update textual_inversion_gui.py Fixed minor typos of "Regularization" * Add support for Blocks to train * Add missing network parms * Fix issue with old_lr_warmup_steps * Update sd-scripts * Add support for ScheduleFree Optimizer Type * Update sd-scripts * Update requirements_pytorch_windows.txt * Update requirements_pytorch_windows.txt * Update sd-scripts from origin * Another sd-script update * Adding support for blocks_to_swap option to gui * Fix xformers install issue * feat(docker): mount models folder as a volume * feat(docker): add models folder to .dockerignore * Add support for AdEMAMix8bit optimizer * Bump crate-ci/typos from 1.23.6 to 1.25.0 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.25.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.25.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Fix typo on README.md * Add new --noverify option to skip requirements validation on startup * Update startup GUI code * Update setup code * Update sd-scripts * Update sf-scripts * Update Lycoris support * Allow to specify tensorboard host via env var TENSORBOARD_HOST * Update sd-scripts version * Update sd-scripts release * Update sd-scripts * Add --skip_cache_check option to GUI * Fix requirements issue * Add support for LyCORIS LoRA when training Flux.1 * Pin huggingface-hub version for gradio 5 * Update sd-scripts * Add support for --save_last_n_epochs_state * Update sd-scripts to version with Differential Output Preservation * Increase maximum flux-lora merge strength to 2 * Update to latest sd-scripts * Update requirements syntax (for windows) * Update requirements for linux * Update torch version and validation output * Fix typo * Update README * Fix validation issue on linux * Update sd-scripts, improve requirements outputs * Update requirements_runpod.txt * Update requirements for onnxruntime-gpu Needed for compatibility with CUDA 12. * Update onnxruntime-gpu==1.19.2 * Update sd-scripts release * Add support for save_last_n_epochs * Update sd-scripts * Bump crate-ci/typos from 1.23.6 to 1.26.8 (#2940) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.26.8. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.26.8) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * fix 'cached_download' from 'huggingface_hub' (#2947) Describe the bug: cannot import name 'cached_download' from 'huggingface_hub' It's applyed for all platforms Co-authored-by: bmaltais <bernard@ducourier.com> * Add support for quiet output for linux setup * Fix quiet issue * Update sd-scripts * Update sd-scripts with blocks_to_swap support * Make blocks_to_swap visible in LoRA tab * Fix blocks_to_swap not properly working * Update sd-scripts and allow python 3.10 to 3.12 * Fix issue with max_train_steps * Fix max_train_steps_info error * Reverting all changes for max_train_steps * Update sd-scripts * Update sd-scripts * Update to latest sd-scripts * Add support for RAdamScheduleFree * Add support for huber_scale * Add support for fused_backward_pass for sd3 finetuning * Add support for prodigyplus.ProdigyPlusScheduleFree * SD3 LoRA training MVP * Make blocks_to_swap common * Add support for sd3 lora disable_mmap_load_safetensors * Add a bunch of missing SD3 parameters * Fix clip_l issue for missing path * Fix train_t5xxl issue * Fix network_module issue * Add uniform to weighting_scheme * Bump crate-ci/typos from 1.23.6 to 1.28.1 (#2996) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.28.1. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.28.1) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * Update README.md (#3031) * Bump crate-ci/typos from 1.23.6 to 1.29.0 (#3029) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.29.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.29.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * Update sd-scripts version * Update setup.sh (#3054) Enter the current directory before executing setup.sh, otherwise the installer might failed to find rqeuirements.txt * Removing wrong folder * Fix issue with SD3 Lora training blocks_to_swap and fused_backward_pass * Fix dreambooth issue * Update to lastest sd-scripts code * Run on novita (#3119) (#3120) * add run on novita * adjust position Co-authored-by: hugo <liyiligang@users.noreply.github.com> * Bump crate-ci/typos from 1.23.6 to 1.30.0 (#3101) Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.30.0. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.23.6...v1.30.0) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: bmaltais <bernard@ducourier.com> * updated prodigyopt to 1.1.2 and removed duplicated row in requirements.txt (#3065) * fixed names on LR Schedure dropdown (#3064) * Update to latest sd-scripts version * fixed names on LR Schedure dropdown (#3064) * Cleanup venv3 * Fix issue with gradio on new installations Add support for latest sd-scripts pytorch-optimizer * Update README for v25.0.0 release --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: b-fission <b-fission@users.noreply.github.com> Co-authored-by: DevArqSangoi <lucas.sangoi@gmail.com> Co-authored-by: Кирилл Москвин <retreat.cost@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: b-fission <131207849+b-fission@users.noreply.github.com> Co-authored-by: eftSharptooth <76253264+eftSharptooth@users.noreply.github.com> Co-authored-by: Disty0 <disty@disty.xyz> Co-authored-by: wcole3 <will.cole3@gmail.com> Co-authored-by: rohitanshu <85547195+iamrohitanshu@users.noreply.github.com> Co-authored-by: wzgrx <39661556+wzgrx@users.noreply.github.com> Co-authored-by: Vladimir Sotnikov <vladimir.s@alphakek.ai> Co-authored-by: bulieme0 <53142287+bulieme@users.noreply.github.com> Co-authored-by: Nicolas Pereira <41456803+hqnicolas@users.noreply.github.com> Co-authored-by: ruucm <ruucm.a@gmail.com> Co-authored-by: CaledoniaProject <CaledoniaProject@users.noreply.github.com> Co-authored-by: hugo <liyiligang@users.noreply.github.com> Co-authored-by: Koro <Koronos@users.noreply.github.com>pull/3051/head v25.0.0
parent
a1b16e44f0
commit
ed55e81997
|
|
@ -3,6 +3,7 @@ cudnn_windows/
|
|||
bitsandbytes_windows/
|
||||
bitsandbytes_windows_deprecated/
|
||||
dataset/
|
||||
models/
|
||||
__pycache__/
|
||||
venv/
|
||||
**/.hadolint.yml
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ jobs:
|
|||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
id: publish
|
||||
with:
|
||||
context: .
|
||||
|
|
|
|||
|
|
@ -18,4 +18,4 @@ jobs:
|
|||
- uses: actions/checkout@v4
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.23.6
|
||||
uses: crate-ci/typos@v1.30.0
|
||||
|
|
|
|||
|
|
@ -52,3 +52,5 @@ models
|
|||
data
|
||||
config.toml
|
||||
sd-scripts
|
||||
venv
|
||||
venv*
|
||||
86
README.md
86
README.md
|
|
@ -48,13 +48,20 @@ The GUI allows you to set the training parameters and generate and run the requi
|
|||
- [Potential Solutions](#potential-solutions)
|
||||
- [SDXL training](#sdxl-training)
|
||||
- [Masked loss](#masked-loss)
|
||||
- [Guides](#guides)
|
||||
- [Using Accelerate Lora Tab to Select GPU ID](#using-accelerate-lora-tab-to-select-gpu-id)
|
||||
- [Starting Accelerate in GUI](#starting-accelerate-in-gui)
|
||||
- [Running Multiple Instances (linux)](#running-multiple-instances-linux)
|
||||
- [Monitoring Processes](#monitoring-processes)
|
||||
- [Interesting Forks](#interesting-forks)
|
||||
- [Change History](#change-history)
|
||||
- [v25.0.0](#v2500)
|
||||
|
||||
## 🦒 Colab
|
||||
|
||||
This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: <https://github.com/camenduru/kohya_ss-colab>.
|
||||
|
||||
I would like to express my gratitude to camendutu for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository.
|
||||
I would like to express my gratitude to camenduru for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository.
|
||||
|
||||
| Colab | Info |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ |
|
||||
|
|
@ -71,7 +78,7 @@ To install the necessary dependencies on a Windows system, follow these steps:
|
|||
1. Install [Python 3.10.11](https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe).
|
||||
- During the installation process, ensure that you select the option to add Python to the 'PATH' environment variable.
|
||||
|
||||
2. Install [CUDA 11.8 toolkit](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows&target_arch=x86_64).
|
||||
2. Install [CUDA 12.4 toolkit](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Windows&target_arch=x86_64).
|
||||
|
||||
3. Install [Git](https://git-scm.com/download/win).
|
||||
|
||||
|
|
@ -129,7 +136,7 @@ To install the necessary dependencies on a Linux system, ensure that you fulfill
|
|||
apt install python3.10-venv
|
||||
```
|
||||
|
||||
- Install the CUDA 11.8 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64).
|
||||
- Install the CUDA 12.4 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-12-4-0-download-archive?target_os=Linux&target_arch=x86_64).
|
||||
|
||||
- Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system.
|
||||
|
||||
|
|
@ -179,7 +186,7 @@ If you choose to use the interactive mode, the default values for the accelerate
|
|||
|
||||
To install the necessary components for Runpod and run kohya_ss, follow these steps:
|
||||
|
||||
1. Select the Runpod pytorch 2.0.1 template. This is important. Other templates may not work.
|
||||
1. Select the Runpod pytorch 2.2.0 template. This is important. Other templates may not work.
|
||||
|
||||
2. SSH into the Runpod.
|
||||
|
||||
|
|
@ -222,7 +229,9 @@ To run from a pre-built Runpod template, you can:
|
|||
3. Once deployed, connect to the Runpod on HTTP 3010 to access the kohya_ss GUI. You can also connect to auto1111 on HTTP 3000.
|
||||
|
||||
### Novita
|
||||
|
||||
#### Pre-built Novita template
|
||||
|
||||
1. Open the Novita template by clicking on <https://novita.ai/gpus-console?templateId=312>.
|
||||
|
||||
2. Deploy the template on the desired host.
|
||||
|
|
@ -339,13 +348,27 @@ To upgrade your installation on Linux or macOS, follow these steps:
|
|||
To launch the GUI service, you can use the provided scripts or run the `kohya_gui.py` script directly. Use the command line arguments listed below to configure the underlying service.
|
||||
|
||||
```text
|
||||
--listen: Specify the IP address to listen on for connections to Gradio.
|
||||
--username: Set a username for authentication.
|
||||
--password: Set a password for authentication.
|
||||
--server_port: Define the port to run the server listener on.
|
||||
--inbrowser: Open the Gradio UI in a web browser.
|
||||
--share: Share the Gradio UI.
|
||||
--language: Set custom language
|
||||
--help show this help message and exit
|
||||
--config CONFIG Path to the toml config file for interface defaults
|
||||
--debug Debug on
|
||||
--listen LISTEN IP to listen on for connections to Gradio
|
||||
--username USERNAME Username for authentication
|
||||
--password PASSWORD Password for authentication
|
||||
--server_port SERVER_PORT
|
||||
Port to run the server listener on
|
||||
--inbrowser Open in browser
|
||||
--share Share the gradio UI
|
||||
--headless Is the server headless
|
||||
--language LANGUAGE Set custom language
|
||||
--use-ipex Use IPEX environment
|
||||
--use-rocm Use ROCm environment
|
||||
--do_not_use_shell Enforce not to use shell=True when running external commands
|
||||
--do_not_share Do not share the gradio UI
|
||||
--requirements REQUIREMENTS
|
||||
requirements file to use for validation
|
||||
--root_path ROOT_PATH
|
||||
`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss
|
||||
--noverify Disable requirements verification
|
||||
```
|
||||
|
||||
### Launching the GUI on Windows
|
||||
|
|
@ -448,6 +471,45 @@ The feature is not fully tested, so there may be bugs. If you find any issues, p
|
|||
|
||||
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
||||
|
||||
## Guides
|
||||
|
||||
The following are guides extracted from issues discussions
|
||||
|
||||
### Using Accelerate Lora Tab to Select GPU ID
|
||||
|
||||
#### Starting Accelerate in GUI
|
||||
|
||||
- Open the kohya GUI on your desired port.
|
||||
- Open the `Accelerate launch` tab
|
||||
- Ensure the Multi-GPU checkbox is unchecked.
|
||||
- Set GPU IDs to the desired GPU (like 1).
|
||||
|
||||
#### Running Multiple Instances (linux)
|
||||
|
||||
- For tracking multiple processes, use separate kohya GUI instances on different ports (e.g., 7860, 7861).
|
||||
- Start instances using `nohup ./gui.sh --listen 0.0.0.0 --server_port <port> --headless > log.log 2>&1 &`.
|
||||
|
||||
#### Monitoring Processes
|
||||
|
||||
- Open each GUI in a separate browser tab.
|
||||
- For terminal access, use SSH and tools like `tmux` or `screen`.
|
||||
|
||||
For more details, visit the [GitHub issue](https://github.com/bmaltais/kohya_ss/issues/2577).
|
||||
|
||||
## Interesting Forks
|
||||
|
||||
To finetune HunyuanDiT models or create LoRAs, visit this [fork](https://github.com/Tencent/HunyuanDiT/tree/main/kohya_ss-hydit)
|
||||
|
||||
## Change History
|
||||
|
||||
See release information.
|
||||
### v25.0.0
|
||||
|
||||
This is a SIGNIFICANT upgrade. I am groing in uncharted territory here because kohya has not merged any of the recent flux.1 and sd3 updated to his code in his main branch yet... but I feel updates in his code has pretty much dried down and I think his code is probably ready for prime time. So instead of keeping my GUI in the cave man ages, I am opting to move the code for the GUI with support for flux.1 and sd3 to the main branch of my project. Perhaps this will bite me in the proverbias ass... but for those who would rather stay on the older pre "flux.1 and sd3" updates, you can always do:
|
||||
|
||||
```shell
|
||||
git checkout v24.1.7
|
||||
```
|
||||
|
||||
after cloning the repo.
|
||||
|
||||
For all the info regarding the new flux.1 and sd3 parameters, see <https://github.com/kohya-ss/sd-scripts/blob/sd3/README.md> for more details.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ parms="parms"
|
|||
nin="nin"
|
||||
extention="extention" # Intentionally left
|
||||
nd="nd"
|
||||
pn="pn"
|
||||
shs="shs"
|
||||
sts="sts"
|
||||
scs="scs"
|
||||
|
|
|
|||
217
assets/style.css
217
assets/style.css
|
|
@ -1,4 +1,4 @@
|
|||
#open_folder_small{
|
||||
#open_folder_small {
|
||||
min-width: auto;
|
||||
flex-grow: 0;
|
||||
padding-left: 0.25em;
|
||||
|
|
@ -7,14 +7,14 @@
|
|||
font-size: 1.5em;
|
||||
}
|
||||
|
||||
#open_folder{
|
||||
#open_folder {
|
||||
height: auto;
|
||||
flex-grow: 0;
|
||||
padding-left: 0.25em;
|
||||
padding-right: 0.25em;
|
||||
}
|
||||
|
||||
#number_input{
|
||||
#number_input {
|
||||
min-width: min-content;
|
||||
flex-grow: 0.3;
|
||||
padding-left: 0.75em;
|
||||
|
|
@ -22,7 +22,7 @@
|
|||
}
|
||||
|
||||
.ver-class {
|
||||
color: #808080;
|
||||
color: #6d6d6d; /* Neutral dark gray */
|
||||
font-size: small;
|
||||
text-align: right;
|
||||
padding-right: 1em;
|
||||
|
|
@ -35,13 +35,212 @@
|
|||
}
|
||||
|
||||
#myTensorButton {
|
||||
background: radial-gradient(ellipse, #3a99ff, #52c8ff);
|
||||
background: #555c66; /* Muted dark gray */
|
||||
color: white;
|
||||
border: #296eb8;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.5em 1em;
|
||||
/* box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); Subtle shadow */
|
||||
/* transition: box-shadow 0.3s ease; */
|
||||
}
|
||||
|
||||
#myTensorButton:hover {
|
||||
/* box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); Slightly increased shadow on hover */
|
||||
}
|
||||
|
||||
#myTensorButtonStop {
|
||||
background: radial-gradient(ellipse, #52c8ff, #3a99ff);
|
||||
color: black;
|
||||
border: #296eb8;
|
||||
background: #777d85; /* Lighter muted gray */
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
padding: 0.5em 1em;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
/* transition: box-shadow 0.3s ease; */
|
||||
}
|
||||
|
||||
#myTensorButtonStop:hover {
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.advanced_background {
|
||||
background: #f4f4f4; /* Light neutral gray */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */
|
||||
}
|
||||
|
||||
.advanced_background:hover {
|
||||
background-color: #ebebeb; /* Slightly darker background on hover */
|
||||
border: 1px solid #ccc; /* Add a subtle border */
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.basic_background {
|
||||
background: #eaeff1; /* Muted cool gray */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.basic_background:hover {
|
||||
background-color: #dfe4e7; /* Slightly darker cool gray on hover */
|
||||
border: 1px solid #ccc;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.huggingface_background {
|
||||
background: #e0e4e7; /* Light gray with a hint of blue */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.huggingface_background:hover {
|
||||
background-color: #d6dce0; /* Slightly darker on hover */
|
||||
border: 1px solid #bbb;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.flux1_background {
|
||||
background: #ece9e6; /* Light beige tone */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.flux1_background:hover {
|
||||
background-color: #e2dfdb; /* Slightly darker beige on hover */
|
||||
border: 1px solid #ccc;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.preset_background {
|
||||
background: #f0f0f0; /* Light gray */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.preset_background:hover {
|
||||
background-color: #e6e6e6; /* Slightly darker on hover */
|
||||
border: 1px solid #ccc;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.samples_background {
|
||||
background: #d9dde1; /* Soft muted gray-blue */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.samples_background:hover {
|
||||
background-color: #cfd3d8; /* Slightly darker on hover */
|
||||
border: 1px solid #bbb;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
.dark .advanced_background {
|
||||
background: #172029; /* Slightly darker gradio dark theme */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; /* Added transition for smooth shadow effect */
|
||||
}
|
||||
|
||||
.dark .advanced_background:hover {
|
||||
background-color: #121920; /* Slightly darker background on hover */
|
||||
border: 1px solid #000000; /* Add a subtle border */
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.dark .basic_background {
|
||||
background: #172029;
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .basic_background:hover {
|
||||
background-color: #11181e;
|
||||
border: 1px solid #000000;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.dark .huggingface_background {
|
||||
background: #131c25;
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .huggingface_background:hover {
|
||||
background-color: #131c25;
|
||||
border: 1px solid #000000;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.dark .flux1_background {
|
||||
background: #131c25;
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .flux1_background:hover {
|
||||
background-color: #131c25;
|
||||
border: 1px solid #000000;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.dark .preset_background {
|
||||
background: #191d25;
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .preset_background:hover {
|
||||
background-color: #212530;
|
||||
border: 1px solid #000000;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.dark .samples_background {
|
||||
background: #101e2c;
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .samples_background:hover {
|
||||
background-color: #17293a;
|
||||
border: 1px solid #000000;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.flux1_rank_layers_background {
|
||||
background: #ece9e6; /* White background for clear theme */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.flux1_rank_layers_background:hover {
|
||||
background-color: #dddad7; /* Slightly darker on hover */
|
||||
border: 1px solid #ccc;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
||||
.dark .flux1_rank_layers_background {
|
||||
background: #131c25; /* Dark background for dark theme */
|
||||
padding: 1em;
|
||||
border-radius: 8px;
|
||||
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.dark .flux1_rank_layers_background:hover {
|
||||
background-color: #131c25; /* Slightly darker on hover */
|
||||
border: 1px solid #000000;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
|
||||
}
|
||||
|
|
@ -48,6 +48,7 @@ learning_rate_te1 = 0.0001 # Learning rate text encoder 1
|
|||
learning_rate_te2 = 0.0001 # Learning rate text encoder 2
|
||||
lr_scheduler = "cosine" # LR Scheduler
|
||||
lr_scheduler_args = "" # LR Scheduler args
|
||||
lr_scheduler_type = "" # LR Scheduler type
|
||||
lr_warmup = 0 # LR Warmup (% of total steps)
|
||||
lr_scheduler_num_cycles = 1 # LR Scheduler num cycles
|
||||
lr_scheduler_power = 1.0 # LR Scheduler power
|
||||
|
|
@ -150,6 +151,9 @@ sample_prompts = "" # Sample prompts
|
|||
sample_sampler = "euler_a" # Sampler to use for image sampling
|
||||
|
||||
[sdxl]
|
||||
disable_mmap_load_safetensors = false # Disable mmap load safe tensors
|
||||
fused_backward_pass = false # Fused backward pass
|
||||
fused_optimizer_groups = 0 # Fused optimizer groups
|
||||
sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs
|
||||
sdxl_no_half_vae = true # No half VAE
|
||||
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ services:
|
|||
- /tmp
|
||||
volumes:
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix
|
||||
- ./models:/app/models
|
||||
- ./dataset:/dataset
|
||||
- ./dataset/images:/app/data
|
||||
- ./dataset/logs:/app/logs
|
||||
- ./dataset/outputs:/app/outputs
|
||||
- ./dataset/regularization:/app/regularization
|
||||
- ./models:/app/models
|
||||
- ./.cache/config:/app/config
|
||||
- ./.cache/user:/home/1000/.cache
|
||||
- ./.cache/triton:/home/1000/.triton
|
||||
|
|
|
|||
|
|
@ -1,32 +1,27 @@
|
|||
## Updating a Local Branch with the Latest sd-scripts Changes
|
||||
## Updating a Local Submodule with the Latest sd-scripts Changes
|
||||
|
||||
To update your local branch with the most recent changes from kohya/sd-scripts, follow these steps:
|
||||
|
||||
1. Add sd-scripts as an alternative remote by executing the following command:
|
||||
1. When you wish to perform an update of the dev branch, execute the following commands:
|
||||
|
||||
```
|
||||
git remote add sd-scripts https://github.com/kohya-ss/sd-scripts.git
|
||||
```
|
||||
|
||||
2. When you wish to perform an update, execute the following commands:
|
||||
|
||||
```
|
||||
```bash
|
||||
cd sd-scripts
|
||||
git fetch
|
||||
git checkout dev
|
||||
git pull sd-scripts main
|
||||
git pull origin dev
|
||||
cd ..
|
||||
git add sd-scripts
|
||||
git commit -m "Update sd-scripts submodule to the latest on dev"
|
||||
```
|
||||
|
||||
Alternatively, if you want to obtain the latest code, even if it may be unstable:
|
||||
Alternatively, if you want to obtain the latest code from main:
|
||||
|
||||
```bash
|
||||
cd sd-scripts
|
||||
git fetch
|
||||
git checkout main
|
||||
git pull origin main
|
||||
cd ..
|
||||
git add sd-scripts
|
||||
git commit -m "Update sd-scripts submodule to the latest on main"
|
||||
```
|
||||
git checkout dev
|
||||
git pull sd-scripts dev
|
||||
```
|
||||
|
||||
3. If you encounter a conflict with the Readme file, you can resolve it by taking the following steps:
|
||||
|
||||
```
|
||||
git add README.md
|
||||
git merge --continue
|
||||
```
|
||||
|
||||
This may open a text editor for a commit message, but you can simply save and close it to proceed. Following these steps should resolve the conflict. If you encounter additional merge conflicts, consider them as valuable learning opportunities for personal growth.
|
||||
8
gui.bat
8
gui.bat
|
|
@ -7,11 +7,13 @@ call .\venv\Scripts\deactivate.bat
|
|||
|
||||
:: Activate the virtual environment
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
||||
:: Update pip to latest version
|
||||
python -m pip install --upgrade pip -q
|
||||
|
||||
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
|
||||
|
||||
:: Validate requirements
|
||||
python.exe .\setup\validate_requirements.py
|
||||
if %errorlevel% neq 0 exit /b %errorlevel%
|
||||
echo Starting the GUI... this might take some time...
|
||||
|
||||
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
|
||||
if %errorlevel% equ 0 (
|
||||
|
|
|
|||
30
gui.ps1
30
gui.ps1
|
|
@ -7,28 +7,18 @@ if ($env:VIRTUAL_ENV) {
|
|||
# Activate the virtual environment
|
||||
# Write-Host "Activating the virtual environment..."
|
||||
& .\venv\Scripts\activate
|
||||
|
||||
python.exe -m pip install --upgrade pip -q
|
||||
|
||||
$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib"
|
||||
|
||||
# Debug info about system
|
||||
# python.exe .\setup\debug_info.py
|
||||
Write-Host "Starting the GUI... this might take some time..."
|
||||
|
||||
# Validate the requirements and store the exit code
|
||||
python.exe .\setup\validate_requirements.py
|
||||
|
||||
# Check the exit code and stop execution if it is not 0
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "Failed to validate requirements. Exiting script..."
|
||||
exit $LASTEXITCODE
|
||||
$argsFromFile = @()
|
||||
if (Test-Path .\gui_parameters.txt) {
|
||||
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
|
||||
}
|
||||
$args_combo = $argsFromFile + $args
|
||||
# Write-Host "The arguments passed to this script were: $args_combo"
|
||||
python.exe kohya_gui.py $args_combo
|
||||
|
||||
# If the exit code is 0, read arguments from gui_parameters.txt (if it exists)
|
||||
# and run the kohya_gui.py script with the command-line arguments
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
$argsFromFile = @()
|
||||
if (Test-Path .\gui_parameters.txt) {
|
||||
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
|
||||
}
|
||||
$args_combo = $argsFromFile + $args
|
||||
# Write-Host "The arguments passed to this script were: $args_combo"
|
||||
python.exe kohya_gui.py $args_combo
|
||||
}
|
||||
|
|
|
|||
8
gui.sh
8
gui.sh
|
|
@ -111,10 +111,4 @@ then
|
|||
STARTUP_CMD=python
|
||||
fi
|
||||
|
||||
# Validate the requirements and run the script if successful
|
||||
if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then
|
||||
"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "$@"
|
||||
else
|
||||
echo "Validation failed. Exiting..."
|
||||
exit 1
|
||||
fi
|
||||
"${STARTUP_CMD}" $STARTUP_CMD_ARGS "$SCRIPT_DIR/kohya_gui.py" "--requirements=""$REQUIREMENTS_FILE" "$@"
|
||||
|
|
|
|||
235
kohya_gui.py
235
kohya_gui.py
|
|
@ -1,6 +1,10 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
import contextlib
|
||||
import gradio as gr
|
||||
|
||||
from kohya_gui.class_gui_config import KohyaSSGUIConfig
|
||||
from kohya_gui.dreambooth_gui import dreambooth_tab
|
||||
from kohya_gui.finetune_gui import finetune_tab
|
||||
|
|
@ -8,71 +12,43 @@ from kohya_gui.textual_inversion_gui import ti_tab
|
|||
from kohya_gui.utilities import utilities_tab
|
||||
from kohya_gui.lora_gui import lora_tab
|
||||
from kohya_gui.class_lora_tab import LoRATools
|
||||
|
||||
from kohya_gui.custom_logging import setup_logging
|
||||
from kohya_gui.localization_ext import add_javascript
|
||||
|
||||
PYTHON = sys.executable
|
||||
project_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def UI(**kwargs):
|
||||
add_javascript(kwargs.get("language"))
|
||||
css = ""
|
||||
# Function to read file content, suppressing any FileNotFoundError
|
||||
def read_file_content(file_path):
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
with open(file_path, "r", encoding="utf8") as file:
|
||||
return file.read()
|
||||
return ""
|
||||
|
||||
headless = kwargs.get("headless", False)
|
||||
log.info(f"headless: {headless}")
|
||||
# Function to initialize the Gradio UI interface
|
||||
def initialize_ui_interface(config, headless, use_shell, release_info, readme_content):
|
||||
# Load custom CSS if available
|
||||
css = read_file_content("./assets/style.css")
|
||||
|
||||
if os.path.exists("./assets/style.css"):
|
||||
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
|
||||
log.debug("Load CSS...")
|
||||
css += file.read() + "\n"
|
||||
|
||||
if os.path.exists("./.release"):
|
||||
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
|
||||
release = file.read()
|
||||
|
||||
if os.path.exists("./README.md"):
|
||||
with open(os.path.join("./README.md"), "r", encoding="utf8") as file:
|
||||
README = file.read()
|
||||
|
||||
interface = gr.Blocks(
|
||||
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
||||
)
|
||||
|
||||
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
|
||||
|
||||
if config.is_config_loaded():
|
||||
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
|
||||
|
||||
use_shell_flag = True
|
||||
# if os.name == "posix":
|
||||
# use_shell_flag = True
|
||||
|
||||
use_shell_flag = config.get("settings.use_shell", use_shell_flag)
|
||||
|
||||
if kwargs.get("do_not_use_shell", False):
|
||||
use_shell_flag = False
|
||||
|
||||
if use_shell_flag:
|
||||
log.info("Using shell=True when running external commands...")
|
||||
|
||||
with interface:
|
||||
# Create the main Gradio Blocks interface
|
||||
ui_interface = gr.Blocks(css=css, title=f"Kohya_ss GUI {release_info}", theme=gr.themes.Default())
|
||||
with ui_interface:
|
||||
# Create tabs for different functionalities
|
||||
with gr.Tab("Dreambooth"):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = dreambooth_tab(
|
||||
headless=headless, config=config, use_shell_flag=use_shell_flag
|
||||
)
|
||||
) = dreambooth_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||
with gr.Tab("LoRA"):
|
||||
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
|
||||
lora_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||
with gr.Tab("Textual Inversion"):
|
||||
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
|
||||
ti_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||
with gr.Tab("Finetuning"):
|
||||
finetune_tab(
|
||||
headless=headless, config=config, use_shell_flag=use_shell_flag
|
||||
)
|
||||
finetune_tab(headless=headless, config=config, use_shell_flag=use_shell)
|
||||
with gr.Tab("Utilities"):
|
||||
# Utilities tab requires inputs from the Dreambooth tab
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
|
|
@ -84,102 +60,97 @@ def UI(**kwargs):
|
|||
with gr.Tab("LoRA"):
|
||||
_ = LoRATools(headless=headless)
|
||||
with gr.Tab("About"):
|
||||
gr.Markdown(f"kohya_ss GUI release {release}")
|
||||
# About tab to display release information and README content
|
||||
gr.Markdown(f"kohya_ss GUI release {release_info}")
|
||||
with gr.Tab("README"):
|
||||
gr.Markdown(README)
|
||||
gr.Markdown(readme_content)
|
||||
|
||||
htmlStr = f"""
|
||||
<html>
|
||||
<body>
|
||||
<div class="ver-class">{release}</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
gr.HTML(htmlStr)
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
username = kwargs.get("username")
|
||||
password = kwargs.get("password")
|
||||
server_port = kwargs.get("server_port", 0)
|
||||
inbrowser = kwargs.get("inbrowser", False)
|
||||
share = kwargs.get("share", False)
|
||||
do_not_share = kwargs.get("do_not_share", False)
|
||||
server_name = kwargs.get("listen")
|
||||
root_path = kwargs.get("root_path", None)
|
||||
# Display release information in a div element
|
||||
gr.Markdown(f"<div class='ver-class'>{release_info}</div>")
|
||||
|
||||
launch_kwargs["server_name"] = server_name
|
||||
if username and password:
|
||||
launch_kwargs["auth"] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs["server_port"] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs["inbrowser"] = inbrowser
|
||||
if do_not_share:
|
||||
launch_kwargs["share"] = False
|
||||
else:
|
||||
if share:
|
||||
launch_kwargs["share"] = share
|
||||
if root_path:
|
||||
launch_kwargs["root_path"] = root_path
|
||||
launch_kwargs["debug"] = True
|
||||
interface.launch(**launch_kwargs)
|
||||
return ui_interface
|
||||
|
||||
# Function to configure and launch the UI
|
||||
def UI(**kwargs):
|
||||
# Add custom JavaScript if specified
|
||||
add_javascript(kwargs.get("language"))
|
||||
log.info(f"headless: {kwargs.get('headless', False)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
# Load release and README information
|
||||
release_info = read_file_content("./.release")
|
||||
readme_content = read_file_content("./README.md")
|
||||
|
||||
# Load configuration from the specified file
|
||||
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
|
||||
if config.is_config_loaded():
|
||||
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
|
||||
|
||||
# Determine if shell should be used for running external commands
|
||||
use_shell = not kwargs.get("do_not_use_shell", False) and config.get("settings.use_shell", True)
|
||||
if use_shell:
|
||||
log.info("Using shell=True when running external commands...")
|
||||
|
||||
# Initialize the Gradio UI interface
|
||||
ui_interface = initialize_ui_interface(config, kwargs.get("headless", False), use_shell, release_info, readme_content)
|
||||
|
||||
# Construct launch parameters using dictionary comprehension
|
||||
launch_params = {
|
||||
"server_name": kwargs.get("listen"),
|
||||
"auth": (kwargs["username"], kwargs["password"]) if kwargs.get("username") and kwargs.get("password") else None,
|
||||
"server_port": kwargs.get("server_port", 0) if kwargs.get("server_port", 0) > 0 else None,
|
||||
"inbrowser": kwargs.get("inbrowser", False),
|
||||
"share": False if kwargs.get("do_not_share", False) else kwargs.get("share", False),
|
||||
"root_path": kwargs.get("root_path", None),
|
||||
"debug": kwargs.get("debug", False),
|
||||
}
|
||||
|
||||
# This line filters out any key-value pairs from `launch_params` where the value is `None`, ensuring only valid parameters are passed to the `launch` function.
|
||||
launch_params = {k: v for k, v in launch_params.items() if v is not None}
|
||||
|
||||
# Launch the Gradio interface with the specified parameters
|
||||
ui_interface.launch(**launch_params)
|
||||
|
||||
# Function to initialize argument parser for command-line arguments
|
||||
def initialize_arg_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="./config.toml",
|
||||
help="Path to the toml config file for interface defaults",
|
||||
)
|
||||
parser.add_argument("--config", type=str, default="./config.toml", help="Path to the toml config file for interface defaults")
|
||||
parser.add_argument("--debug", action="store_true", help="Debug on")
|
||||
parser.add_argument(
|
||||
"--listen",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="IP to listen on for connections to Gradio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--username", type=str, default="", help="Username for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password", type=str, default="", help="Password for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port to run the server listener on",
|
||||
)
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", help="IP to listen on for connections to Gradio")
|
||||
parser.add_argument("--username", type=str, default="", help="Username for authentication")
|
||||
parser.add_argument("--password", type=str, default="", help="Password for authentication")
|
||||
parser.add_argument("--server_port", type=int, default=0, help="Port to run the server listener on")
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Is the server headless"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language", type=str, default=None, help="Set custom language"
|
||||
)
|
||||
|
||||
parser.add_argument("--headless", action="store_true", help="Is the server headless")
|
||||
parser.add_argument("--language", type=str, default=None, help="Set custom language")
|
||||
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
||||
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
|
||||
parser.add_argument("--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands")
|
||||
parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI")
|
||||
parser.add_argument("--requirements", type=str, default=None, help="requirements file to use for validation")
|
||||
parser.add_argument("--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss")
|
||||
parser.add_argument("--noverify", action="store_true", help="Disable requirements verification")
|
||||
return parser
|
||||
|
||||
parser.add_argument(
|
||||
"--do_not_use_shell", action="store_true", help="Enforce not to use shell=True when running external commands"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--do_not_share", action="store_true", help="Do not share the gradio UI"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize argument parser and parse arguments
|
||||
parser = initialize_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
# Set up logging based on the debug flag
|
||||
log = setup_logging(debug=args.debug)
|
||||
|
||||
# Verify requirements unless `noverify` flag is set
|
||||
if args.noverify:
|
||||
log.warning("Skipping requirements verification.")
|
||||
else:
|
||||
# Run the validation command to verify requirements
|
||||
validation_command = [PYTHON, os.path.join(project_dir, "setup", "validate_requirements.py")]
|
||||
|
||||
if args.requirements is not None:
|
||||
validation_command.append(f"--requirements={args.requirements}")
|
||||
|
||||
subprocess.run(validation_command, check=True)
|
||||
|
||||
# Launch the UI with the provided arguments
|
||||
UI(**vars(args))
|
||||
|
|
@ -102,7 +102,7 @@ def caption_images(
|
|||
postfix=postfix,
|
||||
)
|
||||
# Replace specified text in caption files if find and replace text is provided
|
||||
if find_text and replace_text:
|
||||
if find_text:
|
||||
find_replace(
|
||||
folder_path=images_dir,
|
||||
caption_file_ext=caption_ext,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def get_images_in_directory(directory_path):
|
|||
import os
|
||||
|
||||
# List of common image file extensions to look for
|
||||
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
|
||||
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
|
||||
|
||||
# Generate a list of image file paths in the directory
|
||||
image_files = [
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ import os
|
|||
import shlex
|
||||
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
|
||||
class AccelerateLaunch:
|
||||
|
|
@ -79,12 +83,16 @@ class AccelerateLaunch:
|
|||
)
|
||||
self.dynamo_use_fullgraph = gr.Checkbox(
|
||||
label="Dynamo use fullgraph",
|
||||
value=self.config.get("accelerate_launch.dynamo_use_fullgraph", False),
|
||||
value=self.config.get(
|
||||
"accelerate_launch.dynamo_use_fullgraph", False
|
||||
),
|
||||
info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
|
||||
)
|
||||
self.dynamo_use_dynamic = gr.Checkbox(
|
||||
label="Dynamo use dynamic",
|
||||
value=self.config.get("accelerate_launch.dynamo_use_dynamic", False),
|
||||
value=self.config.get(
|
||||
"accelerate_launch.dynamo_use_dynamic", False
|
||||
),
|
||||
info="Whether to enable dynamic shape tracing.",
|
||||
)
|
||||
|
||||
|
|
@ -103,6 +111,24 @@ class AccelerateLaunch:
|
|||
placeholder="example: 0,1",
|
||||
info=" What GPUs (by id) should be used for training on this machine as a comma-separated list",
|
||||
)
|
||||
|
||||
def validate_gpu_ids(value):
|
||||
if value == "":
|
||||
return
|
||||
if not (
|
||||
value.isdigit() and int(value) >= 0 and int(value) <= 128
|
||||
):
|
||||
log.error("GPU IDs must be an integer between 0 and 128")
|
||||
return
|
||||
else:
|
||||
for id in value.split(","):
|
||||
if not id.isdigit() or int(id) < 0 or int(id) > 128:
|
||||
log.error(
|
||||
"GPU IDs must be an integer between 0 and 128"
|
||||
)
|
||||
|
||||
self.gpu_ids.blur(fn=validate_gpu_ids, inputs=self.gpu_ids)
|
||||
|
||||
self.main_process_port = gr.Number(
|
||||
label="Main process port",
|
||||
value=self.config.get("accelerate_launch.main_process_port", 0),
|
||||
|
|
@ -137,8 +163,13 @@ class AccelerateLaunch:
|
|||
if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"):
|
||||
run_cmd.append("--dynamo_use_dynamic")
|
||||
|
||||
if "extra_accelerate_launch_args" in kwargs and kwargs["extra_accelerate_launch_args"] != "":
|
||||
extra_accelerate_launch_args = kwargs["extra_accelerate_launch_args"].replace('"', "")
|
||||
if (
|
||||
"extra_accelerate_launch_args" in kwargs
|
||||
and kwargs["extra_accelerate_launch_args"] != ""
|
||||
):
|
||||
extra_accelerate_launch_args = kwargs[
|
||||
"extra_accelerate_launch_args"
|
||||
].replace('"', "")
|
||||
for arg in extra_accelerate_launch_args.split():
|
||||
run_cmd.append(shlex.quote(arg))
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class AdvancedTraining:
|
|||
with gr.Row():
|
||||
self.loss_type = gr.Dropdown(
|
||||
label="Loss type",
|
||||
choices=["huber", "smooth_l1", "l2"],
|
||||
choices=["huber", "smooth_l1", "l1", "l2"],
|
||||
value=self.config.get("advanced.loss_type", "l2"),
|
||||
info="The type of loss to use and whether it's scheduled based on the timestep",
|
||||
)
|
||||
|
|
@ -168,6 +168,14 @@ class AdvancedTraining:
|
|||
step=0.01,
|
||||
info="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type",
|
||||
)
|
||||
self.huber_scale = gr.Number(
|
||||
label="Huber scale",
|
||||
value=self.config.get("advanced.huber_scale", 1.0),
|
||||
minimum=0.0,
|
||||
maximum=10.0,
|
||||
step=0.01,
|
||||
info="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.save_every_n_steps = gr.Number(
|
||||
|
|
@ -188,6 +196,18 @@ class AdvancedTraining:
|
|||
precision=0,
|
||||
info="(Optional) Save only the specified number of states (old models will be deleted)",
|
||||
)
|
||||
self.save_last_n_epochs = gr.Number(
|
||||
label="Save last N epochs",
|
||||
value=self.config.get("advanced.save_last_n_epochs", 0),
|
||||
precision=0,
|
||||
info="(Optional) Save only the specified number of epochs (old epochs will be deleted)",
|
||||
)
|
||||
self.save_last_n_epochs_state = gr.Number(
|
||||
label="Save last N epochs state",
|
||||
value=self.config.get("advanced.save_last_n_epochs_state", 0),
|
||||
precision=0,
|
||||
info="(Optional) Save only the specified number of epochs states (old models will be deleted)",
|
||||
)
|
||||
with gr.Row():
|
||||
|
||||
def full_options_update(full_fp16, full_bf16):
|
||||
|
|
@ -228,12 +248,16 @@ class AdvancedTraining:
|
|||
)
|
||||
|
||||
with gr.Row():
|
||||
if training_type == "lora":
|
||||
self.fp8_base = gr.Checkbox(
|
||||
label="fp8 base training (experimental)",
|
||||
info="U-Net and Text Encoder can be trained with fp8 (experimental)",
|
||||
value=self.config.get("advanced.fp8_base", False),
|
||||
)
|
||||
self.fp8_base = gr.Checkbox(
|
||||
label="fp8 base",
|
||||
info="Use fp8 for base model",
|
||||
value=self.config.get("advanced.fp8_base", False),
|
||||
)
|
||||
self.fp8_base_unet = gr.Checkbox(
|
||||
label="fp8 base unet",
|
||||
info="Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16.",
|
||||
value=self.config.get("advanced.fp8_base_unet", False),
|
||||
)
|
||||
self.full_fp16 = gr.Checkbox(
|
||||
label="Full fp16 training (experimental)",
|
||||
value=self.config.get("advanced.full_fp16", False),
|
||||
|
|
@ -255,6 +279,25 @@ class AdvancedTraining:
|
|||
outputs=[self.full_fp16, self.full_bf16],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.highvram = gr.Checkbox(
|
||||
label="highvram",
|
||||
value=self.config.get("advanced.highvram", False),
|
||||
info="Disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM)",
|
||||
interactive=True,
|
||||
)
|
||||
self.lowvram = gr.Checkbox(
|
||||
label="lowvram",
|
||||
value=self.config.get("advanced.lowvram", False),
|
||||
info="Enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)",
|
||||
interactive=True,
|
||||
)
|
||||
self.skip_cache_check = gr.Checkbox(
|
||||
label="Skip cache check",
|
||||
value=self.config.get("advanced.skip_cache_check", False),
|
||||
info="Skip cache check for faster training start",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.gradient_checkpointing = gr.Checkbox(
|
||||
label="Gradient checkpointing",
|
||||
|
|
@ -450,6 +493,15 @@ class AdvancedTraining:
|
|||
value=self.config.get("advanced.vae_batch_size", 0),
|
||||
step=1,
|
||||
)
|
||||
self.blocks_to_swap = gr.Slider(
|
||||
label="Blocks to swap",
|
||||
value=self.config.get("advanced.blocks_to_swap", 0),
|
||||
info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.",
|
||||
minimum=0,
|
||||
maximum=57,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
self.save_state = gr.Checkbox(
|
||||
label="Save training state",
|
||||
|
|
@ -534,6 +586,11 @@ class AdvancedTraining:
|
|||
self.current_log_tracker_config_dir = path if not path == "" else "."
|
||||
return list(list_files(path, exts=[".json"], all=True))
|
||||
|
||||
self.log_config = gr.Checkbox(
|
||||
label="Log config",
|
||||
value=self.config.get("advanced.log_config", False),
|
||||
info="Log training parameter to WANDB",
|
||||
)
|
||||
self.log_tracker_name = gr.Textbox(
|
||||
label="Log tracker name",
|
||||
value=self.config.get("advanced.log_tracker_name", ""),
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class BasicTraining:
|
|||
learning_rate_value: float = "1e-6",
|
||||
lr_scheduler_value: str = "constant",
|
||||
lr_warmup_value: float = "0",
|
||||
lr_warmup_steps_value: int = 0,
|
||||
finetuning: bool = False,
|
||||
dreambooth: bool = False,
|
||||
config: dict = {},
|
||||
|
|
@ -44,10 +45,14 @@ class BasicTraining:
|
|||
self.learning_rate_value = learning_rate_value
|
||||
self.lr_scheduler_value = lr_scheduler_value
|
||||
self.lr_warmup_value = lr_warmup_value
|
||||
self.lr_warmup_steps_value= lr_warmup_steps_value
|
||||
self.finetuning = finetuning
|
||||
self.dreambooth = dreambooth
|
||||
self.config = config
|
||||
|
||||
# Initialize old_lr_warmup and old_lr_warmup_steps with default values
|
||||
self.old_lr_warmup = 0
|
||||
self.old_lr_warmup_steps = 0
|
||||
|
||||
# Initialize the UI components
|
||||
self.initialize_ui_components()
|
||||
|
|
@ -162,20 +167,37 @@ class BasicTraining:
|
|||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"linear",
|
||||
"piecewise_constant",
|
||||
"polynomial",
|
||||
"cosine_with_min_lr",
|
||||
"inverse_sqrt",
|
||||
"warmup_stable_decay",
|
||||
],
|
||||
value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
|
||||
)
|
||||
|
||||
|
||||
# Initialize the learning rate scheduler type dropdown
|
||||
self.lr_scheduler_type = gr.Dropdown(
|
||||
label="LR Scheduler type",
|
||||
info="(Optional) custom scheduler module name",
|
||||
choices=[
|
||||
"",
|
||||
"CosineAnnealingLR",
|
||||
],
|
||||
value=self.config.get("basic.lr_scheduler_type", ""),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
# Initialize the optimizer dropdown
|
||||
self.optimizer = gr.Dropdown(
|
||||
label="Optimizer",
|
||||
choices=[
|
||||
"AdamW",
|
||||
"AdamWScheduleFree",
|
||||
"AdamW8bit",
|
||||
"Adafactor",
|
||||
"bitsandbytes.optim.AdEMAMix8bit",
|
||||
"bitsandbytes.optim.PagedAdEMAMix8bit",
|
||||
"DAdaptation",
|
||||
"DAdaptAdaGrad",
|
||||
"DAdaptAdam",
|
||||
|
|
@ -190,11 +212,15 @@ class BasicTraining:
|
|||
"PagedAdamW32bit",
|
||||
"PagedLion8bit",
|
||||
"Prodigy",
|
||||
"prodigyplus.ProdigyPlusScheduleFree",
|
||||
"RAdamScheduleFree",
|
||||
"SGDNesterov",
|
||||
"SGDNesterov8bit",
|
||||
"SGDScheduleFree",
|
||||
],
|
||||
value=self.config.get("basic.optimizer", "AdamW8bit"),
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
def init_grad_and_lr_controls(self) -> None:
|
||||
|
|
@ -240,7 +266,7 @@ class BasicTraining:
|
|||
self.learning_rate = gr.Number(
|
||||
label=lr_label,
|
||||
value=self.config.get("basic.learning_rate", self.learning_rate_value),
|
||||
minimum=0,
|
||||
minimum=-1,
|
||||
maximum=1,
|
||||
info="Set to 0 to not train the Unet",
|
||||
)
|
||||
|
|
@ -251,7 +277,7 @@ class BasicTraining:
|
|||
"basic.learning_rate_te", self.learning_rate_value
|
||||
),
|
||||
visible=self.finetuning or self.dreambooth,
|
||||
minimum=0,
|
||||
minimum=-1,
|
||||
maximum=1,
|
||||
info="Set to 0 to not train the Text Encoder",
|
||||
)
|
||||
|
|
@ -262,7 +288,7 @@ class BasicTraining:
|
|||
"basic.learning_rate_te1", self.learning_rate_value
|
||||
),
|
||||
visible=False,
|
||||
minimum=0,
|
||||
minimum=-1,
|
||||
maximum=1,
|
||||
info="Set to 0 to not train the Text Encoder 1",
|
||||
)
|
||||
|
|
@ -273,7 +299,7 @@ class BasicTraining:
|
|||
"basic.learning_rate_te2", self.learning_rate_value
|
||||
),
|
||||
visible=False,
|
||||
minimum=0,
|
||||
minimum=-1,
|
||||
maximum=1,
|
||||
info="Set to 0 to not train the Text Encoder 2",
|
||||
)
|
||||
|
|
@ -285,25 +311,37 @@ class BasicTraining:
|
|||
maximum=100,
|
||||
step=1,
|
||||
)
|
||||
# Initialize the learning rate warmup steps override
|
||||
self.lr_warmup_steps = gr.Number(
|
||||
label="LR warmup steps (override)",
|
||||
value=self.config.get("basic.lr_warmup_steps", self.lr_warmup_steps_value),
|
||||
minimum=0,
|
||||
step=1,
|
||||
)
|
||||
|
||||
def lr_scheduler_changed(scheduler, value):
|
||||
def lr_scheduler_changed(scheduler, value, value_lr_warmup_steps):
|
||||
if scheduler == "constant":
|
||||
self.old_lr_warmup = value
|
||||
self.old_lr_warmup_steps = value_lr_warmup_steps
|
||||
value = 0
|
||||
value_lr_warmup_steps = 0
|
||||
interactive=False
|
||||
info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..."
|
||||
else:
|
||||
if self.old_lr_warmup != 0:
|
||||
value = self.old_lr_warmup
|
||||
self.old_lr_warmup = 0
|
||||
if self.old_lr_warmup_steps != 0:
|
||||
value_lr_warmup_steps = self.old_lr_warmup_steps
|
||||
self.old_lr_warmup_steps = 0
|
||||
interactive=True
|
||||
info=""
|
||||
return gr.Slider(value=value, interactive=interactive, info=info)
|
||||
return gr.Slider(value=value, interactive=interactive, info=info), gr.Number(value=value_lr_warmup_steps, interactive=interactive, info=info)
|
||||
|
||||
self.lr_scheduler.change(
|
||||
lr_scheduler_changed,
|
||||
inputs=[self.lr_scheduler, self.lr_warmup],
|
||||
outputs=self.lr_warmup,
|
||||
inputs=[self.lr_scheduler, self.lr_warmup, self.lr_warmup_steps],
|
||||
outputs=[self.lr_warmup, self.lr_warmup_steps],
|
||||
)
|
||||
|
||||
def init_scheduler_controls(self) -> None:
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class CommandExecutor:
|
|||
|
||||
# Execute the command securely
|
||||
self.process = subprocess.Popen(run_cmd, **kwargs)
|
||||
log.info("Command executed.")
|
||||
log.debug("Command executed.")
|
||||
|
||||
def kill_command(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,336 @@
|
|||
import gradio as gr
|
||||
from typing import Tuple
|
||||
from .common_gui import (
|
||||
get_any_file_path,
|
||||
document_symbol,
|
||||
)
|
||||
|
||||
|
||||
class flux1Training:
|
||||
def __init__(
|
||||
self,
|
||||
headless: bool = False,
|
||||
finetuning: bool = False,
|
||||
training_type: str = "",
|
||||
config: dict = {},
|
||||
flux1_checkbox: gr.Checkbox = False,
|
||||
) -> None:
|
||||
self.headless = headless
|
||||
self.finetuning = finetuning
|
||||
self.training_type = training_type
|
||||
self.config = config
|
||||
self.flux1_checkbox = flux1_checkbox
|
||||
|
||||
# Define the behavior for changing noise offset type.
|
||||
def noise_offset_type_change(
|
||||
noise_offset_type: str,
|
||||
) -> Tuple[gr.Group, gr.Group]:
|
||||
if noise_offset_type == "Original":
|
||||
return (gr.Group(visible=True), gr.Group(visible=False))
|
||||
else:
|
||||
return (gr.Group(visible=False), gr.Group(visible=True))
|
||||
|
||||
with gr.Accordion(
|
||||
"Flux.1", open=True, visible=False, elem_classes=["flux1_background"]
|
||||
) as flux1_accordion:
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
self.ae = gr.Textbox(
|
||||
label="VAE Path",
|
||||
placeholder="Path to VAE model",
|
||||
value=self.config.get("flux1.ae", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.ae_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.ae_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.ae,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.clip_l = gr.Textbox(
|
||||
label="CLIP-L Path",
|
||||
placeholder="Path to CLIP-L model",
|
||||
value=self.config.get("flux1.clip_l", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.clip_l_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.clip_l_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.clip_l,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.t5xxl = gr.Textbox(
|
||||
label="T5-XXL Path",
|
||||
placeholder="Path to T5-XXL model",
|
||||
value=self.config.get("flux1.t5xxl", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5xxl_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5xxl_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.t5xxl,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
self.discrete_flow_shift = gr.Number(
|
||||
label="Discrete Flow Shift",
|
||||
value=self.config.get("flux1.discrete_flow_shift", 3.0),
|
||||
info="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0",
|
||||
minimum=-1024,
|
||||
maximum=1024,
|
||||
step=0.01,
|
||||
interactive=True,
|
||||
)
|
||||
self.model_prediction_type = gr.Dropdown(
|
||||
label="Model Prediction Type",
|
||||
choices=["raw", "additive", "sigma_scaled"],
|
||||
value=self.config.get(
|
||||
"flux1.timestep_sampling", "sigma_scaled"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
self.timestep_sampling = gr.Dropdown(
|
||||
label="Timestep Sampling",
|
||||
choices=["flux_shift", "sigma", "shift", "sigmoid", "uniform"],
|
||||
value=self.config.get("flux1.timestep_sampling", "sigma"),
|
||||
interactive=True,
|
||||
)
|
||||
self.apply_t5_attn_mask = gr.Checkbox(
|
||||
label="Apply T5 Attention Mask",
|
||||
value=self.config.get("flux1.apply_t5_attn_mask", False),
|
||||
info="Apply attention mask to T5-XXL encode and FLUX double blocks ",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(visible=True if not finetuning else False):
|
||||
self.split_mode = gr.Checkbox(
|
||||
label="Split Mode",
|
||||
value=self.config.get("flux1.split_mode", False),
|
||||
info="Split mode for Flux1",
|
||||
interactive=True,
|
||||
)
|
||||
self.train_blocks = gr.Dropdown(
|
||||
label="Train Blocks",
|
||||
choices=["all", "double", "single"],
|
||||
value=self.config.get("flux1.train_blocks", "all"),
|
||||
interactive=True,
|
||||
)
|
||||
self.split_qkv = gr.Checkbox(
|
||||
label="Split QKV",
|
||||
value=self.config.get("flux1.split_qkv", False),
|
||||
info="Split the projection layers of q/k/v/txt in the attention",
|
||||
interactive=True,
|
||||
)
|
||||
self.train_t5xxl = gr.Checkbox(
|
||||
label="Train T5-XXL",
|
||||
value=self.config.get("flux1.train_t5xxl", False),
|
||||
info="Train T5-XXL model",
|
||||
interactive=True,
|
||||
)
|
||||
self.cpu_offload_checkpointing = gr.Checkbox(
|
||||
label="CPU Offload Checkpointing",
|
||||
value=self.config.get("flux1.cpu_offload_checkpointing", False),
|
||||
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
self.guidance_scale = gr.Number(
|
||||
label="Guidance Scale",
|
||||
value=self.config.get("flux1.guidance_scale", 3.5),
|
||||
info="Guidance scale for Flux1",
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
)
|
||||
self.t5xxl_max_token_length = gr.Number(
|
||||
label="T5-XXL Max Token Length",
|
||||
value=self.config.get("flux1.t5xxl_max_token_length", 512),
|
||||
info="Max token length for T5-XXL",
|
||||
minimum=0,
|
||||
maximum=4096,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
self.enable_all_linear = gr.Checkbox(
|
||||
label="Enable All Linear",
|
||||
value=self.config.get("flux1.enable_all_linear", False),
|
||||
info="(Only applicable to 'FLux1 OFT' LoRA) Target all linear connections in the MLP layer. The default is False, which targets only attention.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.flux1_cache_text_encoder_outputs = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs",
|
||||
value=self.config.get(
|
||||
"flux1.cache_text_encoder_outputs", False
|
||||
),
|
||||
info="Cache text encoder outputs to speed up inference",
|
||||
interactive=True,
|
||||
)
|
||||
self.flux1_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs to Disk",
|
||||
value=self.config.get(
|
||||
"flux1.cache_text_encoder_outputs_to_disk", False
|
||||
),
|
||||
info="Cache text encoder outputs to disk to speed up inference",
|
||||
interactive=True,
|
||||
)
|
||||
self.mem_eff_save = gr.Checkbox(
|
||||
label="Memory Efficient Save",
|
||||
value=self.config.get("flux1.mem_eff_save", False),
|
||||
info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# self.blocks_to_swap = gr.Slider(
|
||||
# label="Blocks to swap",
|
||||
# value=self.config.get("flux1.blocks_to_swap", 0),
|
||||
# info="The number of blocks to swap. The default is None (no swap). These options must be combined with --fused_backward_pass or --blockwise_fused_optimizers. The recommended maximum value is 36.",
|
||||
# minimum=0,
|
||||
# maximum=57,
|
||||
# step=1,
|
||||
# interactive=True,
|
||||
# )
|
||||
self.single_blocks_to_swap = gr.Slider(
|
||||
label="Single Blocks to swap (depercated)",
|
||||
value=self.config.get("flux1.single_blocks_to_swap", 0),
|
||||
info="[Experimental] Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes.",
|
||||
minimum=0,
|
||||
maximum=19,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
self.double_blocks_to_swap = gr.Slider(
|
||||
label="Double Blocks to swap (depercated)",
|
||||
value=self.config.get("flux1.double_blocks_to_swap", 0),
|
||||
info="[Experimental] Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes.",
|
||||
minimum=0,
|
||||
maximum=38,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=True if finetuning else False):
|
||||
self.blockwise_fused_optimizers = gr.Checkbox(
|
||||
label="Blockwise Fused Optimizer",
|
||||
value=self.config.get(
|
||||
"flux1.blockwise_fused_optimizers", False
|
||||
),
|
||||
info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.",
|
||||
interactive=True,
|
||||
)
|
||||
self.cpu_offload_checkpointing = gr.Checkbox(
|
||||
label="CPU Offload Checkpointing",
|
||||
value=self.config.get("flux1.cpu_offload_checkpointing", False),
|
||||
info="[Experimental] Enable offloading of tensors to CPU during checkpointing",
|
||||
interactive=True,
|
||||
)
|
||||
self.flux_fused_backward_pass = gr.Checkbox(
|
||||
label="Fused Backward Pass",
|
||||
value=self.config.get("flux1.fused_backward_pass", False),
|
||||
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
"Blocks to train",
|
||||
open=True,
|
||||
visible=False if finetuning else True,
|
||||
elem_classes=["flux1_blocks_to_train_background"],
|
||||
):
|
||||
with gr.Row():
|
||||
self.train_double_block_indices = gr.Textbox(
|
||||
label="train_double_block_indices",
|
||||
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of double blocks is 19.",
|
||||
value=self.config.get("flux1.train_double_block_indices", "all"),
|
||||
interactive=True,
|
||||
)
|
||||
self.train_single_block_indices = gr.Textbox(
|
||||
label="train_single_block_indices",
|
||||
info="The indices are specified as a list of integers or a range of integers, like '0,1,5,8' or '0,1,4-5,7' or 'all' or 'none'. The number of single blocks is 38.",
|
||||
value=self.config.get("flux1.train_single_block_indices", "all"),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
"Rank for layers",
|
||||
open=False,
|
||||
visible=False if finetuning else True,
|
||||
elem_classes=["flux1_rank_layers_background"],
|
||||
):
|
||||
with gr.Row():
|
||||
self.img_attn_dim = gr.Textbox(
|
||||
label="img_attn_dim",
|
||||
value=self.config.get("flux1.img_attn_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.img_mlp_dim = gr.Textbox(
|
||||
label="img_mlp_dim",
|
||||
value=self.config.get("flux1.img_mlp_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.img_mod_dim = gr.Textbox(
|
||||
label="img_mod_dim",
|
||||
value=self.config.get("flux1.img_mod_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.single_dim = gr.Textbox(
|
||||
label="single_dim",
|
||||
value=self.config.get("flux1.single_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
self.txt_attn_dim = gr.Textbox(
|
||||
label="txt_attn_dim",
|
||||
value=self.config.get("flux1.txt_attn_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.txt_mlp_dim = gr.Textbox(
|
||||
label="txt_mlp_dim",
|
||||
value=self.config.get("flux1.txt_mlp_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.txt_mod_dim = gr.Textbox(
|
||||
label="txt_mod_dim",
|
||||
value=self.config.get("flux1.txt_mod_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.single_mod_dim = gr.Textbox(
|
||||
label="single_mod_dim",
|
||||
value=self.config.get("flux1.single_mod_dim", ""),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
self.in_dims = gr.Textbox(
|
||||
label="in_dims",
|
||||
value=self.config.get("flux1.in_dims", ""),
|
||||
placeholder="e.g., [4,0,0,0,4]",
|
||||
info="Each number corresponds to img_in, time_in, vector_in, guidance_in, txt_in. The above example applies LoRA to all conditioning layers, with rank 4 for img_in, 2 for time_in, vector_in, guidance_in, and 4 for txt_in.",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
self.flux1_checkbox.change(
|
||||
lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox),
|
||||
inputs=[self.flux1_checkbox],
|
||||
outputs=[flux1_accordion],
|
||||
)
|
||||
|
|
@ -4,10 +4,12 @@ from .svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
|||
from .verify_lora_gui import gradio_verify_lora_tab
|
||||
from .resize_lora_gui import gradio_resize_lora_tab
|
||||
from .extract_lora_gui import gradio_extract_lora_tab
|
||||
from .flux_extract_lora_gui import gradio_flux_extract_lora_tab
|
||||
from .convert_lcm_gui import gradio_convert_lcm_tab
|
||||
from .extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
|
||||
from .extract_lora_from_dylora_gui import gradio_extract_dylora_tab
|
||||
from .merge_lycoris_gui import gradio_merge_lycoris_tab
|
||||
from .flux_merge_lora_gui import GradioFluxMergeLoRaTab
|
||||
|
||||
|
||||
class LoRATools:
|
||||
|
|
@ -19,9 +21,11 @@ class LoRATools:
|
|||
gradio_extract_dylora_tab(headless=headless)
|
||||
gradio_convert_lcm_tab(headless=headless)
|
||||
gradio_extract_lora_tab(headless=headless)
|
||||
gradio_flux_extract_lora_tab(headless=headless)
|
||||
gradio_extract_lycoris_locon_tab(headless=headless)
|
||||
gradio_merge_lora_tab = GradioMergeLoRaTab()
|
||||
gradio_merge_lycoris_tab(headless=headless)
|
||||
gradio_svd_merge_lora_tab(headless=headless)
|
||||
gradio_resize_lora_tab(headless=headless)
|
||||
gradio_verify_lora_tab(headless=headless)
|
||||
GradioFluxMergeLoRaTab(headless=headless)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,10 @@ def create_prompt_file(sample_prompts, output_dir):
|
|||
Returns:
|
||||
str: The path to the prompt file.
|
||||
"""
|
||||
sample_prompts_path = os.path.join(output_dir, "prompt.txt")
|
||||
sample_prompts_path = os.path.join(output_dir, "sample/prompt.txt")
|
||||
|
||||
if not os.path.exists(os.path.dirname(sample_prompts_path)):
|
||||
os.makedirs(os.path.dirname(sample_prompts_path))
|
||||
|
||||
with open(sample_prompts_path, "w", encoding="utf-8") as f:
|
||||
f.write(sample_prompts)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
import gradio as gr
|
||||
from typing import Tuple
|
||||
from .common_gui import (
|
||||
get_folder_path,
|
||||
get_any_file_path,
|
||||
list_files,
|
||||
list_dirs,
|
||||
create_refresh_button,
|
||||
document_symbol,
|
||||
)
|
||||
|
||||
|
||||
class sd3Training:
|
||||
"""
|
||||
This class configures and initializes the advanced training settings for a machine learning model,
|
||||
including options for headless operation, fine-tuning, training type selection, and default directory paths.
|
||||
|
||||
Attributes:
|
||||
headless (bool): If True, run without the Gradio interface.
|
||||
finetuning (bool): If True, enables fine-tuning of the model.
|
||||
training_type (str): Specifies the type of training to perform.
|
||||
no_token_padding (gr.Checkbox): Checkbox to disable token padding.
|
||||
gradient_accumulation_steps (gr.Slider): Slider to set the number of gradient accumulation steps.
|
||||
weighted_captions (gr.Checkbox): Checkbox to enable weighted captions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headless: bool = False,
|
||||
finetuning: bool = False,
|
||||
training_type: str = "",
|
||||
config: dict = {},
|
||||
sd3_checkbox: gr.Checkbox = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the AdvancedTraining class with given settings.
|
||||
|
||||
Parameters:
|
||||
headless (bool): Run in headless mode without GUI.
|
||||
finetuning (bool): Enable model fine-tuning.
|
||||
training_type (str): The type of training to be performed.
|
||||
config (dict): Configuration options for the training process.
|
||||
"""
|
||||
self.headless = headless
|
||||
self.finetuning = finetuning
|
||||
self.training_type = training_type
|
||||
self.config = config
|
||||
self.sd3_checkbox = sd3_checkbox
|
||||
|
||||
# Define the behavior for changing noise offset type.
|
||||
def noise_offset_type_change(
|
||||
noise_offset_type: str,
|
||||
) -> Tuple[gr.Group, gr.Group]:
|
||||
"""
|
||||
Returns a tuple of Gradio Groups with visibility set based on the noise offset type.
|
||||
|
||||
Parameters:
|
||||
noise_offset_type (str): The selected noise offset type.
|
||||
|
||||
Returns:
|
||||
Tuple[gr.Group, gr.Group]: A tuple containing two Gradio Group elements with their visibility set.
|
||||
"""
|
||||
if noise_offset_type == "Original":
|
||||
return (gr.Group(visible=True), gr.Group(visible=False))
|
||||
else:
|
||||
return (gr.Group(visible=False), gr.Group(visible=True))
|
||||
|
||||
with gr.Accordion(
|
||||
"SD3", open=False, elem_id="sd3_tab", visible=False
|
||||
) as sd3_accordion:
|
||||
with gr.Group():
|
||||
gr.Markdown("### SD3 Specific Parameters")
|
||||
with gr.Row():
|
||||
self.weighting_scheme = gr.Dropdown(
|
||||
label="Weighting Scheme",
|
||||
choices=["logit_normal", "sigma_sqrt", "mode", "cosmap", "uniform"],
|
||||
value=self.config.get("sd3.weighting_scheme", "logit_normal"),
|
||||
interactive=True,
|
||||
)
|
||||
self.logit_mean = gr.Number(
|
||||
label="Logit Mean",
|
||||
value=self.config.get("sd3.logit_mean", 0.0),
|
||||
interactive=True,
|
||||
)
|
||||
self.logit_std = gr.Number(
|
||||
label="Logit Std",
|
||||
value=self.config.get("sd3.logit_std", 1.0),
|
||||
interactive=True,
|
||||
)
|
||||
self.mode_scale = gr.Number(
|
||||
label="Mode Scale",
|
||||
value=self.config.get("sd3.mode_scale", 1.29),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.clip_l = gr.Textbox(
|
||||
label="CLIP-L Path",
|
||||
placeholder="Path to CLIP-L model",
|
||||
value=self.config.get("sd3.clip_l", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.clip_l_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.clip_l_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.clip_l,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.clip_g = gr.Textbox(
|
||||
label="CLIP-G Path",
|
||||
placeholder="Path to CLIP-G model",
|
||||
value=self.config.get("sd3.clip_g", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.clip_g_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.clip_g_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.clip_g,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
self.t5xxl = gr.Textbox(
|
||||
label="T5-XXL Path",
|
||||
placeholder="Path to T5-XXL model",
|
||||
value=self.config.get("sd3.t5xxl", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5xxl_button = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5xxl_button.click(
|
||||
get_any_file_path,
|
||||
outputs=self.t5xxl,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.save_clip = gr.Checkbox(
|
||||
label="Save CLIP models",
|
||||
value=self.config.get("sd3.save_clip", False),
|
||||
interactive=True,
|
||||
)
|
||||
self.save_t5xxl = gr.Checkbox(
|
||||
label="Save T5-XXL model",
|
||||
value=self.config.get("sd3.save_t5xxl", False),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.t5xxl_device = gr.Textbox(
|
||||
label="T5-XXL Device",
|
||||
placeholder="Device for T5-XXL (e.g., cuda:0)",
|
||||
value=self.config.get("sd3.t5xxl_device", ""),
|
||||
interactive=True,
|
||||
)
|
||||
self.t5xxl_dtype = gr.Dropdown(
|
||||
label="T5-XXL Dtype",
|
||||
choices=["float32", "fp16", "bf16"],
|
||||
value=self.config.get("sd3.t5xxl_dtype", "bf16"),
|
||||
interactive=True,
|
||||
)
|
||||
self.sd3_text_encoder_batch_size = gr.Number(
|
||||
label="Text Encoder Batch Size",
|
||||
value=self.config.get("sd3.text_encoder_batch_size", 1),
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
self.sd3_cache_text_encoder_outputs = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs",
|
||||
value=self.config.get("sd3.cache_text_encoder_outputs", False),
|
||||
info="Cache text encoder outputs to speed up inference",
|
||||
interactive=True,
|
||||
)
|
||||
self.sd3_cache_text_encoder_outputs_to_disk = gr.Checkbox(
|
||||
label="Cache Text Encoder Outputs to Disk",
|
||||
value=self.config.get(
|
||||
"sd3.cache_text_encoder_outputs_to_disk", False
|
||||
),
|
||||
info="Cache text encoder outputs to disk to speed up inference",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
self.clip_l_dropout_rate = gr.Number(
|
||||
label="CLIP-L Dropout Rate",
|
||||
value=self.config.get("sd3.clip_l_dropout_rate", 0.0),
|
||||
interactive=True,
|
||||
minimum=0.0,
|
||||
info="Dropout rate for CLIP-L encoder"
|
||||
)
|
||||
self.clip_g_dropout_rate = gr.Number(
|
||||
label="CLIP-G Dropout Rate",
|
||||
value=self.config.get("sd3.clip_g_dropout_rate", 0.0),
|
||||
interactive=True,
|
||||
minimum=0.0,
|
||||
info="Dropout rate for CLIP-G encoder"
|
||||
)
|
||||
self.t5_dropout_rate = gr.Number(
|
||||
label="T5 Dropout Rate",
|
||||
value=self.config.get("sd3.t5_dropout_rate", 0.0),
|
||||
interactive=True,
|
||||
minimum=0.0,
|
||||
info="Dropout rate for T5-XXL encoder"
|
||||
)
|
||||
with gr.Row():
|
||||
self.sd3_fused_backward_pass = gr.Checkbox(
|
||||
label="Fused Backward Pass",
|
||||
value=self.config.get("sd3.fused_backward_pass", False),
|
||||
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
|
||||
interactive=True,
|
||||
)
|
||||
self.disable_mmap_load_safetensors = gr.Checkbox(
|
||||
label="Disable mmap load safe tensors",
|
||||
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
|
||||
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
|
||||
)
|
||||
self.enable_scaled_pos_embed = gr.Checkbox(
|
||||
label="Enable Scaled Positional Embeddings",
|
||||
info="Enable scaled positional embeddings in the model.",
|
||||
value=self.config.get("sd3.enable_scaled_pos_embed", False),
|
||||
)
|
||||
self.pos_emb_random_crop_rate = gr.Number(
|
||||
label="Positional Embedding Random Crop Rate",
|
||||
value=self.config.get("sd3.pos_emb_random_crop_rate", 0.0),
|
||||
interactive=True,
|
||||
minimum=0.0,
|
||||
info="Random crop rate for positional embeddings"
|
||||
)
|
||||
|
||||
self.sd3_checkbox.change(
|
||||
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
|
||||
inputs=[self.sd3_checkbox],
|
||||
outputs=[sd3_accordion],
|
||||
)
|
||||
|
|
@ -7,10 +7,12 @@ class SDXLParameters:
|
|||
sdxl_checkbox: gr.Checkbox,
|
||||
show_sdxl_cache_text_encoder_outputs: bool = True,
|
||||
config: KohyaSSGUIConfig = {},
|
||||
trainer: str = "",
|
||||
):
|
||||
self.sdxl_checkbox = sdxl_checkbox
|
||||
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
|
||||
self.config = config
|
||||
self.trainer = trainer
|
||||
|
||||
self.initialize_accordion()
|
||||
|
||||
|
|
@ -30,6 +32,41 @@ class SDXLParameters:
|
|||
info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.",
|
||||
value=self.config.get("sdxl.sdxl_no_half_vae", False),
|
||||
)
|
||||
self.fused_backward_pass = gr.Checkbox(
|
||||
label="Fused backward pass",
|
||||
info="Enable fused backward pass. This option is useful to reduce the GPU memory usage. Can't be used if Fused optimizer groups is > 0. Only AdaFactor is supported",
|
||||
value=self.config.get("sdxl.fused_backward_pass", False),
|
||||
visible=self.trainer == "finetune" or self.trainer == "dreambooth",
|
||||
)
|
||||
self.fused_optimizer_groups = gr.Number(
|
||||
label="Fused optimizer groups",
|
||||
info="Number of optimizer groups to fuse. This option is useful to reduce the GPU memory usage. Can't be used if Fused backward pass is enabled. Since the effect is limited to a certain number, it is recommended to specify 4-10.",
|
||||
value=self.config.get("sdxl.fused_optimizer_groups", 0),
|
||||
minimum=0,
|
||||
step=1,
|
||||
visible=self.trainer == "finetune" or self.trainer == "dreambooth",
|
||||
)
|
||||
self.disable_mmap_load_safetensors = gr.Checkbox(
|
||||
label="Disable mmap load safe tensors",
|
||||
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
|
||||
value=self.config.get("sdxl.disable_mmap_load_safetensors", False),
|
||||
)
|
||||
|
||||
self.fused_backward_pass.change(
|
||||
lambda fused_backward_pass: gr.Number(
|
||||
interactive=not fused_backward_pass
|
||||
),
|
||||
inputs=[self.fused_backward_pass],
|
||||
outputs=[self.fused_optimizer_groups],
|
||||
)
|
||||
self.fused_optimizer_groups.change(
|
||||
lambda fused_optimizer_groups: gr.Checkbox(
|
||||
interactive=fused_optimizer_groups == 0
|
||||
),
|
||||
inputs=[self.fused_optimizer_groups],
|
||||
outputs=[self.fused_backward_pass],
|
||||
)
|
||||
|
||||
|
||||
self.sdxl_checkbox.change(
|
||||
lambda sdxl_checkbox: gr.Accordion(visible=sdxl_checkbox),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ default_models = [
|
|||
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
|
|
@ -245,19 +244,88 @@ class SourceModel:
|
|||
with gr.Column():
|
||||
with gr.Row():
|
||||
self.v2 = gr.Checkbox(
|
||||
label="v2", value=False, visible=False, min_width=60
|
||||
label="v2", value=False, visible=False, min_width=60,
|
||||
interactive=True,
|
||||
)
|
||||
self.v_parameterization = gr.Checkbox(
|
||||
label="v_parameterization",
|
||||
label="v_param",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=130,
|
||||
interactive=True,
|
||||
)
|
||||
self.sdxl_checkbox = gr.Checkbox(
|
||||
label="SDXL",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=60,
|
||||
interactive=True,
|
||||
)
|
||||
self.sd3_checkbox = gr.Checkbox(
|
||||
label="SD3",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=60,
|
||||
interactive=True,
|
||||
)
|
||||
self.flux1_checkbox = gr.Checkbox(
|
||||
label="Flux.1",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=60,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
def toggle_checkboxes(v2, v_parameterization, sdxl_checkbox, sd3_checkbox, flux1_checkbox):
|
||||
# Check if all checkboxes are unchecked
|
||||
if not v2 and not v_parameterization and not sdxl_checkbox and not sd3_checkbox and not flux1_checkbox:
|
||||
# If all unchecked, return new interactive checkboxes
|
||||
return (
|
||||
gr.Checkbox(interactive=True), # v2 checkbox
|
||||
gr.Checkbox(interactive=True), # v_parameterization checkbox
|
||||
gr.Checkbox(interactive=True), # sdxl_checkbox
|
||||
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||
gr.Checkbox(interactive=True), # sd3_checkbox
|
||||
)
|
||||
else:
|
||||
# If any checkbox is checked, return checkboxes with current interactive state
|
||||
return (
|
||||
gr.Checkbox(interactive=v2), # v2 checkbox
|
||||
gr.Checkbox(interactive=v_parameterization), # v_parameterization checkbox
|
||||
gr.Checkbox(interactive=sdxl_checkbox), # sdxl_checkbox
|
||||
gr.Checkbox(interactive=sd3_checkbox), # sd3_checkbox
|
||||
gr.Checkbox(interactive=flux1_checkbox), # flux1_checkbox
|
||||
)
|
||||
|
||||
self.v2.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.v_parameterization.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.sdxl_checkbox.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.sd3_checkbox.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
self.flux1_checkbox.change(
|
||||
fn=toggle_checkboxes,
|
||||
inputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
outputs=[self.v2, self.v_parameterization, self.sdxl_checkbox, self.sd3_checkbox, self.flux1_checkbox],
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column():
|
||||
gr.Group(visible=False)
|
||||
|
|
@ -294,6 +362,8 @@ class SourceModel:
|
|||
self.v2,
|
||||
self.v_parameterization,
|
||||
self.sdxl_checkbox,
|
||||
self.sd3_checkbox,
|
||||
self.flux1_checkbox,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .common_gui import setup_environment
|
|||
|
||||
class TensorboardManager:
|
||||
DEFAULT_TENSORBOARD_PORT = 6006
|
||||
DEFAULT_TENSORBOARD_HOST = "0.0.0.0"
|
||||
|
||||
def __init__(self, logging_dir, headless: bool = False, wait_time=5):
|
||||
self.logging_dir = logging_dir
|
||||
|
|
@ -29,6 +30,9 @@ class TensorboardManager:
|
|||
self.tensorboard_port = os.environ.get(
|
||||
"TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT
|
||||
)
|
||||
self.tensorboard_host = os.environ.get(
|
||||
"TENSORBOARD_HOST", self.DEFAULT_TENSORBOARD_HOST
|
||||
)
|
||||
self.log = setup_logging()
|
||||
self.thread = None
|
||||
self.stop_event = Event()
|
||||
|
|
@ -64,7 +68,7 @@ class TensorboardManager:
|
|||
"--logdir",
|
||||
logging_dir,
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
self.tensorboard_host,
|
||||
"--port",
|
||||
str(self.tensorboard_port),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ except ImportError:
|
|||
from easygui import msgbox, ynbox
|
||||
from typing import Optional
|
||||
from .custom_logging import setup_logging
|
||||
from .sd_modeltype import SDModelType
|
||||
|
||||
import os
|
||||
import re
|
||||
|
|
@ -327,7 +328,6 @@ def update_my_data(my_data):
|
|||
|
||||
# Convert values to int if they are strings
|
||||
for key in [
|
||||
"adaptive_noise_scale",
|
||||
"clip_skip",
|
||||
"epoch",
|
||||
"gradient_accumulation_steps",
|
||||
|
|
@ -379,7 +379,13 @@ def update_my_data(my_data):
|
|||
my_data[key] = int(75)
|
||||
|
||||
# Convert values to float if they are strings, correctly handling float representations
|
||||
for key in ["noise_offset", "learning_rate", "text_encoder_lr", "unet_lr"]:
|
||||
for key in [
|
||||
"adaptive_noise_scale",
|
||||
"noise_offset",
|
||||
"learning_rate",
|
||||
"text_encoder_lr",
|
||||
"unet_lr",
|
||||
]:
|
||||
value = my_data.get(key)
|
||||
if value is not None:
|
||||
try:
|
||||
|
|
@ -956,11 +962,15 @@ def set_pretrained_model_name_or_path_input(
|
|||
v2 = gr.Checkbox(value=False, visible=False)
|
||||
v_parameterization = gr.Checkbox(value=False, visible=False)
|
||||
sdxl = gr.Checkbox(value=True, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
)
|
||||
|
||||
# Check if the given pretrained_model_name_or_path is in the list of V2 base models
|
||||
|
|
@ -969,11 +979,15 @@ def set_pretrained_model_name_or_path_input(
|
|||
v2 = gr.Checkbox(value=True, visible=False)
|
||||
v_parameterization = gr.Checkbox(value=False, visible=False)
|
||||
sdxl = gr.Checkbox(value=False, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
)
|
||||
|
||||
# Check if the given pretrained_model_name_or_path is in the list of V parameterization models
|
||||
|
|
@ -984,11 +998,15 @@ def set_pretrained_model_name_or_path_input(
|
|||
v2 = gr.Checkbox(value=True, visible=False)
|
||||
v_parameterization = gr.Checkbox(value=True, visible=False)
|
||||
sdxl = gr.Checkbox(value=False, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
)
|
||||
|
||||
# Check if the given pretrained_model_name_or_path is in the list of V1 models
|
||||
|
|
@ -997,17 +1015,32 @@ def set_pretrained_model_name_or_path_input(
|
|||
v2 = gr.Checkbox(value=False, visible=False)
|
||||
v_parameterization = gr.Checkbox(value=False, visible=False)
|
||||
sdxl = gr.Checkbox(value=False, visible=False)
|
||||
sd3 = gr.Checkbox(value=False, visible=False)
|
||||
flux1 = gr.Checkbox(value=False, visible=False)
|
||||
return (
|
||||
gr.Dropdown(),
|
||||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
)
|
||||
|
||||
# Check if the model_list is set to 'custom'
|
||||
v2 = gr.Checkbox(visible=True)
|
||||
v_parameterization = gr.Checkbox(visible=True)
|
||||
sdxl = gr.Checkbox(visible=True)
|
||||
sd3 = gr.Checkbox(visible=True)
|
||||
flux1 = gr.Checkbox(visible=True)
|
||||
|
||||
# Auto-detect model type if safetensors file path is given
|
||||
if pretrained_model_name_or_path.lower().endswith(".safetensors"):
|
||||
detect = SDModelType(pretrained_model_name_or_path)
|
||||
v2 = gr.Checkbox(value=detect.Is_SD2(), visible=True)
|
||||
sdxl = gr.Checkbox(value=detect.Is_SDXL(), visible=True)
|
||||
sd3 = gr.Checkbox(value=detect.Is_SD3(), visible=True)
|
||||
flux1 = gr.Checkbox(value=detect.Is_FLUX1(), visible=True)
|
||||
#TODO: v_parameterization
|
||||
|
||||
# If a refresh method is provided, use it to update the choices for the Dropdown widget
|
||||
if refresh_method is not None:
|
||||
|
|
@ -1021,6 +1054,8 @@ def set_pretrained_model_name_or_path_input(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
sd3,
|
||||
flux1,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1369,7 +1404,11 @@ def validate_file_path(file_path: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def validate_folder_path(folder_path: str, can_be_written_to: bool = False, create_if_not_exists: bool = False) -> bool:
|
||||
def validate_folder_path(
|
||||
folder_path: str,
|
||||
can_be_written_to: bool = False,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> bool:
|
||||
if folder_path == "":
|
||||
return True
|
||||
msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..."
|
||||
|
|
@ -1387,6 +1426,7 @@ def validate_folder_path(folder_path: str, can_be_written_to: bool = False, crea
|
|||
log.info(f"{msg} SUCCESS")
|
||||
return True
|
||||
|
||||
|
||||
def validate_toml_file(file_path: str) -> bool:
|
||||
if file_path == "":
|
||||
return True
|
||||
|
|
@ -1425,11 +1465,14 @@ def validate_model_path(pretrained_model_name_or_path: str) -> bool:
|
|||
log.info(f"{msg} SUCCESS")
|
||||
else:
|
||||
# If not one of the default models, check if it's a valid local path
|
||||
if not validate_file_path(pretrained_model_name_or_path) and not validate_folder_path(pretrained_model_name_or_path):
|
||||
if not validate_file_path(
|
||||
pretrained_model_name_or_path
|
||||
) and not validate_folder_path(pretrained_model_name_or_path):
|
||||
log.info(f"{msg} FAILURE: not a valid file or folder")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_file_writable(file_path: str) -> bool:
|
||||
"""
|
||||
Checks if a file is writable.
|
||||
|
|
@ -1450,8 +1493,9 @@ def is_file_writable(file_path: str) -> bool:
|
|||
pass
|
||||
# If the file can be opened, it is considered writable
|
||||
return True
|
||||
except IOError:
|
||||
except IOError as e:
|
||||
# If an IOError occurs, the file cannot be written to
|
||||
log.info(f"Error: {e}. File '{file_path}' is not writable.")
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -1462,7 +1506,7 @@ def print_command_and_toml(run_cmd, tmpfilename):
|
|||
# Reconstruct the safe command string for display
|
||||
command_to_run = " ".join(run_cmd)
|
||||
|
||||
log.info(command_to_run)
|
||||
print(command_to_run)
|
||||
print("")
|
||||
|
||||
log.info(f"Showing toml config file: {tmpfilename}")
|
||||
|
|
@ -1489,10 +1533,11 @@ def validate_args_setting(input_string):
|
|||
)
|
||||
return False
|
||||
|
||||
|
||||
def setup_environment():
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = (
|
||||
fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging as log
|
||||
from easygui import msgbox
|
||||
|
||||
def dataset_balancing(concept_repeats, folder, insecure):
|
||||
|
||||
if not concept_repeats > 0:
|
||||
|
|
@ -78,7 +83,11 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||
old_name = os.path.join(folder, subdir)
|
||||
new_name = os.path.join(folder, f"{repeats}_{subdir}")
|
||||
|
||||
os.rename(old_name, new_name)
|
||||
# Check if the new folder name already exists
|
||||
if os.path.exists(new_name):
|
||||
log.warning(f"Destination folder {new_name} already exists. Skipping...")
|
||||
else:
|
||||
os.rename(old_name, new_name)
|
||||
else:
|
||||
log.info(
|
||||
f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..."
|
||||
|
|
@ -87,6 +96,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||
msgbox("Dataset balancing completed...")
|
||||
|
||||
|
||||
|
||||
def warning(insecure):
|
||||
if insecure:
|
||||
if boolbox(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ from .common_gui import (
|
|||
SaveConfigFile,
|
||||
scriptdir,
|
||||
update_my_data,
|
||||
validate_file_path, validate_folder_path, validate_model_path,
|
||||
validate_file_path,
|
||||
validate_folder_path,
|
||||
validate_model_path,
|
||||
validate_args_setting,
|
||||
setup_environment,
|
||||
)
|
||||
|
|
@ -27,10 +29,13 @@ from .class_gui_config import KohyaSSGUIConfig
|
|||
from .class_source_model import SourceModel
|
||||
from .class_basic_training import BasicTraining
|
||||
from .class_advanced_training import AdvancedTraining
|
||||
from .class_sd3 import sd3Training
|
||||
from .class_folders import Folders
|
||||
from .class_command_executor import CommandExecutor
|
||||
from .class_huggingface import HuggingFace
|
||||
from .class_metadata import MetaData
|
||||
from .class_sdxl_parameters import SDXLParameters
|
||||
from .class_flux1 import flux1Training
|
||||
|
||||
from .dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
|
|
@ -60,6 +65,7 @@ def save_configuration(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
flux1_checkbox,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -72,6 +78,7 @@ def save_configuration(
|
|||
learning_rate_te2,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
|
|
@ -84,6 +91,7 @@ def save_configuration(
|
|||
caption_extension,
|
||||
enable_bucket,
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
full_fp16,
|
||||
full_bf16,
|
||||
no_token_padding,
|
||||
|
|
@ -134,6 +142,7 @@ def save_configuration(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -150,18 +159,28 @@ def save_configuration(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
weighted_captions,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
fused_backward_pass,
|
||||
fused_optimizer_groups,
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
debiased_estimation_loss,
|
||||
|
|
@ -178,6 +197,44 @@ def save_configuration(
|
|||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_fused_backward_pass,
|
||||
clip_g,
|
||||
clip_l,
|
||||
logit_mean,
|
||||
logit_std,
|
||||
mode_scale,
|
||||
save_clip,
|
||||
save_t5xxl,
|
||||
t5xxl,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Flux.1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
ae,
|
||||
flux1_clip_l,
|
||||
flux1_t5xxl,
|
||||
discrete_flow_shift,
|
||||
model_prediction_type,
|
||||
timestep_sampling,
|
||||
split_mode,
|
||||
train_blocks,
|
||||
t5xxl_max_token_length,
|
||||
guidance_scale,
|
||||
blockwise_fused_optimizers,
|
||||
flux_fused_backward_pass,
|
||||
cpu_offload_checkpointing,
|
||||
blocks_to_swap,
|
||||
single_blocks_to_swap,
|
||||
double_blocks_to_swap,
|
||||
mem_eff_save,
|
||||
apply_t5_attn_mask,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -218,6 +275,7 @@ def open_configuration(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
flux1_checkbox,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -230,6 +288,7 @@ def open_configuration(
|
|||
learning_rate_te2,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
|
|
@ -242,6 +301,7 @@ def open_configuration(
|
|||
caption_extension,
|
||||
enable_bucket,
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
full_fp16,
|
||||
full_bf16,
|
||||
no_token_padding,
|
||||
|
|
@ -292,6 +352,7 @@ def open_configuration(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -308,18 +369,28 @@ def open_configuration(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
weighted_captions,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
fused_backward_pass,
|
||||
fused_optimizer_groups,
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
debiased_estimation_loss,
|
||||
|
|
@ -336,6 +407,44 @@ def open_configuration(
|
|||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_fused_backward_pass,
|
||||
clip_g,
|
||||
clip_l,
|
||||
logit_mean,
|
||||
logit_std,
|
||||
mode_scale,
|
||||
save_clip,
|
||||
save_t5xxl,
|
||||
t5xxl,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Flux.1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
ae,
|
||||
flux1_clip_l,
|
||||
flux1_t5xxl,
|
||||
discrete_flow_shift,
|
||||
model_prediction_type,
|
||||
timestep_sampling,
|
||||
split_mode,
|
||||
train_blocks,
|
||||
t5xxl_max_token_length,
|
||||
guidance_scale,
|
||||
blockwise_fused_optimizers,
|
||||
flux_fused_backward_pass,
|
||||
cpu_offload_checkpointing,
|
||||
blocks_to_swap,
|
||||
single_blocks_to_swap,
|
||||
double_blocks_to_swap,
|
||||
mem_eff_save,
|
||||
apply_t5_attn_mask,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -371,6 +480,7 @@ def train_model(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
flux1_checkbox,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -383,6 +493,7 @@ def train_model(
|
|||
learning_rate_te2,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
|
|
@ -395,6 +506,7 @@ def train_model(
|
|||
caption_extension,
|
||||
enable_bucket,
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
full_fp16,
|
||||
full_bf16,
|
||||
no_token_padding,
|
||||
|
|
@ -445,6 +557,7 @@ def train_model(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -461,18 +574,28 @@ def train_model(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
weighted_captions,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
fused_backward_pass,
|
||||
fused_optimizer_groups,
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
debiased_estimation_loss,
|
||||
|
|
@ -489,6 +612,44 @@ def train_model(
|
|||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_fused_backward_pass,
|
||||
clip_g,
|
||||
clip_l,
|
||||
logit_mean,
|
||||
logit_std,
|
||||
mode_scale,
|
||||
save_clip,
|
||||
save_t5xxl,
|
||||
t5xxl,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Flux.1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
ae,
|
||||
flux1_clip_l,
|
||||
flux1_t5xxl,
|
||||
discrete_flow_shift,
|
||||
model_prediction_type,
|
||||
timestep_sampling,
|
||||
split_mode,
|
||||
train_blocks,
|
||||
t5xxl_max_token_length,
|
||||
guidance_scale,
|
||||
blockwise_fused_optimizers,
|
||||
flux_fused_backward_pass,
|
||||
cpu_offload_checkpointing,
|
||||
blocks_to_swap,
|
||||
single_blocks_to_swap,
|
||||
double_blocks_to_swap,
|
||||
mem_eff_save,
|
||||
apply_t5_attn_mask,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -524,10 +685,14 @@ def train_model(
|
|||
if not validate_file_path(log_tracker_config):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True):
|
||||
if not validate_folder_path(
|
||||
logging_dir, can_be_written_to=True, create_if_not_exists=True
|
||||
):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True):
|
||||
if not validate_folder_path(
|
||||
output_dir, can_be_written_to=True, create_if_not_exists=True
|
||||
):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not validate_model_path(pretrained_model_name_or_path):
|
||||
|
|
@ -544,26 +709,11 @@ def train_model(
|
|||
|
||||
if not validate_model_path(vae):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
#
|
||||
# End of path validation
|
||||
#
|
||||
|
||||
# This function validates files or folder paths. Simply add new variables containing file of folder path
|
||||
# to validate below
|
||||
# if not validate_paths(
|
||||
# dataset_config=dataset_config,
|
||||
# headless=headless,
|
||||
# log_tracker_config=log_tracker_config,
|
||||
# logging_dir=logging_dir,
|
||||
# output_dir=output_dir,
|
||||
# pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
# reg_data_dir=reg_data_dir,
|
||||
# resume=resume,
|
||||
# train_data_dir=train_data_dir,
|
||||
# vae=vae,
|
||||
# ):
|
||||
# return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not print_only and check_if_model_exist(
|
||||
output_name, output_dir, save_model_as, headless=headless
|
||||
):
|
||||
|
|
@ -573,15 +723,6 @@ def train_model(
|
|||
log.info(
|
||||
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
|
||||
)
|
||||
if max_train_steps > 0:
|
||||
if lr_warmup != 0:
|
||||
lr_warmup_steps = round(
|
||||
float(int(lr_warmup) * int(max_train_steps) / 100)
|
||||
)
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
|
||||
if max_train_steps == 0:
|
||||
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
|
||||
|
|
@ -640,11 +781,11 @@ def train_model(
|
|||
reg_factor = 1
|
||||
else:
|
||||
log.warning(
|
||||
"Regularisation images are used... Will double the number of steps required..."
|
||||
"Regularization images are used... Will double the number of steps required..."
|
||||
)
|
||||
reg_factor = 2
|
||||
|
||||
log.info(f"Regulatization factor: {reg_factor}")
|
||||
log.info(f"Regularization factor: {reg_factor}")
|
||||
|
||||
if max_train_steps == 0:
|
||||
# calculate max_train_steps
|
||||
|
|
@ -664,13 +805,18 @@ def train_model(
|
|||
else:
|
||||
max_train_steps_info = f"Max train steps: {max_train_steps}"
|
||||
|
||||
if lr_warmup != 0:
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
|
||||
log.info(f"Total steps: {total_steps}")
|
||||
|
||||
# Calculate lr_warmup_steps
|
||||
if lr_warmup_steps > 0:
|
||||
lr_warmup_steps = int(lr_warmup_steps)
|
||||
if lr_warmup > 0:
|
||||
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||
elif lr_warmup != 0:
|
||||
lr_warmup_steps = lr_warmup / 100
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
|
||||
log.info(f"Train batch size: {train_batch_size}")
|
||||
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
|
||||
log.info(f"Epoch: {epoch}")
|
||||
|
|
@ -682,7 +828,7 @@ def train_model(
|
|||
log.error("accelerate not found")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
run_cmd = [rf'{accelerate_path}', "launch"]
|
||||
run_cmd = [rf"{accelerate_path}", "launch"]
|
||||
|
||||
run_cmd = AccelerateLaunch.run_cmd(
|
||||
run_cmd=run_cmd,
|
||||
|
|
@ -701,10 +847,23 @@ def train_model(
|
|||
)
|
||||
|
||||
if sdxl:
|
||||
run_cmd.append(rf'{scriptdir}/sd-scripts/sdxl_train.py')
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
|
||||
elif sd3_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py")
|
||||
elif flux1_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py")
|
||||
else:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py")
|
||||
|
||||
cache_text_encoder_outputs = (
|
||||
(sdxl and sdxl_cache_text_encoder_outputs)
|
||||
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
|
||||
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
|
||||
)
|
||||
cache_text_encoder_outputs_to_disk = (
|
||||
sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
|
||||
) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
|
||||
no_half_vae = sdxl and sdxl_no_half_vae
|
||||
if max_data_loader_n_workers == "" or None:
|
||||
max_data_loader_n_workers = 0
|
||||
else:
|
||||
|
|
@ -715,6 +874,19 @@ def train_model(
|
|||
else:
|
||||
max_train_steps = int(max_train_steps)
|
||||
|
||||
if sdxl:
|
||||
train_text_encoder = (learning_rate_te1 != None and learning_rate_te1 > 0) or (
|
||||
learning_rate_te2 != None and learning_rate_te2 > 0
|
||||
)
|
||||
|
||||
fused_backward_pass_value = False
|
||||
if sd3_checkbox:
|
||||
fused_backward_pass_value = sd3_fused_backward_pass
|
||||
elif flux1_checkbox:
|
||||
fused_backward_pass_value = flux_fused_backward_pass
|
||||
else:
|
||||
fused_backward_pass_value = fused_backward_pass
|
||||
|
||||
# def save_huggingface_to_toml(self, toml_file_path: str):
|
||||
config_toml_data = {
|
||||
# Update the values in the TOML data
|
||||
|
|
@ -724,22 +896,32 @@ def train_model(
|
|||
"bucket_reso_steps": bucket_reso_steps,
|
||||
"cache_latents": cache_latents,
|
||||
"cache_latents_to_disk": cache_latents_to_disk,
|
||||
"cache_text_encoder_outputs": cache_text_encoder_outputs,
|
||||
"cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk,
|
||||
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
||||
"caption_dropout_rate": caption_dropout_rate,
|
||||
"caption_extension": caption_extension,
|
||||
"clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None,
|
||||
"clip_skip": clip_skip if clip_skip != 0 else None,
|
||||
"color_aug": color_aug,
|
||||
"dataset_config": dataset_config,
|
||||
"debiased_estimation_loss": debiased_estimation_loss,
|
||||
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
|
||||
"dynamo_backend": dynamo_backend,
|
||||
"enable_bucket": enable_bucket,
|
||||
"epoch": int(epoch),
|
||||
"flip_aug": flip_aug,
|
||||
"fp8_base": fp8_base,
|
||||
"full_bf16": full_bf16,
|
||||
"full_fp16": full_fp16,
|
||||
"fused_backward_pass": fused_backward_pass_value,
|
||||
"fused_optimizer_groups": (
|
||||
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
|
||||
),
|
||||
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"huber_c": huber_c,
|
||||
"huber_scale": huber_scale,
|
||||
"huber_schedule": huber_schedule,
|
||||
"huggingface_path_in_repo": huggingface_path_in_repo,
|
||||
"huggingface_repo_id": huggingface_repo_id,
|
||||
|
|
@ -750,16 +932,11 @@ def train_model(
|
|||
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
|
||||
"keep_tokens": int(keep_tokens),
|
||||
"learning_rate": learning_rate, # both for sd1.5 and sdxl
|
||||
"learning_rate_te": (
|
||||
learning_rate_te if not sdxl and not 0 else None
|
||||
), # only for sd1.5 and not 0
|
||||
"learning_rate_te1": (
|
||||
learning_rate_te1 if sdxl and not 0 else None
|
||||
), # only for sdxl and not 0
|
||||
"learning_rate_te2": (
|
||||
learning_rate_te2 if sdxl and not 0 else None
|
||||
), # only for sdxl and not 0
|
||||
"learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5
|
||||
"learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl
|
||||
"learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl
|
||||
"logging_dir": logging_dir,
|
||||
"log_config": log_config,
|
||||
"log_tracker_config": log_tracker_config,
|
||||
"log_tracker_name": log_tracker_name,
|
||||
"log_with": log_with,
|
||||
|
|
@ -767,15 +944,20 @@ def train_model(
|
|||
"lr_scheduler": lr_scheduler,
|
||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
||||
"lr_scheduler_num_cycles": (
|
||||
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
|
||||
int(lr_scheduler_num_cycles)
|
||||
if lr_scheduler_num_cycles != ""
|
||||
else int(epoch)
|
||||
),
|
||||
"lr_scheduler_power": lr_scheduler_power,
|
||||
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
|
||||
"lr_warmup_steps": lr_warmup_steps,
|
||||
"masked_loss": masked_loss,
|
||||
"max_bucket_reso": max_bucket_reso,
|
||||
"max_timestep": max_timestep if max_timestep != 0 else None,
|
||||
"max_token_length": int(max_token_length),
|
||||
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
|
||||
"max_train_epochs": (
|
||||
int(max_train_epochs) if int(max_train_epochs) != 0 else None
|
||||
),
|
||||
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
|
||||
"mem_eff_attn": mem_eff_attn,
|
||||
"metadata_author": metadata_author,
|
||||
|
|
@ -789,6 +971,7 @@ def train_model(
|
|||
"mixed_precision": mixed_precision,
|
||||
"multires_noise_discount": multires_noise_discount,
|
||||
"multires_noise_iterations": multires_noise_iterations if not 0 else None,
|
||||
"no_half_vae": no_half_vae,
|
||||
"no_token_padding": no_token_padding,
|
||||
"noise_offset": noise_offset if not 0 else None,
|
||||
"noise_offset_random_strength": noise_offset_random_strength,
|
||||
|
|
@ -825,6 +1008,10 @@ def train_model(
|
|||
"save_last_n_steps_state": (
|
||||
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
||||
),
|
||||
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
|
||||
"save_last_n_epochs_state": (
|
||||
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
|
||||
),
|
||||
"save_model_as": save_model_as,
|
||||
"save_precision": save_precision,
|
||||
"save_state": save_state,
|
||||
|
|
@ -834,20 +1021,65 @@ def train_model(
|
|||
"sdpa": True if xformers == "sdpa" else None,
|
||||
"seed": int(seed) if int(seed) != 0 else None,
|
||||
"shuffle_caption": shuffle_caption,
|
||||
"skip_cache_check": skip_cache_check,
|
||||
"stop_text_encoder_training": (
|
||||
stop_text_encoder_training if stop_text_encoder_training != 0 else None
|
||||
),
|
||||
"t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None,
|
||||
"train_batch_size": train_batch_size,
|
||||
"train_data_dir": train_data_dir,
|
||||
"train_text_encoder": train_text_encoder if sdxl else None,
|
||||
"v2": v2,
|
||||
"v_parameterization": v_parameterization,
|
||||
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
||||
"vae": vae,
|
||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||
"wandb_api_key": wandb_api_key,
|
||||
"wandb_run_name": wandb_run_name,
|
||||
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||
"weighted_captions": weighted_captions,
|
||||
"xformers": True if xformers == "xformers" else None,
|
||||
# SD3 only Parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
"clip_g": clip_g if sd3_checkbox else None,
|
||||
# "clip_l": see previous assignment above for code
|
||||
"logit_mean": logit_mean if sd3_checkbox else None,
|
||||
"logit_std": logit_std if sd3_checkbox else None,
|
||||
"mode_scale": mode_scale if sd3_checkbox else None,
|
||||
"save_clip": save_clip if sd3_checkbox else None,
|
||||
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
|
||||
# "t5xxl": see previous assignment above for code
|
||||
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
|
||||
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
|
||||
"text_encoder_batch_size": (
|
||||
sd3_text_encoder_batch_size if sd3_checkbox else None
|
||||
),
|
||||
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
|
||||
# Flux.1 specific parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
"ae": ae if flux1_checkbox else None,
|
||||
# "clip_l": see previous assignment above for code
|
||||
# "t5xxl": see previous assignment above for code
|
||||
"discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None,
|
||||
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
|
||||
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
|
||||
"split_mode": split_mode if flux1_checkbox else None,
|
||||
"train_blocks": train_blocks if flux1_checkbox else None,
|
||||
"t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None,
|
||||
"guidance_scale": guidance_scale if flux1_checkbox else None,
|
||||
"blockwise_fused_optimizers": (
|
||||
blockwise_fused_optimizers if flux1_checkbox else None
|
||||
),
|
||||
# "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code
|
||||
"cpu_offload_checkpointing": (
|
||||
cpu_offload_checkpointing if flux1_checkbox else None
|
||||
),
|
||||
"blocks_to_swap": blocks_to_swap if flux1_checkbox or sd3_checkbox else None,
|
||||
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
|
||||
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
|
||||
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
|
||||
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
|
||||
}
|
||||
|
||||
# Given dictionary `config_toml_data`
|
||||
|
|
@ -855,7 +1087,7 @@ def train_model(
|
|||
config_toml_data = {
|
||||
key: value
|
||||
for key, value in config_toml_data.items()
|
||||
if value not in ["", False, None]
|
||||
if not any([value == "", value is False, value is None])
|
||||
}
|
||||
|
||||
config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)
|
||||
|
|
@ -865,7 +1097,7 @@ def train_model(
|
|||
|
||||
current_datetime = datetime.now()
|
||||
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
||||
tmpfilename = fr"{output_dir}/config_dreambooth-{formatted_datetime}.toml"
|
||||
tmpfilename = rf"{output_dir}/config_dreambooth-{formatted_datetime}.toml"
|
||||
|
||||
# Save the updated TOML data back to the file
|
||||
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
|
||||
|
|
@ -875,7 +1107,7 @@ def train_model(
|
|||
log.error(f"Failed to write TOML file: {toml_file.name}")
|
||||
|
||||
run_cmd.append(f"--config_file")
|
||||
run_cmd.append(rf'{tmpfilename}')
|
||||
run_cmd.append(rf"{tmpfilename}")
|
||||
|
||||
# Initialize a dictionary with always-included keyword arguments
|
||||
kwargs_for_training = {
|
||||
|
|
@ -981,6 +1213,26 @@ def dreambooth_tab(
|
|||
config=config,
|
||||
)
|
||||
|
||||
# Add SDXL Parameters
|
||||
sdxl_params = SDXLParameters(
|
||||
source_model.sdxl_checkbox,
|
||||
config=config,
|
||||
trainer="finetune",
|
||||
)
|
||||
|
||||
# Add FLUX1 Parameters
|
||||
flux1_training = flux1Training(
|
||||
headless=headless,
|
||||
config=config,
|
||||
flux1_checkbox=source_model.flux1_checkbox,
|
||||
finetuning=True,
|
||||
)
|
||||
|
||||
# Add SD3 Parameters
|
||||
sd3_training = sd3Training(
|
||||
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
||||
advanced_training = AdvancedTraining(headless=headless, config=config)
|
||||
advanced_training.color_aug.change(
|
||||
|
|
@ -1011,6 +1263,7 @@ def dreambooth_tab(
|
|||
source_model.v2,
|
||||
source_model.v_parameterization,
|
||||
source_model.sdxl_checkbox,
|
||||
source_model.flux1_checkbox,
|
||||
folders.logging_dir,
|
||||
source_model.train_data_dir,
|
||||
folders.reg_data_dir,
|
||||
|
|
@ -1023,6 +1276,7 @@ def dreambooth_tab(
|
|||
basic_training.learning_rate_te2,
|
||||
basic_training.lr_scheduler,
|
||||
basic_training.lr_warmup,
|
||||
basic_training.lr_warmup_steps,
|
||||
basic_training.train_batch_size,
|
||||
basic_training.epoch,
|
||||
basic_training.save_every_n_epochs,
|
||||
|
|
@ -1035,6 +1289,7 @@ def dreambooth_tab(
|
|||
basic_training.caption_extension,
|
||||
basic_training.enable_bucket,
|
||||
advanced_training.gradient_checkpointing,
|
||||
advanced_training.fp8_base,
|
||||
advanced_training.full_fp16,
|
||||
advanced_training.full_bf16,
|
||||
advanced_training.no_token_padding,
|
||||
|
|
@ -1084,6 +1339,7 @@ def dreambooth_tab(
|
|||
basic_training.optimizer,
|
||||
basic_training.optimizer_args,
|
||||
basic_training.lr_scheduler_args,
|
||||
basic_training.lr_scheduler_type,
|
||||
advanced_training.noise_offset_type,
|
||||
advanced_training.noise_offset,
|
||||
advanced_training.noise_offset_random_strength,
|
||||
|
|
@ -1100,18 +1356,28 @@ def dreambooth_tab(
|
|||
advanced_training.loss_type,
|
||||
advanced_training.huber_schedule,
|
||||
advanced_training.huber_c,
|
||||
advanced_training.huber_scale,
|
||||
advanced_training.vae_batch_size,
|
||||
advanced_training.min_snr_gamma,
|
||||
advanced_training.weighted_captions,
|
||||
advanced_training.save_every_n_steps,
|
||||
advanced_training.save_last_n_steps,
|
||||
advanced_training.save_last_n_steps_state,
|
||||
advanced_training.save_last_n_epochs,
|
||||
advanced_training.save_last_n_epochs_state,
|
||||
advanced_training.skip_cache_check,
|
||||
advanced_training.log_with,
|
||||
advanced_training.wandb_api_key,
|
||||
advanced_training.wandb_run_name,
|
||||
advanced_training.log_tracker_name,
|
||||
advanced_training.log_tracker_config,
|
||||
advanced_training.log_config,
|
||||
advanced_training.scale_v_pred_loss_like_noise_pred,
|
||||
sdxl_params.disable_mmap_load_safetensors,
|
||||
sdxl_params.fused_backward_pass,
|
||||
sdxl_params.fused_optimizer_groups,
|
||||
sdxl_params.sdxl_cache_text_encoder_outputs,
|
||||
sdxl_params.sdxl_no_half_vae,
|
||||
advanced_training.min_timestep,
|
||||
advanced_training.max_timestep,
|
||||
advanced_training.debiased_estimation_loss,
|
||||
|
|
@ -1128,6 +1394,44 @@ def dreambooth_tab(
|
|||
metadata.metadata_license,
|
||||
metadata.metadata_tags,
|
||||
metadata.metadata_title,
|
||||
# SD3 Parameters
|
||||
sd3_training.sd3_cache_text_encoder_outputs,
|
||||
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_training.sd3_fused_backward_pass,
|
||||
sd3_training.clip_g,
|
||||
sd3_training.clip_l,
|
||||
sd3_training.logit_mean,
|
||||
sd3_training.logit_std,
|
||||
sd3_training.mode_scale,
|
||||
sd3_training.save_clip,
|
||||
sd3_training.save_t5xxl,
|
||||
sd3_training.t5xxl,
|
||||
sd3_training.t5xxl_device,
|
||||
sd3_training.t5xxl_dtype,
|
||||
sd3_training.sd3_text_encoder_batch_size,
|
||||
sd3_training.weighting_scheme,
|
||||
source_model.sd3_checkbox,
|
||||
# Flux1 parameters
|
||||
flux1_training.flux1_cache_text_encoder_outputs,
|
||||
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
|
||||
flux1_training.ae,
|
||||
flux1_training.clip_l,
|
||||
flux1_training.t5xxl,
|
||||
flux1_training.discrete_flow_shift,
|
||||
flux1_training.model_prediction_type,
|
||||
flux1_training.timestep_sampling,
|
||||
flux1_training.split_mode,
|
||||
flux1_training.train_blocks,
|
||||
flux1_training.t5xxl_max_token_length,
|
||||
flux1_training.guidance_scale,
|
||||
flux1_training.blockwise_fused_optimizers,
|
||||
flux1_training.flux_fused_backward_pass,
|
||||
flux1_training.cpu_offload_checkpointing,
|
||||
advanced_training.blocks_to_swap,
|
||||
flux1_training.single_blocks_to_swap,
|
||||
flux1_training.double_blocks_to_swap,
|
||||
flux1_training.mem_eff_save,
|
||||
flux1_training.apply_t5_attn_mask,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from .common_gui import (
|
|||
)
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
from .sd_modeltype import SDModelType
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -337,6 +338,19 @@ def gradio_extract_lora_tab(
|
|||
outputs=[load_tuned_model_to, load_original_model_to],
|
||||
)
|
||||
|
||||
#secondary event on model_tuned for auto-detection of v2/SDXL
|
||||
def change_modeltype_model_tuned(path):
|
||||
detect = SDModelType(path)
|
||||
v2 = gr.Checkbox(value=detect.Is_SD2())
|
||||
sdxl = gr.Checkbox(value=detect.Is_SDXL())
|
||||
return v2, sdxl
|
||||
|
||||
model_tuned.change(
|
||||
change_modeltype_model_tuned,
|
||||
inputs=model_tuned,
|
||||
outputs=[v2, sdxl]
|
||||
)
|
||||
|
||||
extract_button = gr.Button("Extract LoRA model")
|
||||
|
||||
extract_button.click(
|
||||
|
|
|
|||
|
|
@ -18,14 +18,18 @@ from .common_gui import (
|
|||
SaveConfigFile,
|
||||
scriptdir,
|
||||
update_my_data,
|
||||
validate_file_path, validate_folder_path, validate_model_path,
|
||||
validate_args_setting, setup_environment,
|
||||
validate_file_path,
|
||||
validate_folder_path,
|
||||
validate_model_path,
|
||||
validate_args_setting,
|
||||
setup_environment,
|
||||
)
|
||||
from .class_accelerate_launch import AccelerateLaunch
|
||||
from .class_configuration_file import ConfigurationFile
|
||||
from .class_source_model import SourceModel
|
||||
from .class_basic_training import BasicTraining
|
||||
from .class_advanced_training import AdvancedTraining
|
||||
from .class_sd3 import sd3Training
|
||||
from .class_folders import Folders
|
||||
from .class_sdxl_parameters import SDXLParameters
|
||||
from .class_command_executor import CommandExecutor
|
||||
|
|
@ -34,6 +38,7 @@ from .class_sample_images import SampleImages, create_prompt_file
|
|||
from .class_huggingface import HuggingFace
|
||||
from .class_metadata import MetaData
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
from .class_flux1 import flux1Training
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
|
|
@ -65,6 +70,7 @@ def save_configuration(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl_checkbox,
|
||||
flux1_checkbox,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -82,6 +88,7 @@ def save_configuration(
|
|||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
dataset_repeats,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
|
|
@ -116,6 +123,7 @@ def save_configuration(
|
|||
save_state_on_train_end,
|
||||
resume,
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
mem_eff_attn,
|
||||
|
|
@ -142,6 +150,7 @@ def save_configuration(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -158,18 +167,26 @@ def save_configuration(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
weighted_captions,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
fused_backward_pass,
|
||||
fused_optimizer_groups,
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
|
|
@ -188,6 +205,44 @@ def save_configuration(
|
|||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_fused_backward_pass,
|
||||
clip_g,
|
||||
clip_l,
|
||||
logit_mean,
|
||||
logit_std,
|
||||
mode_scale,
|
||||
save_clip,
|
||||
save_t5xxl,
|
||||
t5xxl,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Flux.1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
ae,
|
||||
flux1_clip_l,
|
||||
flux1_t5xxl,
|
||||
discrete_flow_shift,
|
||||
model_prediction_type,
|
||||
timestep_sampling,
|
||||
split_mode,
|
||||
train_blocks,
|
||||
t5xxl_max_token_length,
|
||||
guidance_scale,
|
||||
blockwise_fused_optimizers,
|
||||
flux_fused_backward_pass,
|
||||
cpu_offload_checkpointing,
|
||||
blocks_to_swap,
|
||||
single_blocks_to_swap,
|
||||
double_blocks_to_swap,
|
||||
mem_eff_save,
|
||||
apply_t5_attn_mask,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -231,6 +286,7 @@ def open_configuration(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl_checkbox,
|
||||
flux1_checkbox,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -248,6 +304,7 @@ def open_configuration(
|
|||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
dataset_repeats,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
|
|
@ -282,6 +339,7 @@ def open_configuration(
|
|||
save_state_on_train_end,
|
||||
resume,
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
mem_eff_attn,
|
||||
|
|
@ -308,6 +366,7 @@ def open_configuration(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -324,18 +383,26 @@ def open_configuration(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
weighted_captions,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
fused_backward_pass,
|
||||
fused_optimizer_groups,
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
|
|
@ -354,6 +421,44 @@ def open_configuration(
|
|||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_fused_backward_pass,
|
||||
clip_g,
|
||||
clip_l,
|
||||
logit_mean,
|
||||
logit_std,
|
||||
mode_scale,
|
||||
save_clip,
|
||||
save_t5xxl,
|
||||
t5xxl,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Flux.1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
ae,
|
||||
flux1_clip_l,
|
||||
flux1_t5xxl,
|
||||
discrete_flow_shift,
|
||||
model_prediction_type,
|
||||
timestep_sampling,
|
||||
split_mode,
|
||||
train_blocks,
|
||||
t5xxl_max_token_length,
|
||||
guidance_scale,
|
||||
blockwise_fused_optimizers,
|
||||
flux_fused_backward_pass,
|
||||
cpu_offload_checkpointing,
|
||||
blocks_to_swap,
|
||||
single_blocks_to_swap,
|
||||
double_blocks_to_swap,
|
||||
mem_eff_save,
|
||||
apply_t5_attn_mask,
|
||||
training_preset,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
|
|
@ -403,6 +508,7 @@ def train_model(
|
|||
v2,
|
||||
v_parameterization,
|
||||
sdxl_checkbox,
|
||||
flux1_checkbox,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -420,6 +526,7 @@ def train_model(
|
|||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
dataset_repeats,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
|
|
@ -454,6 +561,7 @@ def train_model(
|
|||
save_state_on_train_end,
|
||||
resume,
|
||||
gradient_checkpointing,
|
||||
fp8_base,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
mem_eff_attn,
|
||||
|
|
@ -480,6 +588,7 @@ def train_model(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -496,18 +605,26 @@ def train_model(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
weighted_captions,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
fused_backward_pass,
|
||||
fused_optimizer_groups,
|
||||
sdxl_cache_text_encoder_outputs,
|
||||
sdxl_no_half_vae,
|
||||
min_timestep,
|
||||
|
|
@ -526,6 +643,44 @@ def train_model(
|
|||
metadata_license,
|
||||
metadata_tags,
|
||||
metadata_title,
|
||||
# SD3 parameters
|
||||
sd3_cache_text_encoder_outputs,
|
||||
sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_fused_backward_pass,
|
||||
clip_g,
|
||||
clip_l,
|
||||
logit_mean,
|
||||
logit_std,
|
||||
mode_scale,
|
||||
save_clip,
|
||||
save_t5xxl,
|
||||
t5xxl,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
sd3_text_encoder_batch_size,
|
||||
weighting_scheme,
|
||||
sd3_checkbox,
|
||||
# Flux.1
|
||||
flux1_cache_text_encoder_outputs,
|
||||
flux1_cache_text_encoder_outputs_to_disk,
|
||||
ae,
|
||||
flux1_clip_l,
|
||||
flux1_t5xxl,
|
||||
discrete_flow_shift,
|
||||
model_prediction_type,
|
||||
timestep_sampling,
|
||||
split_mode,
|
||||
train_blocks,
|
||||
t5xxl_max_token_length,
|
||||
guidance_scale,
|
||||
blockwise_fused_optimizers,
|
||||
flux_fused_backward_pass,
|
||||
cpu_offload_checkpointing,
|
||||
blocks_to_swap,
|
||||
single_blocks_to_swap,
|
||||
double_blocks_to_swap,
|
||||
mem_eff_save,
|
||||
apply_t5_attn_mask,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
|
@ -569,10 +724,14 @@ def train_model(
|
|||
if not validate_file_path(log_tracker_config):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not validate_folder_path(logging_dir, can_be_written_to=True, create_if_not_exists=True):
|
||||
if not validate_folder_path(
|
||||
logging_dir, can_be_written_to=True, create_if_not_exists=True
|
||||
):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not validate_folder_path(output_dir, can_be_written_to=True, create_if_not_exists=True):
|
||||
if not validate_folder_path(
|
||||
output_dir, can_be_written_to=True, create_if_not_exists=True
|
||||
):
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not validate_model_path(pretrained_model_name_or_path):
|
||||
|
|
@ -585,18 +744,6 @@ def train_model(
|
|||
# End of path validation
|
||||
#
|
||||
|
||||
# if not validate_paths(
|
||||
# dataset_config=dataset_config,
|
||||
# finetune_image_folder=image_folder,
|
||||
# headless=headless,
|
||||
# log_tracker_config=log_tracker_config,
|
||||
# logging_dir=logging_dir,
|
||||
# output_dir=output_dir,
|
||||
# pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
# resume=resume,
|
||||
# ):
|
||||
# return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if not print_only and check_if_model_exist(
|
||||
output_name, output_dir, save_model_as, headless
|
||||
):
|
||||
|
|
@ -727,10 +874,16 @@ def train_model(
|
|||
|
||||
log.info(max_train_steps_info)
|
||||
|
||||
if max_train_steps != 0:
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
# Calculate lr_warmup_steps
|
||||
if lr_warmup_steps > 0:
|
||||
lr_warmup_steps = int(lr_warmup_steps)
|
||||
if lr_warmup > 0:
|
||||
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||
elif lr_warmup != 0:
|
||||
lr_warmup_steps = lr_warmup / 100
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
|
||||
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
||||
|
||||
accelerate_path = get_executable_path("accelerate")
|
||||
|
|
@ -738,7 +891,7 @@ def train_model(
|
|||
log.error("accelerate not found")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
run_cmd = [rf'{accelerate_path}', "launch"]
|
||||
run_cmd = [rf"{accelerate_path}", "launch"]
|
||||
|
||||
run_cmd = AccelerateLaunch.run_cmd(
|
||||
run_cmd=run_cmd,
|
||||
|
|
@ -758,6 +911,10 @@ def train_model(
|
|||
|
||||
if sdxl_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train.py")
|
||||
elif sd3_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train.py")
|
||||
elif flux1_checkbox:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train.py")
|
||||
else:
|
||||
run_cmd.append(rf"{scriptdir}/sd-scripts/fine_tune.py")
|
||||
|
||||
|
|
@ -766,7 +923,14 @@ def train_model(
|
|||
if use_latent_files == "Yes"
|
||||
else f"{train_dir}/{caption_metadata_filename}"
|
||||
)
|
||||
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
|
||||
cache_text_encoder_outputs = (
|
||||
(sdxl_checkbox and sdxl_cache_text_encoder_outputs)
|
||||
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
|
||||
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
|
||||
)
|
||||
cache_text_encoder_outputs_to_disk = (
|
||||
sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
|
||||
) or (flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk)
|
||||
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
|
||||
|
||||
if max_data_loader_n_workers == "" or None:
|
||||
|
|
@ -791,22 +955,31 @@ def train_model(
|
|||
"cache_latents": cache_latents,
|
||||
"cache_latents_to_disk": cache_latents_to_disk,
|
||||
"cache_text_encoder_outputs": cache_text_encoder_outputs,
|
||||
"cache_text_encoder_outputs_to_disk": cache_text_encoder_outputs_to_disk,
|
||||
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
|
||||
"caption_dropout_rate": caption_dropout_rate,
|
||||
"caption_extension": caption_extension,
|
||||
"clip_l": flux1_clip_l if flux1_checkbox else clip_l if sd3_checkbox else None,
|
||||
"clip_skip": clip_skip if clip_skip != 0 else None,
|
||||
"color_aug": color_aug,
|
||||
"dataset_config": dataset_config,
|
||||
"dataset_repeats": int(dataset_repeats),
|
||||
"debiased_estimation_loss": debiased_estimation_loss,
|
||||
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
|
||||
"dynamo_backend": dynamo_backend,
|
||||
"enable_bucket": True,
|
||||
"flip_aug": flip_aug,
|
||||
"fp8_base": fp8_base,
|
||||
"full_bf16": full_bf16,
|
||||
"full_fp16": full_fp16,
|
||||
"fused_backward_pass": sd3_fused_backward_pass if sd3_checkbox else flux_fused_backward_pass if flux1_checkbox else fused_backward_pass,
|
||||
"fused_optimizer_groups": (
|
||||
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
|
||||
),
|
||||
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"huber_c": huber_c,
|
||||
"huber_scale": huber_scale,
|
||||
"huber_schedule": huber_schedule,
|
||||
"huggingface_repo_id": huggingface_repo_id,
|
||||
"huggingface_token": huggingface_token,
|
||||
|
|
@ -828,11 +1001,13 @@ def train_model(
|
|||
learning_rate_te2 if sdxl_checkbox else None
|
||||
), # only for sdxl
|
||||
"logging_dir": logging_dir,
|
||||
"log_config": log_config,
|
||||
"log_tracker_name": log_tracker_name,
|
||||
"log_tracker_config": log_tracker_config,
|
||||
"loss_type": loss_type,
|
||||
"lr_scheduler": lr_scheduler,
|
||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
||||
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
|
||||
"lr_warmup_steps": lr_warmup_steps,
|
||||
"masked_loss": masked_loss,
|
||||
"max_bucket_reso": int(max_bucket_reso),
|
||||
|
|
@ -886,6 +1061,10 @@ def train_model(
|
|||
"save_last_n_steps_state": (
|
||||
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
||||
),
|
||||
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
|
||||
"save_last_n_epochs_state": (
|
||||
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
|
||||
),
|
||||
"save_model_as": save_model_as,
|
||||
"save_precision": save_precision,
|
||||
"save_state": save_state,
|
||||
|
|
@ -895,6 +1074,8 @@ def train_model(
|
|||
"sdpa": True if xformers == "sdpa" else None,
|
||||
"seed": int(seed) if int(seed) != 0 else None,
|
||||
"shuffle_caption": shuffle_caption,
|
||||
"skip_cache_check": skip_cache_check,
|
||||
"t5xxl": t5xxl if sd3_checkbox else flux1_t5xxl if flux1_checkbox else None,
|
||||
"train_batch_size": train_batch_size,
|
||||
"train_data_dir": image_folder,
|
||||
"train_text_encoder": train_text_encoder,
|
||||
|
|
@ -904,9 +1085,50 @@ def train_model(
|
|||
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
|
||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||
"wandb_api_key": wandb_api_key,
|
||||
"wandb_run_name": wandb_run_name,
|
||||
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||
"weighted_captions": weighted_captions,
|
||||
"xformers": True if xformers == "xformers" else None,
|
||||
# SD3 only Parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
"clip_g": clip_g if sd3_checkbox else None,
|
||||
# "clip_l": see previous assignment above for code
|
||||
"logit_mean": logit_mean if sd3_checkbox else None,
|
||||
"logit_std": logit_std if sd3_checkbox else None,
|
||||
"mode_scale": mode_scale if sd3_checkbox else None,
|
||||
"save_clip": save_clip if sd3_checkbox else None,
|
||||
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
|
||||
# "t5xxl": see previous assignment above for code
|
||||
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
|
||||
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
|
||||
"text_encoder_batch_size": (
|
||||
sd3_text_encoder_batch_size if sd3_checkbox else None
|
||||
),
|
||||
"weighting_scheme": weighting_scheme if sd3_checkbox else None,
|
||||
# Flux.1 specific parameters
|
||||
# "cache_text_encoder_outputs": see previous assignment above for code
|
||||
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
|
||||
"ae": ae if flux1_checkbox else None,
|
||||
# "clip_l": see previous assignment above for code
|
||||
# "t5xxl": see previous assignment above for code
|
||||
"discrete_flow_shift": discrete_flow_shift if flux1_checkbox else None,
|
||||
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
|
||||
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
|
||||
"split_mode": split_mode if flux1_checkbox else None,
|
||||
"train_blocks": train_blocks if flux1_checkbox else None,
|
||||
"t5xxl_max_token_length": t5xxl_max_token_length if flux1_checkbox else None,
|
||||
"guidance_scale": guidance_scale if flux1_checkbox else None,
|
||||
"blockwise_fused_optimizers": (
|
||||
blockwise_fused_optimizers if flux1_checkbox else None
|
||||
),
|
||||
"cpu_offload_checkpointing": (
|
||||
cpu_offload_checkpointing if flux1_checkbox else None
|
||||
),
|
||||
"blocks_to_swap": blocks_to_swap if flux1_checkbox else None,
|
||||
"single_blocks_to_swap": single_blocks_to_swap if flux1_checkbox else None,
|
||||
"double_blocks_to_swap": double_blocks_to_swap if flux1_checkbox else None,
|
||||
"mem_eff_save": mem_eff_save if flux1_checkbox else None,
|
||||
"apply_t5_attn_mask": apply_t5_attn_mask if flux1_checkbox else None,
|
||||
}
|
||||
|
||||
# Given dictionary `config_toml_data`
|
||||
|
|
@ -924,7 +1146,7 @@ def train_model(
|
|||
|
||||
current_datetime = datetime.now()
|
||||
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
||||
tmpfilename = fr"{output_dir}/config_finetune-{formatted_datetime}.toml"
|
||||
tmpfilename = rf"{output_dir}/config_finetune-{formatted_datetime}.toml"
|
||||
# Save the updated TOML data back to the file
|
||||
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
|
||||
toml.dump(config_toml_data, toml_file)
|
||||
|
|
@ -1090,7 +1312,9 @@ def finetune_tab(
|
|||
|
||||
# Add SDXL Parameters
|
||||
sdxl_params = SDXLParameters(
|
||||
source_model.sdxl_checkbox, config=config
|
||||
source_model.sdxl_checkbox,
|
||||
config=config,
|
||||
trainer="finetune",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
|
|
@ -1099,6 +1323,19 @@ def finetune_tab(
|
|||
label="Train text encoder", value=True
|
||||
)
|
||||
|
||||
# Add FLUX1 Parameters
|
||||
flux1_training = flux1Training(
|
||||
headless=headless,
|
||||
config=config,
|
||||
flux1_checkbox=source_model.flux1_checkbox,
|
||||
finetuning=True,
|
||||
)
|
||||
|
||||
# Add SD3 Parameters
|
||||
sd3_training = sd3Training(
|
||||
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
||||
with gr.Row():
|
||||
gradient_accumulation_steps = gr.Slider(
|
||||
|
|
@ -1146,6 +1383,7 @@ def finetune_tab(
|
|||
source_model.v2,
|
||||
source_model.v_parameterization,
|
||||
source_model.sdxl_checkbox,
|
||||
source_model.flux1_checkbox,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -1163,6 +1401,7 @@ def finetune_tab(
|
|||
basic_training.learning_rate,
|
||||
basic_training.lr_scheduler,
|
||||
basic_training.lr_warmup,
|
||||
basic_training.lr_warmup_steps,
|
||||
dataset_repeats,
|
||||
basic_training.train_batch_size,
|
||||
basic_training.epoch,
|
||||
|
|
@ -1196,6 +1435,7 @@ def finetune_tab(
|
|||
advanced_training.save_state_on_train_end,
|
||||
advanced_training.resume,
|
||||
advanced_training.gradient_checkpointing,
|
||||
advanced_training.fp8_base,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
advanced_training.mem_eff_attn,
|
||||
|
|
@ -1222,6 +1462,7 @@ def finetune_tab(
|
|||
basic_training.optimizer,
|
||||
basic_training.optimizer_args,
|
||||
basic_training.lr_scheduler_args,
|
||||
basic_training.lr_scheduler_type,
|
||||
advanced_training.noise_offset_type,
|
||||
advanced_training.noise_offset,
|
||||
advanced_training.noise_offset_random_strength,
|
||||
|
|
@ -1238,18 +1479,26 @@ def finetune_tab(
|
|||
advanced_training.loss_type,
|
||||
advanced_training.huber_schedule,
|
||||
advanced_training.huber_c,
|
||||
advanced_training.huber_scale,
|
||||
advanced_training.vae_batch_size,
|
||||
advanced_training.min_snr_gamma,
|
||||
weighted_captions,
|
||||
advanced_training.save_every_n_steps,
|
||||
advanced_training.save_last_n_steps,
|
||||
advanced_training.save_last_n_steps_state,
|
||||
advanced_training.save_last_n_epochs,
|
||||
advanced_training.save_last_n_epochs_state,
|
||||
advanced_training.skip_cache_check,
|
||||
advanced_training.log_with,
|
||||
advanced_training.wandb_api_key,
|
||||
advanced_training.wandb_run_name,
|
||||
advanced_training.log_tracker_name,
|
||||
advanced_training.log_tracker_config,
|
||||
advanced_training.log_config,
|
||||
advanced_training.scale_v_pred_loss_like_noise_pred,
|
||||
sdxl_params.disable_mmap_load_safetensors,
|
||||
sdxl_params.fused_backward_pass,
|
||||
sdxl_params.fused_optimizer_groups,
|
||||
sdxl_params.sdxl_cache_text_encoder_outputs,
|
||||
sdxl_params.sdxl_no_half_vae,
|
||||
advanced_training.min_timestep,
|
||||
|
|
@ -1268,6 +1517,44 @@ def finetune_tab(
|
|||
metadata.metadata_license,
|
||||
metadata.metadata_tags,
|
||||
metadata.metadata_title,
|
||||
# SD3 Parameters
|
||||
sd3_training.sd3_cache_text_encoder_outputs,
|
||||
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
|
||||
sd3_training.clip_g,
|
||||
sd3_training.clip_l,
|
||||
sd3_training.logit_mean,
|
||||
sd3_training.logit_std,
|
||||
sd3_training.mode_scale,
|
||||
sd3_training.save_clip,
|
||||
sd3_training.save_t5xxl,
|
||||
sd3_training.t5xxl,
|
||||
sd3_training.t5xxl_device,
|
||||
sd3_training.t5xxl_dtype,
|
||||
sd3_training.sd3_text_encoder_batch_size,
|
||||
sd3_training.sd3_fused_backward_pass,
|
||||
sd3_training.weighting_scheme,
|
||||
source_model.sd3_checkbox,
|
||||
# Flux1 parameters
|
||||
flux1_training.flux1_cache_text_encoder_outputs,
|
||||
flux1_training.flux1_cache_text_encoder_outputs_to_disk,
|
||||
flux1_training.ae,
|
||||
flux1_training.clip_l,
|
||||
flux1_training.t5xxl,
|
||||
flux1_training.discrete_flow_shift,
|
||||
flux1_training.model_prediction_type,
|
||||
flux1_training.timestep_sampling,
|
||||
flux1_training.split_mode,
|
||||
flux1_training.train_blocks,
|
||||
flux1_training.t5xxl_max_token_length,
|
||||
flux1_training.guidance_scale,
|
||||
flux1_training.blockwise_fused_optimizers,
|
||||
flux1_training.flux_fused_backward_pass,
|
||||
flux1_training.cpu_offload_checkpointing,
|
||||
advanced_training.blocks_to_swap,
|
||||
flux1_training.single_blocks_to_swap,
|
||||
flux1_training.double_blocks_to_swap,
|
||||
flux1_training.mem_eff_save,
|
||||
flux1_training.apply_t5_attn_mask,
|
||||
]
|
||||
|
||||
configuration.button_open_config.click(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,273 @@
|
|||
import gradio as gr
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
create_refresh_button,
|
||||
setup_environment,
|
||||
)
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
def extract_flux_lora(
|
||||
model_org,
|
||||
model_tuned,
|
||||
save_to,
|
||||
save_precision,
|
||||
dim,
|
||||
device,
|
||||
clamp_quantile,
|
||||
no_metadata,
|
||||
mem_eff_safe_open,
|
||||
):
|
||||
# Check for required inputs
|
||||
if model_org == "" or model_tuned == "" or save_to == "":
|
||||
log.info(
|
||||
"Please provide all required inputs: original model, tuned model, and save path."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if source models exist
|
||||
if not os.path.isfile(model_org):
|
||||
log.info("The provided original model is not a file")
|
||||
return
|
||||
|
||||
if not os.path.isfile(model_tuned):
|
||||
log.info("The provided tuned model is not a file")
|
||||
return
|
||||
|
||||
# Prepare save path
|
||||
if os.path.dirname(save_to) == "":
|
||||
save_to = os.path.join(os.path.dirname(model_tuned), save_to)
|
||||
if os.path.isdir(save_to):
|
||||
save_to = os.path.join(save_to, "flux_lora.safetensors")
|
||||
if os.path.normpath(model_tuned) == os.path.normpath(save_to):
|
||||
path, ext = os.path.splitext(save_to)
|
||||
save_to = f"{path}_lora{ext}"
|
||||
|
||||
run_cmd = [
|
||||
rf"{PYTHON}",
|
||||
rf"{scriptdir}/sd-scripts/networks/flux_extract_lora.py",
|
||||
"--model_org",
|
||||
rf"{model_org}",
|
||||
"--model_tuned",
|
||||
rf"{model_tuned}",
|
||||
"--save_to",
|
||||
rf"{save_to}",
|
||||
"--dim",
|
||||
str(dim),
|
||||
"--device",
|
||||
device,
|
||||
"--clamp_quantile",
|
||||
str(clamp_quantile),
|
||||
]
|
||||
|
||||
if save_precision:
|
||||
run_cmd.extend(["--save_precision", save_precision])
|
||||
|
||||
if no_metadata:
|
||||
run_cmd.append("--no_metadata")
|
||||
|
||||
if mem_eff_safe_open:
|
||||
run_cmd.append("--mem_eff_safe_open")
|
||||
|
||||
env = setup_environment()
|
||||
|
||||
# Reconstruct the safe command string for display
|
||||
command_to_run = " ".join(run_cmd)
|
||||
log.info(f"Executing command: {command_to_run}")
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
|
||||
|
||||
def gradio_flux_extract_lora_tab(headless=False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
|
||||
def list_models(path):
|
||||
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||
|
||||
with gr.Tab("Extract Flux LoRA"):
|
||||
gr.Markdown(
|
||||
"This utility can extract a LoRA network from a finetuned Flux model."
|
||||
)
|
||||
|
||||
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
model_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||
model_ext_name = gr.Textbox(value="Model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
model_org = gr.Dropdown(
|
||||
label="Original Flux model (path to the original model)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
model_org,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_model_org_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_org_file.click(
|
||||
get_file_path,
|
||||
inputs=[model_org, model_ext, model_ext_name],
|
||||
outputs=model_org,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
model_tuned = gr.Dropdown(
|
||||
label="Finetuned Flux model (path to the finetuned model to extract)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
model_tuned,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_model_tuned_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_tuned_file.click(
|
||||
get_file_path,
|
||||
inputs=[model_tuned, model_ext, model_ext_name],
|
||||
outputs=model_tuned,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
save_to = gr.Dropdown(
|
||||
label="Save to (path where to save the extracted LoRA model...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
inputs=[save_to, lora_ext, lora_ext_name],
|
||||
outputs=save_to,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
save_precision = gr.Dropdown(
|
||||
label="Save precision",
|
||||
choices=["None", "float", "fp16", "bf16"],
|
||||
value="None",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
dim = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
label="Network Dimension (Rank)",
|
||||
value=4,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
choices=["cpu", "cuda"],
|
||||
value="cuda",
|
||||
interactive=True,
|
||||
)
|
||||
clamp_quantile = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
label="Clamp Quantile",
|
||||
value=0.99,
|
||||
step=0.01,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
no_metadata = gr.Checkbox(
|
||||
label="No metadata (do not save sai modelspec metadata)",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
mem_eff_safe_open = gr.Checkbox(
|
||||
label="Memory efficient safe open (experimental feature)",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
extract_button = gr.Button("Extract Flux LoRA model")
|
||||
|
||||
extract_button.click(
|
||||
extract_flux_lora,
|
||||
inputs=[
|
||||
model_org,
|
||||
model_tuned,
|
||||
save_to,
|
||||
save_precision,
|
||||
dim,
|
||||
device,
|
||||
clamp_quantile,
|
||||
no_metadata,
|
||||
mem_eff_safe_open,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
model_org.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||
inputs=model_org,
|
||||
outputs=model_org,
|
||||
show_progress=False,
|
||||
)
|
||||
model_tuned.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||
inputs=model_tuned,
|
||||
outputs=model_tuned,
|
||||
show_progress=False,
|
||||
)
|
||||
save_to.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||
inputs=save_to,
|
||||
outputs=save_to,
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
@ -0,0 +1,470 @@
|
|||
# Standard library imports
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Third-party imports
|
||||
import gradio as gr
|
||||
|
||||
# Local module imports
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
create_refresh_button,
|
||||
setup_environment,
|
||||
)
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
def check_model(model):
|
||||
if not model:
|
||||
return True
|
||||
if not os.path.isfile(model):
|
||||
log.info(f"The provided {model} is not a file")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def verify_conditions(flux_model, lora_models):
|
||||
lora_models_count = sum(1 for model in lora_models if model)
|
||||
if flux_model and lora_models_count >= 1:
|
||||
return True
|
||||
elif not flux_model and lora_models_count >= 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class GradioFluxMergeLoRaTab:
|
||||
def __init__(self, headless=False):
|
||||
self.headless = headless
|
||||
self.build_tab()
|
||||
|
||||
def save_inputs_to_json(self, file_path, inputs):
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
json.dump(inputs, file)
|
||||
log.info(f"Saved inputs to {file_path}")
|
||||
|
||||
def load_inputs_from_json(self, file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
inputs = json.load(file)
|
||||
log.info(f"Loaded inputs from {file_path}")
|
||||
return inputs
|
||||
|
||||
def build_tab(self):
|
||||
current_flux_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
current_lora_model_dir = current_flux_model_dir
|
||||
|
||||
def list_flux_models(path):
|
||||
nonlocal current_flux_model_dir
|
||||
current_flux_model_dir = path
|
||||
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||
|
||||
def list_lora_models(path):
|
||||
nonlocal current_lora_model_dir
|
||||
current_lora_model_dir = path
|
||||
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||
|
||||
def list_save_to(path):
|
||||
nonlocal current_save_dir
|
||||
current_save_dir = path
|
||||
return list(list_files(path, exts=[".safetensors"], all=True))
|
||||
|
||||
with gr.Tab("Merge FLUX LoRA"):
|
||||
gr.Markdown(
|
||||
"This utility can merge up to 4 LoRA into a FLUX model or alternatively merge up to 4 LoRA together."
|
||||
)
|
||||
|
||||
lora_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
flux_ext = gr.Textbox(value="*.safetensors", visible=False)
|
||||
flux_ext_name = gr.Textbox(value="FLUX model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
flux_model = gr.Dropdown(
|
||||
label="FLUX Model (Optional. FLUX model path, if you want to merge it with LoRA files via the 'concat' method)",
|
||||
interactive=True,
|
||||
choices=[""] + list_flux_models(current_flux_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
flux_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_flux_models(current_flux_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
flux_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
flux_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[flux_model, flux_ext, flux_ext_name],
|
||||
outputs=flux_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
flux_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_flux_models(path)),
|
||||
inputs=flux_model,
|
||||
outputs=flux_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
lora_a_model = gr.Dropdown(
|
||||
label='LoRA model "A" (path to the LoRA A model)',
|
||||
interactive=True,
|
||||
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
lora_a_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
lora_b_model = gr.Dropdown(
|
||||
label='LoRA model "B" (path to the LoRA B model)',
|
||||
interactive=True,
|
||||
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
lora_b_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_b_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lora_b_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
lora_a_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||
inputs=lora_a_model,
|
||||
outputs=lora_a_model,
|
||||
show_progress=False,
|
||||
)
|
||||
lora_b_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||
inputs=lora_b_model,
|
||||
outputs=lora_b_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
ratio_a = gr.Slider(
|
||||
label="Model A merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=2,
|
||||
step=0.01,
|
||||
value=0.0,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
ratio_b = gr.Slider(
|
||||
label="Model B merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=2,
|
||||
step=0.01,
|
||||
value=0.0,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
lora_c_model = gr.Dropdown(
|
||||
label='LoRA model "C" (path to the LoRA C model)',
|
||||
interactive=True,
|
||||
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
lora_c_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_c_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_c_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lora_c_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_c_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
lora_d_model = gr.Dropdown(
|
||||
label='LoRA model "D" (path to the LoRA D model)',
|
||||
interactive=True,
|
||||
choices=[""] + list_lora_models(current_lora_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
lora_d_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_lora_models(current_lora_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_d_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_d_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[lora_d_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_d_model,
|
||||
show_progress=False,
|
||||
)
|
||||
lora_c_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||
inputs=lora_c_model,
|
||||
outputs=lora_c_model,
|
||||
show_progress=False,
|
||||
)
|
||||
lora_d_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_lora_models(path)),
|
||||
inputs=lora_d_model,
|
||||
outputs=lora_d_model,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
ratio_c = gr.Slider(
|
||||
label="Model C merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=2,
|
||||
step=0.01,
|
||||
value=0.0,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
ratio_d = gr.Slider(
|
||||
label="Model D merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=2,
|
||||
step=0.01,
|
||||
value=0.0,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
save_to = gr.Dropdown(
|
||||
label="Save to (path for the file to save...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
inputs=[save_to, lora_ext, lora_ext_name],
|
||||
outputs=save_to,
|
||||
show_progress=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Merge precision",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
value="float",
|
||||
interactive=True,
|
||||
)
|
||||
save_precision = gr.Radio(
|
||||
label="Save precision",
|
||||
choices=["float", "fp16", "bf16", "fp8"],
|
||||
value="fp16",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
save_to.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_save_to(path)),
|
||||
inputs=save_to,
|
||||
outputs=save_to,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
loading_device = gr.Dropdown(
|
||||
label="Loading device",
|
||||
choices=["cpu", "cuda"],
|
||||
value="cpu",
|
||||
interactive=True,
|
||||
)
|
||||
working_device = gr.Dropdown(
|
||||
label="Working device",
|
||||
choices=["cpu", "cuda"],
|
||||
value="cpu",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
concat = gr.Checkbox(label="Concat LoRA", value=False)
|
||||
shuffle = gr.Checkbox(label="Shuffle LoRA weights", value=False)
|
||||
no_metadata = gr.Checkbox(label="Don't save metadata", value=False)
|
||||
diffusers = gr.Checkbox(label="Diffusers LoRA", value=False)
|
||||
|
||||
merge_button = gr.Button("Merge model")
|
||||
|
||||
merge_button.click(
|
||||
self.merge_flux_lora,
|
||||
inputs=[
|
||||
flux_model,
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
lora_c_model,
|
||||
lora_d_model,
|
||||
ratio_a,
|
||||
ratio_b,
|
||||
ratio_c,
|
||||
ratio_d,
|
||||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
loading_device,
|
||||
working_device,
|
||||
concat,
|
||||
shuffle,
|
||||
no_metadata,
|
||||
diffusers,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
def merge_flux_lora(
|
||||
self,
|
||||
flux_model,
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
lora_c_model,
|
||||
lora_d_model,
|
||||
ratio_a,
|
||||
ratio_b,
|
||||
ratio_c,
|
||||
ratio_d,
|
||||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
loading_device,
|
||||
working_device,
|
||||
concat,
|
||||
shuffle,
|
||||
no_metadata,
|
||||
difffusers,
|
||||
):
|
||||
log.info("Merge FLUX LoRA...")
|
||||
models = [
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
lora_c_model,
|
||||
lora_d_model,
|
||||
]
|
||||
lora_models = [model for model in models if model]
|
||||
ratios = [ratio for model, ratio in zip(models, [ratio_a, ratio_b, ratio_c, ratio_d]) if model]
|
||||
|
||||
# if not verify_conditions(flux_model, lora_models):
|
||||
# log.info(
|
||||
# "Warning: Either provide at least one LoRA model along with the FLUX model or at least two LoRA models if no FLUX model is provided."
|
||||
# )
|
||||
# return
|
||||
|
||||
for model in [flux_model] + lora_models:
|
||||
if not check_model(model):
|
||||
return
|
||||
|
||||
run_cmd = [rf"{PYTHON}", rf"{scriptdir}/sd-scripts/networks/flux_merge_lora.py"]
|
||||
|
||||
if flux_model:
|
||||
run_cmd.extend(["--flux_model", rf"{flux_model}"])
|
||||
|
||||
run_cmd.extend([
|
||||
"--save_precision", save_precision,
|
||||
"--precision", precision,
|
||||
"--save_to", rf"{save_to}",
|
||||
"--loading_device", loading_device,
|
||||
"--working_device", working_device,
|
||||
])
|
||||
|
||||
if lora_models:
|
||||
run_cmd.append("--models")
|
||||
run_cmd.extend(lora_models)
|
||||
run_cmd.append("--ratios")
|
||||
run_cmd.extend(map(str, ratios))
|
||||
|
||||
if concat:
|
||||
run_cmd.append("--concat")
|
||||
if shuffle:
|
||||
run_cmd.append("--shuffle")
|
||||
if no_metadata:
|
||||
run_cmd.append("--no_metadata")
|
||||
if difffusers:
|
||||
run_cmd.append("--diffusers")
|
||||
|
||||
env = setup_environment()
|
||||
|
||||
# Reconstruct the safe command string for display
|
||||
command_to_run = " ".join(run_cmd)
|
||||
log.info(f"Executing command: {command_to_run}")
|
||||
|
||||
# Run the command in the sd-scripts folder context
|
||||
subprocess.run(run_cmd, env=env)
|
||||
|
||||
log.info("Done merging...")
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -16,6 +16,7 @@ from .common_gui import (
|
|||
create_refresh_button, setup_environment
|
||||
)
|
||||
from .custom_logging import setup_logging
|
||||
from .sd_modeltype import SDModelType
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -145,6 +146,13 @@ class GradioMergeLoRaTab:
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
#secondary event on sd_model for auto-detection of SDXL
|
||||
sd_model.change(
|
||||
lambda path: gr.Checkbox(value=SDModelType(path).Is_SDXL()),
|
||||
inputs=sd_model,
|
||||
outputs=sdxl_model
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
lora_a_model = gr.Dropdown(
|
||||
label='LoRA model "A" (path to the LoRA A model)',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
from os.path import isfile
|
||||
from safetensors import safe_open
|
||||
import enum
|
||||
|
||||
# methodology is based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_models.py#L379-L403
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
UNKNOWN = 0
|
||||
SD1 = 1
|
||||
SD2 = 2
|
||||
SDXL = 3
|
||||
SD3 = 4
|
||||
FLUX1 = 5
|
||||
|
||||
|
||||
class SDModelType:
|
||||
def __init__(self, safetensors_path):
|
||||
self.model_type = ModelType.UNKNOWN
|
||||
|
||||
if not isfile(safetensors_path):
|
||||
return
|
||||
|
||||
try:
|
||||
st = safe_open(filename=safetensors_path, framework="numpy", device="cpu")
|
||||
|
||||
# print(st.keys())
|
||||
|
||||
def hasKeyPrefix(pfx):
|
||||
return any(k.startswith(pfx) for k in st.keys())
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in st.keys():
|
||||
self.model_type = ModelType.SD3
|
||||
elif (
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
|
||||
in st.keys()
|
||||
or "double_blocks.0.img_attn.norm.key_norm.scale" in st.keys()
|
||||
):
|
||||
# print("flux1 model detected...")
|
||||
self.model_type = ModelType.FLUX1
|
||||
elif hasKeyPrefix("conditioner."):
|
||||
self.model_type = ModelType.SDXL
|
||||
elif hasKeyPrefix("cond_stage_model.model."):
|
||||
self.model_type = ModelType.SD2
|
||||
elif hasKeyPrefix("model."):
|
||||
self.model_type = ModelType.SD1
|
||||
except:
|
||||
pass
|
||||
|
||||
# print(f"Model type: {self.model_type}")
|
||||
|
||||
def Is_SD1(self):
|
||||
return self.model_type == ModelType.SD1
|
||||
|
||||
def Is_SD2(self):
|
||||
return self.model_type == ModelType.SD2
|
||||
|
||||
def Is_SDXL(self):
|
||||
return self.model_type == ModelType.SDXL
|
||||
|
||||
def Is_SD3(self):
|
||||
return self.model_type == ModelType.SD3
|
||||
|
||||
def Is_FLUX1(self):
|
||||
return self.model_type == ModelType.FLUX1
|
||||
|
|
@ -70,6 +70,7 @@ def save_configuration(
|
|||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
|
|
@ -135,6 +136,7 @@ def save_configuration(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -151,17 +153,23 @@ def save_configuration(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
sdxl_no_half_vae,
|
||||
|
|
@ -229,6 +237,7 @@ def open_configuration(
|
|||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
|
|
@ -294,6 +303,7 @@ def open_configuration(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -310,17 +320,23 @@ def open_configuration(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
sdxl_no_half_vae,
|
||||
|
|
@ -381,6 +397,7 @@ def train_model(
|
|||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
lr_warmup_steps,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
|
|
@ -446,6 +463,7 @@ def train_model(
|
|||
optimizer,
|
||||
optimizer_args,
|
||||
lr_scheduler_args,
|
||||
lr_scheduler_type,
|
||||
noise_offset_type,
|
||||
noise_offset,
|
||||
noise_offset_random_strength,
|
||||
|
|
@ -462,17 +480,23 @@ def train_model(
|
|||
loss_type,
|
||||
huber_schedule,
|
||||
huber_c,
|
||||
huber_scale,
|
||||
vae_batch_size,
|
||||
min_snr_gamma,
|
||||
save_every_n_steps,
|
||||
save_last_n_steps,
|
||||
save_last_n_steps_state,
|
||||
save_last_n_epochs,
|
||||
save_last_n_epochs_state,
|
||||
skip_cache_check,
|
||||
log_with,
|
||||
wandb_api_key,
|
||||
wandb_run_name,
|
||||
log_tracker_name,
|
||||
log_tracker_config,
|
||||
log_config,
|
||||
scale_v_pred_loss_like_noise_pred,
|
||||
disable_mmap_load_safetensors,
|
||||
min_timestep,
|
||||
max_timestep,
|
||||
sdxl_no_half_vae,
|
||||
|
|
@ -549,20 +573,6 @@ def train_model(
|
|||
# End of path validation
|
||||
#
|
||||
|
||||
# if not validate_paths(
|
||||
# dataset_config=dataset_config,
|
||||
# headless=headless,
|
||||
# log_tracker_config=log_tracker_config,
|
||||
# logging_dir=logging_dir,
|
||||
# output_dir=output_dir,
|
||||
# pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
# reg_data_dir=reg_data_dir,
|
||||
# resume=resume,
|
||||
# train_data_dir=train_data_dir,
|
||||
# vae=vae,
|
||||
# ):
|
||||
# return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
if token_string == "":
|
||||
output_message(msg="Token string is missing", headless=headless)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
|
@ -588,13 +598,6 @@ def train_model(
|
|||
stop_text_encoder_training = math.ceil(
|
||||
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
||||
)
|
||||
|
||||
if lr_warmup != 0:
|
||||
lr_warmup_steps = round(
|
||||
float(int(lr_warmup) * int(max_train_steps) / 100)
|
||||
)
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
else:
|
||||
stop_text_encoder_training = 0
|
||||
lr_warmup_steps = 0
|
||||
|
|
@ -657,11 +660,11 @@ def train_model(
|
|||
reg_factor = 1
|
||||
else:
|
||||
log.warning(
|
||||
"Regularisation images are used... Will double the number of steps required..."
|
||||
"Regularization images are used... Will double the number of steps required..."
|
||||
)
|
||||
reg_factor = 2
|
||||
|
||||
log.info(f"Regulatization factor: {reg_factor}")
|
||||
log.info(f"Regularization factor: {reg_factor}")
|
||||
|
||||
if max_train_steps == 0:
|
||||
# calculate max_train_steps
|
||||
|
|
@ -689,13 +692,18 @@ def train_model(
|
|||
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
||||
)
|
||||
|
||||
if lr_warmup != 0:
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
|
||||
log.info(f"Total steps: {total_steps}")
|
||||
|
||||
# Calculate lr_warmup_steps
|
||||
if lr_warmup_steps > 0:
|
||||
lr_warmup_steps = int(lr_warmup_steps)
|
||||
if lr_warmup > 0:
|
||||
log.warning("Both lr_warmup and lr_warmup_steps are set. lr_warmup_steps will be used.")
|
||||
elif lr_warmup != 0:
|
||||
lr_warmup_steps = lr_warmup / 100
|
||||
else:
|
||||
lr_warmup_steps = 0
|
||||
|
||||
log.info(f"Train batch size: {train_batch_size}")
|
||||
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
|
||||
log.info(f"Epoch: {epoch}")
|
||||
|
|
@ -757,6 +765,7 @@ def train_model(
|
|||
"clip_skip": clip_skip if clip_skip != 0 else None,
|
||||
"color_aug": color_aug,
|
||||
"dataset_config": dataset_config,
|
||||
"disable_mmap_load_safetensors": disable_mmap_load_safetensors,
|
||||
"dynamo_backend": dynamo_backend,
|
||||
"enable_bucket": enable_bucket,
|
||||
"epoch": int(epoch),
|
||||
|
|
@ -765,6 +774,7 @@ def train_model(
|
|||
"gradient_accumulation_steps": int(gradient_accumulation_steps),
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"huber_c": huber_c,
|
||||
"huber_scale": huber_scale,
|
||||
"huber_schedule": huber_schedule,
|
||||
"huggingface_repo_id": huggingface_repo_id,
|
||||
"huggingface_token": huggingface_token,
|
||||
|
|
@ -777,6 +787,7 @@ def train_model(
|
|||
"keep_tokens": int(keep_tokens),
|
||||
"learning_rate": learning_rate,
|
||||
"logging_dir": logging_dir,
|
||||
"log_config": log_config,
|
||||
"log_tracker_name": log_tracker_name,
|
||||
"log_tracker_config": log_tracker_config,
|
||||
"loss_type": loss_type,
|
||||
|
|
@ -786,6 +797,7 @@ def train_model(
|
|||
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
|
||||
),
|
||||
"lr_scheduler_power": lr_scheduler_power,
|
||||
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type != "" else None,
|
||||
"lr_warmup_steps": lr_warmup_steps,
|
||||
"max_bucket_reso": max_bucket_reso,
|
||||
"max_timestep": max_timestep if max_timestep != 0 else None,
|
||||
|
|
@ -840,6 +852,10 @@ def train_model(
|
|||
"save_last_n_steps_state": (
|
||||
save_last_n_steps_state if save_last_n_steps_state != 0 else None
|
||||
),
|
||||
"save_last_n_epochs": save_last_n_epochs if save_last_n_epochs != 0 else None,
|
||||
"save_last_n_epochs_state": (
|
||||
save_last_n_epochs_state if save_last_n_epochs_state != 0 else None
|
||||
),
|
||||
"save_model_as": save_model_as,
|
||||
"save_precision": save_precision,
|
||||
"save_state": save_state,
|
||||
|
|
@ -849,6 +865,7 @@ def train_model(
|
|||
"sdpa": True if xformers == "sdpa" else None,
|
||||
"seed": int(seed) if int(seed) != 0 else None,
|
||||
"shuffle_caption": shuffle_caption,
|
||||
"skip_cache_check": skip_cache_check,
|
||||
"stop_text_encoder_training": (
|
||||
stop_text_encoder_training if stop_text_encoder_training != 0 else None
|
||||
),
|
||||
|
|
@ -862,8 +879,8 @@ def train_model(
|
|||
"vae": vae,
|
||||
"vae_batch_size": vae_batch_size if vae_batch_size != 0 else None,
|
||||
"wandb_api_key": wandb_api_key,
|
||||
"wandb_run_name": wandb_run_name,
|
||||
"weigts": weights,
|
||||
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
|
||||
"weights": weights,
|
||||
"use_object_template": True if template == "object template" else None,
|
||||
"use_style_template": True if template == "style template" else None,
|
||||
"xformers": True if xformers == "xformers" else None,
|
||||
|
|
@ -1130,6 +1147,7 @@ def ti_tab(
|
|||
basic_training.learning_rate,
|
||||
basic_training.lr_scheduler,
|
||||
basic_training.lr_warmup,
|
||||
basic_training.lr_warmup_steps,
|
||||
basic_training.train_batch_size,
|
||||
basic_training.epoch,
|
||||
basic_training.save_every_n_epochs,
|
||||
|
|
@ -1194,6 +1212,7 @@ def ti_tab(
|
|||
basic_training.optimizer,
|
||||
basic_training.optimizer_args,
|
||||
basic_training.lr_scheduler_args,
|
||||
basic_training.lr_scheduler_type,
|
||||
advanced_training.noise_offset_type,
|
||||
advanced_training.noise_offset,
|
||||
advanced_training.noise_offset_random_strength,
|
||||
|
|
@ -1210,17 +1229,23 @@ def ti_tab(
|
|||
advanced_training.loss_type,
|
||||
advanced_training.huber_schedule,
|
||||
advanced_training.huber_c,
|
||||
advanced_training.huber_scale,
|
||||
advanced_training.vae_batch_size,
|
||||
advanced_training.min_snr_gamma,
|
||||
advanced_training.save_every_n_steps,
|
||||
advanced_training.save_last_n_steps,
|
||||
advanced_training.save_last_n_steps_state,
|
||||
advanced_training.save_last_n_epochs,
|
||||
advanced_training.save_last_n_epochs_state,
|
||||
advanced_training.skip_cache_check,
|
||||
advanced_training.log_with,
|
||||
advanced_training.wandb_api_key,
|
||||
advanced_training.wandb_run_name,
|
||||
advanced_training.log_tracker_name,
|
||||
advanced_training.log_tracker_config,
|
||||
advanced_training.log_config,
|
||||
advanced_training.scale_v_pred_loss_like_noise_pred,
|
||||
sdxl_params.disable_mmap_load_safetensors,
|
||||
advanced_training.min_timestep,
|
||||
advanced_training.max_timestep,
|
||||
sdxl_params.sdxl_no_half_vae,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,146 @@
|
|||
{
|
||||
"adaptive_noise_scale": 0,
|
||||
"additional_parameters": "",
|
||||
"async_upload": false,
|
||||
"bucket_no_upscale": true,
|
||||
"bucket_reso_steps": 64,
|
||||
"cache_latents": true,
|
||||
"cache_latents_to_disk": true,
|
||||
"caption_dropout_every_n_epochs": 0,
|
||||
"caption_dropout_rate": 0,
|
||||
"caption_extension": ".txt",
|
||||
"clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors",
|
||||
"clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors",
|
||||
"clip_skip": 1,
|
||||
"color_aug": false,
|
||||
"dataset_config": "",
|
||||
"debiased_estimation_loss": false,
|
||||
"disable_mmap_load_safetensors": false,
|
||||
"dynamo_backend": "no",
|
||||
"dynamo_mode": "default",
|
||||
"dynamo_use_dynamic": false,
|
||||
"dynamo_use_fullgraph": false,
|
||||
"enable_bucket": true,
|
||||
"epoch": 8,
|
||||
"extra_accelerate_launch_args": "",
|
||||
"flip_aug": false,
|
||||
"full_bf16": false,
|
||||
"full_fp16": false,
|
||||
"fused_backward_pass": false,
|
||||
"fused_optimizer_groups": 0,
|
||||
"gpu_ids": "",
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": true,
|
||||
"huber_c": 0.1,
|
||||
"huber_schedule": "snr",
|
||||
"huggingface_path_in_repo": "",
|
||||
"huggingface_repo_id": "",
|
||||
"huggingface_repo_type": "",
|
||||
"huggingface_repo_visibility": "",
|
||||
"huggingface_token": "",
|
||||
"ip_noise_gamma": 0,
|
||||
"ip_noise_gamma_random_strength": false,
|
||||
"keep_tokens": 0,
|
||||
"learning_rate": 5e-06,
|
||||
"learning_rate_te": 0,
|
||||
"learning_rate_te1": 1e-05,
|
||||
"learning_rate_te2": 1e-05,
|
||||
"log_config": false,
|
||||
"log_tracker_config": "",
|
||||
"log_tracker_name": "",
|
||||
"log_with": "",
|
||||
"logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3",
|
||||
"logit_mean": 0,
|
||||
"logit_std": 1,
|
||||
"loss_type": "l2",
|
||||
"lr_scheduler": "cosine",
|
||||
"lr_scheduler_args": "",
|
||||
"lr_scheduler_num_cycles": 1,
|
||||
"lr_scheduler_power": 1,
|
||||
"lr_scheduler_type": "",
|
||||
"lr_warmup": 10,
|
||||
"main_process_port": 0,
|
||||
"masked_loss": false,
|
||||
"max_bucket_reso": 1536,
|
||||
"max_data_loader_n_workers": 0,
|
||||
"max_resolution": "512,512",
|
||||
"max_timestep": 1000,
|
||||
"max_token_length": 225,
|
||||
"max_train_epochs": 8,
|
||||
"max_train_steps": 1600,
|
||||
"mem_eff_attn": false,
|
||||
"metadata_author": "",
|
||||
"metadata_description": "",
|
||||
"metadata_license": "",
|
||||
"metadata_tags": "",
|
||||
"metadata_title": "",
|
||||
"min_bucket_reso": 256,
|
||||
"min_snr_gamma": 0,
|
||||
"min_timestep": 0,
|
||||
"mixed_precision": "bf16",
|
||||
"mode_scale": 1.29,
|
||||
"model_list": "custom",
|
||||
"multi_gpu": false,
|
||||
"multires_noise_discount": 0.3,
|
||||
"multires_noise_iterations": 0,
|
||||
"no_token_padding": false,
|
||||
"noise_offset": 0,
|
||||
"noise_offset_random_strength": false,
|
||||
"noise_offset_type": "Original",
|
||||
"num_cpu_threads_per_process": 2,
|
||||
"num_machines": 1,
|
||||
"num_processes": 1,
|
||||
"optimizer": "PagedAdamW8bit",
|
||||
"optimizer_args": "weight_decay=0.1 betas=.9,.95",
|
||||
"output_dir": "E:/models/sd3",
|
||||
"output_name": "sd3",
|
||||
"persistent_data_loader_workers": false,
|
||||
"pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors",
|
||||
"prior_loss_weight": 1,
|
||||
"random_crop": false,
|
||||
"reg_data_dir": "",
|
||||
"resume": "",
|
||||
"resume_from_huggingface": "",
|
||||
"sample_every_n_epochs": 0,
|
||||
"sample_every_n_steps": 0,
|
||||
"sample_prompts": "",
|
||||
"sample_sampler": "euler_a",
|
||||
"save_as_bool": false,
|
||||
"save_clip": false,
|
||||
"save_every_n_epochs": 0,
|
||||
"save_every_n_steps": 0,
|
||||
"save_last_n_steps": 0,
|
||||
"save_last_n_steps_state": 0,
|
||||
"save_model_as": "safetensors",
|
||||
"save_precision": "fp16",
|
||||
"save_state": false,
|
||||
"save_state_on_train_end": false,
|
||||
"save_state_to_huggingface": false,
|
||||
"save_t5xxl": false,
|
||||
"scale_v_pred_loss_like_noise_pred": false,
|
||||
"sd3_cache_text_encoder_outputs": true,
|
||||
"sd3_cache_text_encoder_outputs_to_disk": true,
|
||||
"sd3_checkbox": true,
|
||||
"sd3_text_encoder_batch_size": 1,
|
||||
"sdxl": false,
|
||||
"sdxl_cache_text_encoder_outputs": false,
|
||||
"sdxl_no_half_vae": false,
|
||||
"seed": 1026,
|
||||
"shuffle_caption": false,
|
||||
"stop_text_encoder_training": 0,
|
||||
"t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors",
|
||||
"t5xxl_device": "",
|
||||
"t5xxl_dtype": "bf16",
|
||||
"train_batch_size": 1,
|
||||
"train_data_dir": "C:/Users/berna/Downloads/martini/img2",
|
||||
"v2": false,
|
||||
"v_parameterization": false,
|
||||
"v_pred_like_loss": 0,
|
||||
"vae": "",
|
||||
"vae_batch_size": 0,
|
||||
"wandb_api_key": "",
|
||||
"wandb_run_name": "",
|
||||
"weighted_captions": false,
|
||||
"weighting_scheme": "logit_normal",
|
||||
"xformers": "sdpa"
|
||||
}
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
{
|
||||
"adaptive_noise_scale": 0,
|
||||
"additional_parameters": "",
|
||||
"async_upload": false,
|
||||
"bucket_no_upscale": true,
|
||||
"bucket_reso_steps": 64,
|
||||
"cache_latents": true,
|
||||
"cache_latents_to_disk": true,
|
||||
"caption_dropout_every_n_epochs": 0,
|
||||
"caption_dropout_rate": 0,
|
||||
"caption_extension": ".txt",
|
||||
"clip_g": "H:/ComfyUI2/models/clip/clip_g.safetensors",
|
||||
"clip_l": "H:/ComfyUI2/models/clip/clip_l.safetensors",
|
||||
"clip_skip": 1,
|
||||
"color_aug": false,
|
||||
"dataset_config": "",
|
||||
"debiased_estimation_loss": false,
|
||||
"disable_mmap_load_safetensors": false,
|
||||
"dynamo_backend": "no",
|
||||
"dynamo_mode": "default",
|
||||
"dynamo_use_dynamic": false,
|
||||
"dynamo_use_fullgraph": false,
|
||||
"enable_bucket": true,
|
||||
"epoch": 8,
|
||||
"extra_accelerate_launch_args": "",
|
||||
"flip_aug": false,
|
||||
"full_bf16": false,
|
||||
"full_fp16": false,
|
||||
"fused_backward_pass": false,
|
||||
"fused_optimizer_groups": 0,
|
||||
"gpu_ids": "",
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": true,
|
||||
"huber_c": 0.1,
|
||||
"huber_schedule": "snr",
|
||||
"huggingface_path_in_repo": "",
|
||||
"huggingface_repo_id": "",
|
||||
"huggingface_repo_type": "",
|
||||
"huggingface_repo_visibility": "",
|
||||
"huggingface_token": "",
|
||||
"ip_noise_gamma": 0,
|
||||
"ip_noise_gamma_random_strength": false,
|
||||
"keep_tokens": 0,
|
||||
"learning_rate": 5e-06,
|
||||
"learning_rate_te": 0,
|
||||
"learning_rate_te1": 1e-05,
|
||||
"learning_rate_te2": 1e-05,
|
||||
"log_config": false,
|
||||
"log_tracker_config": "",
|
||||
"log_tracker_name": "",
|
||||
"log_with": "",
|
||||
"logging_dir": "C:/Users/berna/Downloads/martini/logs/sd3",
|
||||
"logit_mean": 0,
|
||||
"logit_std": 1,
|
||||
"loss_type": "l2",
|
||||
"lr_scheduler": "cosine",
|
||||
"lr_scheduler_args": "",
|
||||
"lr_scheduler_num_cycles": 1,
|
||||
"lr_scheduler_power": 1,
|
||||
"lr_scheduler_type": "",
|
||||
"lr_warmup": 10,
|
||||
"main_process_port": 0,
|
||||
"masked_loss": false,
|
||||
"max_bucket_reso": 1536,
|
||||
"max_data_loader_n_workers": 0,
|
||||
"max_resolution": "512,512",
|
||||
"max_timestep": 1000,
|
||||
"max_token_length": 150,
|
||||
"max_train_epochs": 8,
|
||||
"max_train_steps": 1600,
|
||||
"mem_eff_attn": false,
|
||||
"metadata_author": "",
|
||||
"metadata_description": "",
|
||||
"metadata_license": "",
|
||||
"metadata_tags": "",
|
||||
"metadata_title": "",
|
||||
"min_bucket_reso": 256,
|
||||
"min_snr_gamma": 0,
|
||||
"min_timestep": 0,
|
||||
"mixed_precision": "bf16",
|
||||
"mode_scale": 1.29,
|
||||
"model_list": "custom",
|
||||
"multi_gpu": false,
|
||||
"multires_noise_discount": 0.3,
|
||||
"multires_noise_iterations": 0,
|
||||
"no_token_padding": false,
|
||||
"noise_offset": 0,
|
||||
"noise_offset_random_strength": false,
|
||||
"noise_offset_type": "Original",
|
||||
"num_cpu_threads_per_process": 2,
|
||||
"num_machines": 1,
|
||||
"num_processes": 1,
|
||||
"optimizer": "PagedAdamW8bit",
|
||||
"optimizer_args": "weight_decay=0.1 betas=.9,.95",
|
||||
"output_dir": "E:/models/sd3",
|
||||
"output_name": "sd3_v2",
|
||||
"persistent_data_loader_workers": false,
|
||||
"pretrained_model_name_or_path": "E:/models/sd3/sd3_medium.safetensors",
|
||||
"prior_loss_weight": 1,
|
||||
"random_crop": false,
|
||||
"reg_data_dir": "",
|
||||
"resume": "",
|
||||
"resume_from_huggingface": "",
|
||||
"sample_every_n_epochs": 0,
|
||||
"sample_every_n_steps": 0,
|
||||
"sample_prompts": "",
|
||||
"sample_sampler": "euler_a",
|
||||
"save_as_bool": false,
|
||||
"save_clip": false,
|
||||
"save_every_n_epochs": 0,
|
||||
"save_every_n_steps": 0,
|
||||
"save_last_n_steps": 0,
|
||||
"save_last_n_steps_state": 0,
|
||||
"save_model_as": "safetensors",
|
||||
"save_precision": "fp16",
|
||||
"save_state": false,
|
||||
"save_state_on_train_end": false,
|
||||
"save_state_to_huggingface": false,
|
||||
"save_t5xxl": false,
|
||||
"scale_v_pred_loss_like_noise_pred": false,
|
||||
"sd3_cache_text_encoder_outputs": true,
|
||||
"sd3_cache_text_encoder_outputs_to_disk": true,
|
||||
"sd3_checkbox": true,
|
||||
"sd3_text_encoder_batch_size": 1,
|
||||
"sdxl": false,
|
||||
"sdxl_cache_text_encoder_outputs": false,
|
||||
"sdxl_no_half_vae": false,
|
||||
"seed": 1026,
|
||||
"shuffle_caption": false,
|
||||
"stop_text_encoder_training": 0,
|
||||
"t5xxl": "H:/ComfyUI2/models/clip/t5xxl_fp8_e4m3fn.safetensors",
|
||||
"t5xxl_device": "",
|
||||
"t5xxl_dtype": "bf16",
|
||||
"train_batch_size": 1,
|
||||
"train_data_dir": "C:/Users/berna/Downloads/martini/img",
|
||||
"v2": false,
|
||||
"v_parameterization": false,
|
||||
"v_pred_like_loss": 0,
|
||||
"vae": "",
|
||||
"vae_batch_size": 0,
|
||||
"wandb_api_key": "",
|
||||
"wandb_run_name": "",
|
||||
"weighted_captions": false,
|
||||
"weighting_scheme": "logit_normal",
|
||||
"xformers": "sdpa"
|
||||
}
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
{
|
||||
"LoRA_type": "Flux1",
|
||||
"LyCORIS_preset": "full",
|
||||
"adaptive_noise_scale": 0,
|
||||
"additional_parameters": "",
|
||||
"ae": "put the full path to ae.sft here",
|
||||
"apply_t5_attn_mask": true,
|
||||
"async_upload": false,
|
||||
"block_alphas": "",
|
||||
"block_dims": "",
|
||||
"block_lr_zero_threshold": "",
|
||||
"bucket_no_upscale": true,
|
||||
"bucket_reso_steps": 64,
|
||||
"bypass_mode": false,
|
||||
"cache_latents": true,
|
||||
"cache_latents_to_disk": true,
|
||||
"caption_dropout_every_n_epochs": 0,
|
||||
"caption_dropout_rate": 0,
|
||||
"caption_extension": ".txt",
|
||||
"clip_l": "put the full path to clip_l.safetensors here",
|
||||
"clip_skip": 1,
|
||||
"color_aug": false,
|
||||
"constrain": 0,
|
||||
"conv_alpha": 1,
|
||||
"conv_block_alphas": "",
|
||||
"conv_block_dims": "",
|
||||
"conv_dim": 1,
|
||||
"dataset_config": "",
|
||||
"debiased_estimation_loss": false,
|
||||
"decompose_both": false,
|
||||
"dim_from_weights": false,
|
||||
"discrete_flow_shift": 3,
|
||||
"dora_wd": false,
|
||||
"down_lr_weight": "",
|
||||
"dynamo_backend": "no",
|
||||
"dynamo_mode": "default",
|
||||
"dynamo_use_dynamic": false,
|
||||
"dynamo_use_fullgraph": false,
|
||||
"enable_bucket": true,
|
||||
"epoch": 1,
|
||||
"extra_accelerate_launch_args": "",
|
||||
"factor": -1,
|
||||
"flip_aug": false,
|
||||
"flux1_cache_text_encoder_outputs": true,
|
||||
"flux1_cache_text_encoder_outputs_to_disk": true,
|
||||
"flux1_checkbox": true,
|
||||
"fp8_base": true,
|
||||
"full_bf16": true,
|
||||
"full_fp16": false,
|
||||
"gpu_ids": "",
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": true,
|
||||
"guidance_scale": 1,
|
||||
"highvram": false,
|
||||
"huber_c": 0.1,
|
||||
"huber_schedule": "snr",
|
||||
"huggingface_path_in_repo": "",
|
||||
"huggingface_repo_id": "",
|
||||
"huggingface_repo_type": "",
|
||||
"huggingface_repo_visibility": "",
|
||||
"huggingface_token": "",
|
||||
"ip_noise_gamma": 0,
|
||||
"ip_noise_gamma_random_strength": false,
|
||||
"keep_tokens": 0,
|
||||
"learning_rate": 0.0003,
|
||||
"log_config": false,
|
||||
"log_tracker_config": "",
|
||||
"log_tracker_name": "",
|
||||
"log_with": "",
|
||||
"logging_dir": "./test/logs-saruman",
|
||||
"loraplus_lr_ratio": 0,
|
||||
"loraplus_text_encoder_lr_ratio": 0,
|
||||
"loraplus_unet_lr_ratio": 0,
|
||||
"loss_type": "l2",
|
||||
"lowvram": false,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_scheduler_args": "",
|
||||
"lr_scheduler_num_cycles": 1,
|
||||
"lr_scheduler_power": 1,
|
||||
"lr_scheduler_type": "",
|
||||
"lr_warmup": 0,
|
||||
"main_process_port": 0,
|
||||
"masked_loss": false,
|
||||
"max_bucket_reso": 2048,
|
||||
"max_data_loader_n_workers": 0,
|
||||
"max_grad_norm": 1,
|
||||
"max_resolution": "512,512",
|
||||
"max_timestep": 1000,
|
||||
"max_token_length": 75,
|
||||
"max_train_epochs": 0,
|
||||
"max_train_steps": 1000,
|
||||
"mem_eff_attn": false,
|
||||
"mem_eff_save": false,
|
||||
"metadata_author": "",
|
||||
"metadata_description": "",
|
||||
"metadata_license": "",
|
||||
"metadata_tags": "",
|
||||
"metadata_title": "",
|
||||
"mid_lr_weight": "",
|
||||
"min_bucket_reso": 256,
|
||||
"min_snr_gamma": 7,
|
||||
"min_timestep": 0,
|
||||
"mixed_precision": "bf16",
|
||||
"model_list": "custom",
|
||||
"model_prediction_type": "raw",
|
||||
"module_dropout": 0,
|
||||
"multi_gpu": false,
|
||||
"multires_noise_discount": 0.3,
|
||||
"multires_noise_iterations": 0,
|
||||
"network_alpha": 16,
|
||||
"network_dim": 16,
|
||||
"network_dropout": 0,
|
||||
"network_weights": "",
|
||||
"noise_offset": 0.05,
|
||||
"noise_offset_random_strength": false,
|
||||
"noise_offset_type": "Original",
|
||||
"num_cpu_threads_per_process": 2,
|
||||
"num_machines": 1,
|
||||
"num_processes": 1,
|
||||
"optimizer": "AdamW8bit",
|
||||
"optimizer_args": "",
|
||||
"output_dir": "put the full path to output folder here",
|
||||
"output_name": "Flux.my-super-duper-model-name-goes-here-v1.0",
|
||||
"persistent_data_loader_workers": false,
|
||||
"pretrained_model_name_or_path": "put the full path to flux1-dev.safetensors here",
|
||||
"prior_loss_weight": 1,
|
||||
"random_crop": false,
|
||||
"rank_dropout": 0,
|
||||
"rank_dropout_scale": false,
|
||||
"reg_data_dir": "",
|
||||
"rescaled": false,
|
||||
"resume": "",
|
||||
"resume_from_huggingface": "",
|
||||
"sample_every_n_epochs": 0,
|
||||
"sample_every_n_steps": 0,
|
||||
"sample_prompts": "saruman posing under a stormy lightning sky, photorealistic --w 832 --h 1216 --s 20 --l 4 --d 42",
|
||||
"sample_sampler": "euler",
|
||||
"save_as_bool": false,
|
||||
"save_every_n_epochs": 1,
|
||||
"save_every_n_steps": 50,
|
||||
"save_last_n_steps": 0,
|
||||
"save_last_n_steps_state": 0,
|
||||
"save_model_as": "safetensors",
|
||||
"save_precision": "bf16",
|
||||
"save_state": false,
|
||||
"save_state_on_train_end": false,
|
||||
"save_state_to_huggingface": false,
|
||||
"scale_v_pred_loss_like_noise_pred": false,
|
||||
"scale_weight_norms": 0,
|
||||
"sdxl": false,
|
||||
"sdxl_cache_text_encoder_outputs": true,
|
||||
"sdxl_no_half_vae": true,
|
||||
"seed": 42,
|
||||
"shuffle_caption": false,
|
||||
"split_mode": false,
|
||||
"stop_text_encoder_training": 0,
|
||||
"t5xxl": "put the full path to the file here. Use the fp16 version",
|
||||
"t5xxl_max_token_length": 512,
|
||||
"text_encoder_lr": 0,
|
||||
"timestep_sampling": "sigmoid",
|
||||
"train_batch_size": 1,
|
||||
"train_blocks": "all",
|
||||
"train_data_dir": "put your image folder here",
|
||||
"train_norm": false,
|
||||
"train_on_input": true,
|
||||
"training_comment": "",
|
||||
"unet_lr": 0.0003,
|
||||
"unit": 1,
|
||||
"up_lr_weight": "",
|
||||
"use_cp": false,
|
||||
"use_scalar": false,
|
||||
"use_tucker": false,
|
||||
"v2": false,
|
||||
"v_parameterization": false,
|
||||
"v_pred_like_loss": 0,
|
||||
"vae": "",
|
||||
"vae_batch_size": 0,
|
||||
"wandb_api_key": "",
|
||||
"wandb_run_name": "",
|
||||
"weighted_captions": false,
|
||||
"xformers": "sdpa"
|
||||
}
|
||||
|
|
@ -1,35 +1,38 @@
|
|||
accelerate==0.25.0
|
||||
accelerate==0.33.0
|
||||
aiofiles==23.2.1
|
||||
altair==4.2.2
|
||||
dadaptation==3.1
|
||||
dadaptation==3.2
|
||||
diffusers[torch]==0.25.0
|
||||
easygui==0.98.3
|
||||
einops==0.7.0
|
||||
fairscale==0.4.13
|
||||
ftfy==6.1.1
|
||||
gradio==4.43.0
|
||||
huggingface-hub==0.20.1
|
||||
gradio==5.23.1
|
||||
huggingface-hub==0.29.3
|
||||
imagesize==1.4.1
|
||||
invisible-watermark==0.2.0
|
||||
lion-pytorch==0.0.6
|
||||
lycoris_lora==2.2.0.post3
|
||||
lycoris_lora==3.1.0
|
||||
omegaconf==2.3.0
|
||||
onnx==1.16.1
|
||||
prodigyopt==1.0
|
||||
prodigyopt==1.1.2
|
||||
protobuf==3.20.3
|
||||
open-clip-torch==2.20.0
|
||||
opencv-python==4.7.0.68
|
||||
prodigyopt==1.0
|
||||
opencv-python==4.10.0.84
|
||||
prodigy-plus-schedule-free==1.8.0
|
||||
pytorch-lightning==1.9.0
|
||||
pytorch-optimizer==3.5.0
|
||||
rich>=13.7.1
|
||||
safetensors==0.4.2
|
||||
safetensors==0.4.4
|
||||
schedulefree==1.4
|
||||
scipy==1.11.4
|
||||
# for T5XXL tokenizer (SD3/FLUX)
|
||||
sentencepiece==0.2.0
|
||||
timm==0.6.12
|
||||
tk==0.1.0
|
||||
toml==0.10.2
|
||||
transformers==4.38.0
|
||||
transformers==4.44.2
|
||||
voluptuous==0.13.1
|
||||
wandb==0.15.11
|
||||
scipy==1.11.4
|
||||
# for kohya_ss library
|
||||
-e ./sd-scripts # no_verify leave this to specify not checking this a verification stage
|
||||
wandb==0.18.0
|
||||
# for kohya_ss sd-scripts library
|
||||
-e ./sd-scripts
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
bitsandbytes==0.43.0
|
||||
tensorboard==2.15.2 tensorflow==2.15.0.post1
|
||||
onnxruntime-gpu==1.17.1
|
||||
# Custom index URL for specific packages
|
||||
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
torch==2.5.0+cu124
|
||||
torchvision==0.20.0+cu124
|
||||
xformers==0.0.28.post2
|
||||
|
||||
bitsandbytes==0.44.0
|
||||
tensorboard==2.15.2
|
||||
tensorflow==2.15.0.post1
|
||||
onnxruntime-gpu==1.19.2
|
||||
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
xformers>=0.0.20
|
||||
bitsandbytes==0.43.0
|
||||
accelerate==0.25.0
|
||||
bitsandbytes==0.44.0
|
||||
accelerate==0.33.0
|
||||
tensorboard
|
||||
|
|
@ -1,5 +1,17 @@
|
|||
torch==2.1.0.post0+cxx11.abi torchvision==0.16.0.post0+cxx11.abi intel-extension-for-pytorch==2.1.20+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
tensorboard==2.15.2 tensorflow==2.15.0 intel-extension-for-tensorflow[xpu]==2.15.0.0
|
||||
mkl==2024.1.0 mkl-dpcpp==2024.1.0 oneccl-devel==2021.12.0 impi-devel==2021.12.0
|
||||
onnxruntime-openvino==1.17.1
|
||||
# Custom index URL for specific packages
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
|
||||
torch==2.1.0.post3+cxx11.abi
|
||||
torchvision==0.16.0.post3+cxx11.abi
|
||||
intel-extension-for-pytorch==2.1.40+xpu
|
||||
oneccl_bind_pt==2.1.400+xpu
|
||||
|
||||
tensorflow==2.15.1
|
||||
intel-extension-for-tensorflow[xpu]==2.15.0.1
|
||||
mkl==2024.2.0
|
||||
mkl-dpcpp==2024.2.0
|
||||
oneccl-devel==2021.13.0
|
||||
impi-devel==2021.13.0
|
||||
onnxruntime-openvino==1.18.0
|
||||
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,4 +1,13 @@
|
|||
torch==2.3.0+rocm6.0 torchvision==0.18.0+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0
|
||||
tensorboard==2.14.1 tensorflow-rocm==2.14.0.600
|
||||
onnxruntime-training --pre --index-url https://pypi.lsh.sh/60/ --extra-index-url https://pypi.org/simple
|
||||
# Custom index URL for specific packages
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.1
|
||||
torch==2.5.0+rocm6.1
|
||||
torchvision==0.20.0+rocm6.1
|
||||
|
||||
tensorboard==2.14.1
|
||||
tensorflow-rocm==2.14.0.600
|
||||
|
||||
# Custom index URL for specific packages
|
||||
--extra-index-url https://pypi.lsh.sh/60/
|
||||
onnxruntime-training --pre
|
||||
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.41.1
|
||||
xformers bitsandbytes==0.43.3
|
||||
tensorflow-macos tensorboard==2.14.1
|
||||
onnxruntime==1.17.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.41.1
|
||||
xformers bitsandbytes==0.43.3
|
||||
tensorflow-macos tensorflow-metal tensorboard==2.14.1
|
||||
onnxruntime==1.17.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
torch==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
# Custom index URL for specific packages
|
||||
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
torch==2.5.0+cu124
|
||||
torchvision==0.20.0+cu124
|
||||
xformers==0.0.28.post2
|
||||
|
||||
-r requirements_windows.txt
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage
|
||||
bitsandbytes==0.43.0
|
||||
tensorboard==2.14.1 tensorflow==2.14.0 wheel
|
||||
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||
torch==2.5.0+cu124
|
||||
torchvision==0.20.0+cu124
|
||||
xformers==0.0.28.post2
|
||||
|
||||
bitsandbytes==0.44.0
|
||||
tensorboard==2.14.1
|
||||
tensorflow==2.14.0
|
||||
wheel
|
||||
tensorrt
|
||||
onnxruntime-gpu==1.17.1
|
||||
onnxruntime-gpu==1.19.2
|
||||
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
bitsandbytes==0.43.0
|
||||
bitsandbytes==0.44.0
|
||||
tensorboard
|
||||
tensorflow>=2.16.1
|
||||
onnxruntime-gpu==1.17.1
|
||||
onnxruntime-gpu==1.19.2
|
||||
|
||||
-r requirements.txt
|
||||
|
|
@ -1 +1 @@
|
|||
Subproject commit b8896aad400222c8c4441b217fda0f9bb0807ffd
|
||||
Subproject commit 8ebe858f896340d698f03fc33d99ca010131320a
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
IF NOT EXIST venv (
|
||||
echo Creating venv...
|
||||
py -3.10 -m venv venv
|
||||
py -3.10.11 -m venv venv
|
||||
)
|
||||
|
||||
:: Create the directory if it doesn't exist
|
||||
|
|
@ -13,6 +13,9 @@ call .\venv\Scripts\deactivate.bat
|
|||
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
||||
REM first make sure we have setuptools available in the venv
|
||||
python -m pip install --require-virtualenv --no-input -q -q setuptools
|
||||
|
||||
REM Check if the batch was started via double-click
|
||||
IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" (
|
||||
REM echo This script was started by double clicking.
|
||||
|
|
|
|||
18
setup.sh
18
setup.sh
|
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/env bash
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# Function to display help information
|
||||
display_help() {
|
||||
|
|
@ -23,6 +24,7 @@ Options:
|
|||
-i, --interactive Interactively configure accelerate instead of using default config file.
|
||||
-n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations.
|
||||
-p, --public Expose public URL in runpod mode. Won't have an effect in other modes.
|
||||
-q, --quiet Suppress all output except errors.
|
||||
-r, --runpod Forces a runpod installation. Useful if detection fails for any reason.
|
||||
-s, --skip-space-check Skip the 10Gb minimum storage space check.
|
||||
-u, --no-gui Skips launching the GUI.
|
||||
|
|
@ -91,6 +93,7 @@ PARENT_DIR=""
|
|||
VENV_DIR=""
|
||||
USE_IPEX=false
|
||||
USE_ROCM=false
|
||||
QUIET="--show_stdout"
|
||||
|
||||
# Function to get the distro name
|
||||
get_distro_name() {
|
||||
|
|
@ -206,20 +209,20 @@ install_python_dependencies() {
|
|||
case "$OSTYPE" in
|
||||
"lin"*)
|
||||
if [ "$RUNPOD" = true ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt $QUIET
|
||||
elif [ "$USE_IPEX" = true ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt $QUIET
|
||||
elif [ "$USE_ROCM" = true ] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt $QUIET
|
||||
else
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt $QUIET
|
||||
fi
|
||||
;;
|
||||
"darwin"*)
|
||||
if [[ "$(uname -m)" == "arm64" ]]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_arm64.txt
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_arm64.txt $QUIET
|
||||
else
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_amd64.txt
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_macos_amd64.txt $QUIET
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
|
|
@ -307,7 +310,7 @@ update_kohya_ss() {
|
|||
|
||||
# Section: Command-line options parsing
|
||||
|
||||
while getopts ":vb:d:g:inprus-:" opt; do
|
||||
while getopts ":vb:d:g:inpqrus-:" opt; do
|
||||
# support long options: https://stackoverflow.com/a/28466267/519360
|
||||
if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG
|
||||
opt="${OPTARG%%=*}" # extract long option name
|
||||
|
|
@ -322,6 +325,7 @@ while getopts ":vb:d:g:inprus-:" opt; do
|
|||
i | interactive) INTERACTIVE=true ;;
|
||||
n | no-git-update) SKIP_GIT_UPDATE=true ;;
|
||||
p | public) PUBLIC=true ;;
|
||||
q | quiet) QUIET="" ;;
|
||||
r | runpod) RUNPOD=true ;;
|
||||
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
|
||||
u | no-gui) SKIP_GUI=true ;;
|
||||
|
|
|
|||
|
|
@ -1,363 +1,321 @@
|
|||
import subprocess
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import logging
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import re
|
||||
import pkg_resources
|
||||
|
||||
errors = 0 # Define the 'errors' variable before using it
|
||||
log = logging.getLogger('sd')
|
||||
log = logging.getLogger("sd")
|
||||
|
||||
# Constants
|
||||
MIN_PYTHON_VERSION = (3, 10, 9)
|
||||
MAX_PYTHON_VERSION = (3, 13, 0)
|
||||
LOG_DIR = "../logs/setup/"
|
||||
LOG_LEVEL = "INFO" # Set to "INFO" or "WARNING" for less verbose logging
|
||||
|
||||
|
||||
def check_python_version():
|
||||
"""
|
||||
Check if the current Python version is within the acceptable range.
|
||||
|
||||
Returns:
|
||||
bool: True if the current Python version is valid, False otherwise.
|
||||
bool: True if the current Python version is valid, False otherwise.
|
||||
"""
|
||||
min_version = (3, 10, 9)
|
||||
max_version = (3, 11, 0)
|
||||
|
||||
from packaging import version
|
||||
|
||||
log.debug("Checking Python version...")
|
||||
try:
|
||||
current_version = sys.version_info
|
||||
log.info(f"Python version is {sys.version}")
|
||||
|
||||
if not (min_version <= current_version < max_version):
|
||||
log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI")
|
||||
log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0")
|
||||
if not (MIN_PYTHON_VERSION <= current_version < MAX_PYTHON_VERSION):
|
||||
log.error(
|
||||
f"The current version of python ({sys.version}) is not supported."
|
||||
)
|
||||
log.error("The Python version must be >= 3.10.9 and < 3.13.0.")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Failed to verify Python version. Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def update_submodule(quiet=True):
|
||||
"""
|
||||
Ensure the submodule is initialized and updated.
|
||||
|
||||
This function uses the Git command line interface to initialize and update
|
||||
the specified submodule recursively. Errors during the Git operation
|
||||
or if Git is not found are caught and logged.
|
||||
|
||||
Parameters:
|
||||
- quiet: If True, suppresses the output of the Git command.
|
||||
"""
|
||||
log.debug("Updating submodule...")
|
||||
git_command = ["git", "submodule", "update", "--init", "--recursive"]
|
||||
|
||||
if quiet:
|
||||
git_command.append("--quiet")
|
||||
|
||||
try:
|
||||
# Initialize and update the submodule
|
||||
subprocess.run(git_command, check=True)
|
||||
log.info("Submodule initialized and updated.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Log the error if the Git operation fails
|
||||
log.error(f"Error during Git operation: {e}")
|
||||
except FileNotFoundError as e:
|
||||
# Log the error if the file is not found
|
||||
log.error(e)
|
||||
|
||||
# def read_tag_version_from_file(file_path):
|
||||
# """
|
||||
# Read the tag version from a given file.
|
||||
|
||||
# Parameters:
|
||||
# - file_path: The path to the file containing the tag version.
|
||||
|
||||
# Returns:
|
||||
# The tag version as a string.
|
||||
# """
|
||||
# with open(file_path, 'r') as file:
|
||||
# # Read the first line and strip whitespace
|
||||
# tag_version = file.readline().strip()
|
||||
# return tag_version
|
||||
|
||||
def clone_or_checkout(repo_url, branch_or_tag, directory_name):
|
||||
"""
|
||||
Clone a repo or checkout a specific branch or tag if the repo already exists.
|
||||
For branches, it updates to the latest version before checking out.
|
||||
Suppresses detached HEAD advice for tags or specific commits.
|
||||
Restores the original working directory after operations.
|
||||
|
||||
Parameters:
|
||||
- repo_url: The URL of the Git repository.
|
||||
- branch_or_tag: The name of the branch or tag to clone or checkout.
|
||||
- directory_name: The name of the directory to clone into or where the repo already exists.
|
||||
"""
|
||||
original_dir = os.getcwd() # Store the original directory
|
||||
log.debug(
|
||||
f"Cloning or checking out repository: {repo_url}, branch/tag: {branch_or_tag}, directory: {directory_name}"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
try:
|
||||
if not os.path.exists(directory_name):
|
||||
# Directory does not exist, clone the repo quietly
|
||||
|
||||
# Construct the command as a string for logging
|
||||
# run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}"
|
||||
run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name]
|
||||
|
||||
|
||||
# Log the command
|
||||
log.debug(run_cmd)
|
||||
|
||||
# Run the command
|
||||
process = subprocess.Popen(
|
||||
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
output, error = process.communicate()
|
||||
|
||||
if error and not error.startswith("Note: switching to"):
|
||||
log.warning(error)
|
||||
else:
|
||||
log.info(f"Successfully cloned sd-scripts {branch_or_tag}")
|
||||
|
||||
run_cmd = [
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch_or_tag,
|
||||
"--single-branch",
|
||||
"--quiet",
|
||||
repo_url,
|
||||
directory_name,
|
||||
]
|
||||
log.debug(f"Cloning repository: {run_cmd}")
|
||||
subprocess.run(run_cmd, check=True)
|
||||
log.info(f"Successfully cloned {repo_url} ({branch_or_tag})")
|
||||
else:
|
||||
os.chdir(directory_name)
|
||||
log.debug("Fetching all branches and tags...")
|
||||
subprocess.run(["git", "fetch", "--all", "--quiet"], check=True)
|
||||
subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True)
|
||||
subprocess.run(
|
||||
["git", "config", "advice.detachedHead", "false"], check=True
|
||||
)
|
||||
|
||||
# Get the current branch or commit hash
|
||||
current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||
tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode()
|
||||
current_branch_hash = (
|
||||
subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||
)
|
||||
target_branch_hash = (
|
||||
subprocess.check_output(["git", "rev-parse", branch_or_tag])
|
||||
.strip()
|
||||
.decode()
|
||||
)
|
||||
|
||||
if current_branch_hash != tag_branch_hash:
|
||||
run_cmd = f"git checkout {branch_or_tag} --quiet"
|
||||
# Log the command
|
||||
log.debug(run_cmd)
|
||||
|
||||
# Execute the checkout command
|
||||
process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
output, error = process.communicate()
|
||||
|
||||
if error:
|
||||
log.warning(error.decode())
|
||||
else:
|
||||
log.info(f"Checked out sd-scripts {branch_or_tag} successfully.")
|
||||
if current_branch_hash != target_branch_hash:
|
||||
log.debug(f"Checking out branch/tag: {branch_or_tag}")
|
||||
subprocess.run(
|
||||
["git", "checkout", branch_or_tag, "--quiet"], check=True
|
||||
)
|
||||
log.info(f"Checked out {branch_or_tag} successfully.")
|
||||
else:
|
||||
log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.")
|
||||
log.info(f"Already at required branch/tag: {branch_or_tag}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f"Error during Git operation: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir) # Restore the original directory
|
||||
os.chdir(original_dir)
|
||||
|
||||
# setup console and file logging
|
||||
def setup_logging(clean=False):
|
||||
#
|
||||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
#
|
||||
|
||||
def setup_logging():
|
||||
"""
|
||||
Set up logging to file and console.
|
||||
"""
|
||||
log.debug("Setting up logging...")
|
||||
|
||||
from rich.theme import Theme
|
||||
from rich.logging import RichHandler
|
||||
from rich.console import Console
|
||||
from rich.pretty import install as pretty_install
|
||||
from rich.traceback import install as traceback_install
|
||||
|
||||
console = Console(
|
||||
log_time=True,
|
||||
log_time_format='%H:%M:%S-%f',
|
||||
theme=Theme(
|
||||
{
|
||||
'traceback.border': 'black',
|
||||
'traceback.border.syntax_error': 'black',
|
||||
'inspect.value.border': 'black',
|
||||
}
|
||||
),
|
||||
log_time_format="%H:%M:%S-%f",
|
||||
theme=Theme({"traceback.border": "black", "inspect.value.border": "black"}),
|
||||
)
|
||||
# logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
# logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
|
||||
current_datetime = datetime.datetime.now()
|
||||
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S')
|
||||
current_datetime_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
log_file = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log',
|
||||
os.path.dirname(__file__), f"{LOG_DIR}kohya_ss_gui_{current_datetime_str}.log"
|
||||
)
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
|
||||
# Create directories if they don't exist
|
||||
log_directory = os.path.dirname(log_file)
|
||||
os.makedirs(log_directory, exist_ok=True)
|
||||
|
||||
level = logging.INFO
|
||||
logging.basicConfig(
|
||||
level=logging.ERROR,
|
||||
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s',
|
||||
format="%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s",
|
||||
filename=log_file,
|
||||
filemode='a',
|
||||
encoding='utf-8',
|
||||
filemode="a",
|
||||
encoding="utf-8",
|
||||
force=True,
|
||||
)
|
||||
log.setLevel(
|
||||
logging.DEBUG
|
||||
) # log to file is always at level debug for facility `sd`
|
||||
pretty_install(console=console)
|
||||
traceback_install(
|
||||
console=console,
|
||||
extra_lines=1,
|
||||
width=console.width,
|
||||
word_wrap=False,
|
||||
indent_guides=False,
|
||||
suppress=[],
|
||||
)
|
||||
rh = RichHandler(
|
||||
show_time=True,
|
||||
omit_repeated_times=False,
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
markup=False,
|
||||
rich_tracebacks=True,
|
||||
log_time_format='%H:%M:%S-%f',
|
||||
level=level,
|
||||
console=console,
|
||||
)
|
||||
rh.set_name(level)
|
||||
while log.hasHandlers() and len(log.handlers) > 0:
|
||||
log.removeHandler(log.handlers[0])
|
||||
log.addHandler(rh)
|
||||
log_level = os.getenv("LOG_LEVEL", LOG_LEVEL).upper()
|
||||
log.setLevel(getattr(logging, log_level, logging.DEBUG))
|
||||
rich_handler = RichHandler(console=console)
|
||||
|
||||
# Replace existing handlers with the rich handler
|
||||
log.handlers.clear()
|
||||
log.addHandler(rich_handler)
|
||||
log.debug("Logging setup complete.")
|
||||
|
||||
|
||||
def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False):
|
||||
def install_requirements_inbulk(
|
||||
requirements_file, show_stdout=True, optional_parm="", upgrade=False
|
||||
):
|
||||
log.debug(f"Installing requirements in bulk from: {requirements_file}")
|
||||
if not os.path.exists(requirements_file):
|
||||
log.error(f'Could not find the requirements file in {requirements_file}.')
|
||||
log.error(f"Could not find the requirements file in {requirements_file}.")
|
||||
return
|
||||
|
||||
log.info(f'Installing requirements from {requirements_file}...')
|
||||
log.info(f"Installing/Validating requirements from {requirements_file}...")
|
||||
|
||||
# Build the command as a list
|
||||
cmd = ["pip", "install", "-r", requirements_file]
|
||||
if upgrade:
|
||||
optional_parm += " -U"
|
||||
cmd.append("--upgrade")
|
||||
if not show_stdout:
|
||||
cmd.append("--quiet")
|
||||
if optional_parm:
|
||||
cmd.extend(optional_parm.split())
|
||||
|
||||
if show_stdout:
|
||||
run_cmd(f'pip install -r {requirements_file} {optional_parm}')
|
||||
else:
|
||||
run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet')
|
||||
log.info(f'Requirements from {requirements_file} installed.')
|
||||
try:
|
||||
# Run the command and filter output in real-time
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
for line in process.stdout:
|
||||
if "Requirement already satisfied" not in line:
|
||||
log.info(line.strip()) if show_stdout else None
|
||||
|
||||
# Capture and log any errors
|
||||
_, stderr = process.communicate()
|
||||
if process.returncode != 0:
|
||||
log.error(f"Failed to install requirements: {stderr.strip()}")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f"An error occurred while installing requirements: {e}")
|
||||
|
||||
|
||||
def configure_accelerate(run_accelerate=False):
|
||||
#
|
||||
# This function was taken and adapted from code written by jstayco
|
||||
#
|
||||
|
||||
log.debug("Configuring accelerate...")
|
||||
from pathlib import Path
|
||||
|
||||
def env_var_exists(var_name):
|
||||
return var_name in os.environ and os.environ[var_name] != ''
|
||||
return var_name in os.environ and os.environ[var_name] != ""
|
||||
|
||||
log.info('Configuring accelerate...')
|
||||
log.info("Configuring accelerate...")
|
||||
|
||||
source_accelerate_config_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'..',
|
||||
'config_files',
|
||||
'accelerate',
|
||||
'default_config.yaml',
|
||||
"..",
|
||||
"config_files",
|
||||
"accelerate",
|
||||
"default_config.yaml",
|
||||
)
|
||||
|
||||
if not os.path.exists(source_accelerate_config_file):
|
||||
log.warning(
|
||||
f"Could not find the accelerate configuration file in {source_accelerate_config_file}."
|
||||
)
|
||||
if run_accelerate:
|
||||
run_cmd('accelerate config')
|
||||
log.debug("Running accelerate configuration command...")
|
||||
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||
else:
|
||||
log.warning(
|
||||
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.'
|
||||
"Please configure accelerate manually by running the option in the menu."
|
||||
)
|
||||
return
|
||||
|
||||
log.debug(
|
||||
f'Source accelerate config location: {source_accelerate_config_file}'
|
||||
)
|
||||
log.debug(f"Source accelerate config location: {source_accelerate_config_file}")
|
||||
|
||||
target_config_location = None
|
||||
|
||||
log.debug(
|
||||
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, "
|
||||
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, "
|
||||
f"USERPROFILE: {os.environ.get('USERPROFILE')}"
|
||||
)
|
||||
if env_var_exists('HF_HOME'):
|
||||
target_config_location = Path(
|
||||
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml'
|
||||
)
|
||||
elif env_var_exists('LOCALAPPDATA'):
|
||||
target_config_location = Path(
|
||||
os.environ['LOCALAPPDATA'],
|
||||
'huggingface',
|
||||
'accelerate',
|
||||
'default_config.yaml',
|
||||
)
|
||||
elif env_var_exists('USERPROFILE'):
|
||||
target_config_location = Path(
|
||||
os.environ['USERPROFILE'],
|
||||
'.cache',
|
||||
'huggingface',
|
||||
'accelerate',
|
||||
'default_config.yaml',
|
||||
)
|
||||
env_vars = {
|
||||
"HF_HOME": Path(os.environ.get("HF_HOME", "")),
|
||||
"LOCALAPPDATA": Path(
|
||||
os.environ.get("LOCALAPPDATA", ""),
|
||||
"huggingface",
|
||||
"accelerate",
|
||||
"default_config.yaml",
|
||||
),
|
||||
"USERPROFILE": Path(
|
||||
os.environ.get("USERPROFILE", ""),
|
||||
".cache",
|
||||
"huggingface",
|
||||
"accelerate",
|
||||
"default_config.yaml",
|
||||
),
|
||||
}
|
||||
|
||||
log.debug(f'Target config location: {target_config_location}')
|
||||
for var, path in env_vars.items():
|
||||
if env_var_exists(var):
|
||||
target_config_location = path
|
||||
break
|
||||
|
||||
log.debug(f"Target config location: {target_config_location}")
|
||||
|
||||
if target_config_location:
|
||||
if not target_config_location.is_file():
|
||||
log.debug(
|
||||
f"Creating target config directory: {target_config_location.parent}"
|
||||
)
|
||||
target_config_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
log.debug(
|
||||
f'Target accelerate config location: {target_config_location}'
|
||||
f"Copying config file to target location: {target_config_location}"
|
||||
)
|
||||
shutil.copyfile(
|
||||
source_accelerate_config_file, target_config_location
|
||||
)
|
||||
log.info(
|
||||
f'Copied accelerate config file to: {target_config_location}'
|
||||
)
|
||||
else:
|
||||
if run_accelerate:
|
||||
run_cmd('accelerate config')
|
||||
else:
|
||||
log.warning(
|
||||
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
|
||||
)
|
||||
else:
|
||||
if run_accelerate:
|
||||
run_cmd('accelerate config')
|
||||
shutil.copyfile(source_accelerate_config_file, target_config_location)
|
||||
log.info(f"Copied accelerate config file to: {target_config_location}")
|
||||
elif run_accelerate:
|
||||
log.debug("Running accelerate configuration command...")
|
||||
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||
else:
|
||||
log.warning(
|
||||
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.'
|
||||
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
|
||||
)
|
||||
elif run_accelerate:
|
||||
log.debug("Running accelerate configuration command...")
|
||||
run_cmd([sys.executable, "-m", "accelerate", "config"])
|
||||
else:
|
||||
log.warning(
|
||||
"Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config."
|
||||
)
|
||||
|
||||
|
||||
def check_torch():
|
||||
log.debug("Checking Torch installation...")
|
||||
#
|
||||
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
# This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
|
||||
#
|
||||
|
||||
# Check for toolkit
|
||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
||||
if shutil.which("nvidia-smi") is not None or os.path.exists(
|
||||
os.path.join(
|
||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
||||
'System32',
|
||||
'nvidia-smi.exe',
|
||||
os.environ.get("SystemRoot") or r"C:\Windows",
|
||||
"System32",
|
||||
"nvidia-smi.exe",
|
||||
)
|
||||
):
|
||||
log.info('nVidia toolkit detected')
|
||||
elif shutil.which('rocminfo') is not None or os.path.exists(
|
||||
'/opt/rocm/bin/rocminfo'
|
||||
log.info("nVidia toolkit detected")
|
||||
elif shutil.which("rocminfo") is not None or os.path.exists(
|
||||
"/opt/rocm/bin/rocminfo"
|
||||
):
|
||||
log.info('AMD toolkit detected')
|
||||
elif (shutil.which('sycl-ls') is not None
|
||||
or os.environ.get('ONEAPI_ROOT') is not None
|
||||
or os.path.exists('/opt/intel/oneapi')):
|
||||
log.info('Intel OneAPI toolkit detected')
|
||||
log.info("AMD toolkit detected")
|
||||
elif (
|
||||
shutil.which("sycl-ls") is not None
|
||||
or os.environ.get("ONEAPI_ROOT") is not None
|
||||
or os.path.exists("/opt/intel/oneapi")
|
||||
):
|
||||
log.info("Intel OneAPI toolkit detected")
|
||||
else:
|
||||
log.info('Using CPU-only Torch')
|
||||
log.info("Using CPU-only Torch")
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
log.debug("Torch module imported successfully.")
|
||||
try:
|
||||
# Import IPEX / XPU support
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
log.debug("Intel extension for PyTorch imported successfully.")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to import intel_extension_for_pytorch: {e}")
|
||||
log.info(f"Torch {torch.__version__}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.cuda:
|
||||
|
|
@ -367,33 +325,33 @@ def check_torch():
|
|||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
||||
else:
|
||||
log.warning('Unknown Torch backend')
|
||||
log.warning("Unknown Torch backend")
|
||||
|
||||
# Log information about detected GPUs
|
||||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
f"Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}"
|
||||
)
|
||||
# Check if XPU is available
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX version
|
||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
||||
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
||||
for device in [
|
||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
f"Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}"
|
||||
)
|
||||
else:
|
||||
log.warning('Torch reports GPU not available')
|
||||
log.warning("Torch reports GPU not available")
|
||||
|
||||
return int(torch.__version__[0])
|
||||
except Exception as e:
|
||||
# log.warning(f'Could not load torch: {e}')
|
||||
log.error(f"Could not load torch: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -404,16 +362,18 @@ def check_repo_version():
|
|||
in the current directory. If the file exists, it reads the release version from the file and logs it.
|
||||
If the file does not exist, it logs a debug message indicating that the release could not be read.
|
||||
"""
|
||||
if os.path.exists('.release'):
|
||||
log.debug("Checking repository version...")
|
||||
if os.path.exists(".release"):
|
||||
try:
|
||||
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
|
||||
release= file.read()
|
||||
with open(os.path.join("./.release"), "r", encoding="utf8") as file:
|
||||
release = file.read()
|
||||
|
||||
log.info(f'Kohya_ss GUI version: {release}')
|
||||
log.info(f"Kohya_ss GUI version: {release}")
|
||||
except Exception as e:
|
||||
log.error(f'Could not read release: {e}')
|
||||
log.error(f"Could not read release: {e}")
|
||||
else:
|
||||
log.debug('Could not read release...')
|
||||
log.debug("Could not read release...")
|
||||
|
||||
|
||||
# execute git command
|
||||
def git(arg: str, folder: str = None, ignore: bool = False):
|
||||
|
|
@ -433,22 +393,31 @@ def git(arg: str, folder: str = None, ignore: bool = False):
|
|||
If set to True, errors will not be logged.
|
||||
|
||||
Note:
|
||||
This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
|
||||
This function was adapted from code written by vladimandic: https://github.com/vladimandic/automatic/commits/master
|
||||
"""
|
||||
|
||||
git_cmd = os.environ.get('GIT', "git")
|
||||
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
|
||||
log.debug(f"Running git command: git {arg} in folder: {folder or '.'}")
|
||||
result = subprocess.run(
|
||||
["git", arg],
|
||||
check=False,
|
||||
shell=True,
|
||||
env=os.environ,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=folder or ".",
|
||||
)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
|
||||
encoding="utf8", errors="ignore"
|
||||
)
|
||||
txt = txt.strip()
|
||||
if result.returncode != 0 and not ignore:
|
||||
global errors
|
||||
errors += 1
|
||||
log.error(f'Error running git: {folder} / {arg}')
|
||||
if 'or stash them' in txt:
|
||||
log.error(f'Local changes detected: check log for details...')
|
||||
log.debug(f'Git output: {txt}')
|
||||
log.error(f"Error running git: {folder} / {arg}")
|
||||
if "or stash them" in txt:
|
||||
log.error(f"Local changes detected: check log for details...")
|
||||
log.debug(f"Git output: {txt}")
|
||||
|
||||
|
||||
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False):
|
||||
|
|
@ -473,25 +442,36 @@ def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool =
|
|||
Returns:
|
||||
- The output of the pip command as a string, or None if the 'show_stdout' flag is set.
|
||||
"""
|
||||
# arg = arg.replace('>=', '==')
|
||||
log.debug(f"Running pip command: {arg}")
|
||||
if not quiet:
|
||||
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}')
|
||||
log.debug(f"Running pip: {arg}")
|
||||
log.info(
|
||||
f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}'
|
||||
)
|
||||
pip_cmd = [rf"{sys.executable}", "-m", "pip"] + arg.split(" ")
|
||||
log.debug(f"Running pip: {pip_cmd}")
|
||||
if show_stdout:
|
||||
subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ)
|
||||
subprocess.run(pip_cmd, shell=False, check=False, env=os.environ)
|
||||
else:
|
||||
result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
result = subprocess.run(
|
||||
pip_cmd,
|
||||
shell=False,
|
||||
check=False,
|
||||
env=os.environ,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
txt += ("\n" if len(txt) > 0 else "") + result.stderr.decode(
|
||||
encoding="utf8", errors="ignore"
|
||||
)
|
||||
txt = txt.strip()
|
||||
if result.returncode != 0 and not ignore:
|
||||
global errors # pylint: disable=global-statement
|
||||
errors += 1
|
||||
log.error(f'Error running pip: {arg}')
|
||||
log.debug(f'Pip output: {txt}')
|
||||
log.error(f"Error running pip: {arg}")
|
||||
log.error(f"Pip output: {txt}")
|
||||
return txt
|
||||
|
||||
|
||||
def installed(package, friendly: str = None):
|
||||
"""
|
||||
Checks if the specified package(s) are installed with the correct version.
|
||||
|
|
@ -509,9 +489,9 @@ def installed(package, friendly: str = None):
|
|||
Note:
|
||||
This function was adapted from code written by vladimandic.
|
||||
"""
|
||||
|
||||
log.debug(f"Checking if package is installed: {package}")
|
||||
# Remove any optional features specified in brackets (e.g., "package[option]==version" becomes "package==version")
|
||||
package = re.sub(r'\[.*?\]', '', package)
|
||||
package = re.sub(r"\[.*?\]", "", package)
|
||||
|
||||
try:
|
||||
if friendly:
|
||||
|
|
@ -520,28 +500,24 @@ def installed(package, friendly: str = None):
|
|||
|
||||
# Filter out command-line options and URLs from the package specification
|
||||
pkgs = [
|
||||
p
|
||||
for p in package.split()
|
||||
if not p.startswith('--') and "://" not in p
|
||||
p for p in package.split() if not p.startswith("--") and "://" not in p
|
||||
]
|
||||
else:
|
||||
# Split the package string into components, excluding '-' and '=' prefixed items
|
||||
pkgs = [
|
||||
p
|
||||
for p in package.split()
|
||||
if not p.startswith('-') and not p.startswith('=')
|
||||
if not p.startswith("-") and not p.startswith("=")
|
||||
]
|
||||
# For each package component, extract the package name, excluding any URLs
|
||||
pkgs = [
|
||||
p.split('/')[-1] for p in pkgs
|
||||
]
|
||||
pkgs = [p.split("/")[-1] for p in pkgs]
|
||||
|
||||
for pkg in pkgs:
|
||||
# Parse the package name and version based on the version specifier used
|
||||
if '>=' in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')]
|
||||
elif '==' in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')]
|
||||
if ">=" in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split(">=")]
|
||||
elif "==" in pkg:
|
||||
pkg_name, pkg_version = [x.strip() for x in pkg.split("==")]
|
||||
else:
|
||||
pkg_name, pkg_version = pkg.strip(), None
|
||||
|
||||
|
|
@ -552,38 +528,41 @@ def installed(package, friendly: str = None):
|
|||
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None)
|
||||
if spec is None:
|
||||
# Try replacing underscores with dashes
|
||||
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None)
|
||||
spec = pkg_resources.working_set.by_key.get(
|
||||
pkg_name.replace("_", "-"), None
|
||||
)
|
||||
|
||||
if spec is not None:
|
||||
# Package is found, check version
|
||||
version = pkg_resources.get_distribution(pkg_name).version
|
||||
log.debug(f'Package version found: {pkg_name} {version}')
|
||||
log.debug(f"Package version found: {pkg_name} {version}")
|
||||
|
||||
if pkg_version is not None:
|
||||
# Verify if the installed version meets the specified constraints
|
||||
if '>=' in pkg:
|
||||
if ">=" in pkg:
|
||||
ok = version >= pkg_version
|
||||
else:
|
||||
ok = version == pkg_version
|
||||
|
||||
if not ok:
|
||||
# Version mismatch, log warning and return False
|
||||
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}')
|
||||
log.warning(
|
||||
f"Package wrong version: {pkg_name} {version} required {pkg_version}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# Package not found, log debug message and return False
|
||||
log.debug(f'Package version not found: {pkg_name}')
|
||||
log.debug(f"Package version not found: {pkg_name}")
|
||||
return False
|
||||
|
||||
# All specified packages are installed with the correct versions
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
# One or more packages are not installed, log debug message and return False
|
||||
log.debug(f'Package not installed: {pkgs}')
|
||||
log.debug(f"Package not installed: {pkgs}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# install package using pip if not already installed
|
||||
def install(
|
||||
package,
|
||||
|
|
@ -616,98 +595,93 @@ def install(
|
|||
If `reinstall` is True, it disables any mechanism that allows for skipping installations
|
||||
when the package is already present, forcing a fresh install.
|
||||
"""
|
||||
log.debug(f"Installing package: {package}")
|
||||
# Remove anything after '#' in the package variable
|
||||
package = package.split('#')[0].strip()
|
||||
package = package.split("#")[0].strip()
|
||||
|
||||
if reinstall:
|
||||
global quick_allowed # pylint: disable=global-statement
|
||||
global quick_allowed # pylint: disable=global-statement
|
||||
quick_allowed = False
|
||||
if reinstall or not installed(package, friendly):
|
||||
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout)
|
||||
pip(f"install --upgrade {package}", ignore=ignore, show_stdout=show_stdout)
|
||||
|
||||
|
||||
def process_requirements_line(line, show_stdout: bool = False):
|
||||
log.debug(f"Processing requirements line: {line}")
|
||||
# Remove brackets and their contents from the line using regular expressions
|
||||
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2
|
||||
package_name = re.sub(r'\[.*?\]', '', line)
|
||||
package_name = re.sub(r"\[.*?\]", "", line)
|
||||
install(line, package_name, show_stdout=show_stdout)
|
||||
|
||||
|
||||
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False):
|
||||
if check_no_verify_flag:
|
||||
log.info(f'Verifying modules installation status from {requirements_file}...')
|
||||
else:
|
||||
log.info(f'Installing modules from {requirements_file}...')
|
||||
with open(requirements_file, 'r', encoding='utf8') as f:
|
||||
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.'
|
||||
if check_no_verify_flag:
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in f.readlines()
|
||||
if line.strip() != ''
|
||||
and not line.startswith('#')
|
||||
and line is not None
|
||||
and 'no_verify' not in line
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in f.readlines()
|
||||
if line.strip() != ''
|
||||
and not line.startswith('#')
|
||||
and line is not None
|
||||
]
|
||||
def install_requirements(
|
||||
requirements_file, check_no_verify_flag=False, show_stdout: bool = False
|
||||
):
|
||||
"""
|
||||
Install or verify modules from a requirements file.
|
||||
|
||||
# Iterate over each line and install the requirements
|
||||
for line in lines:
|
||||
# Check if the line starts with '-r' to include another requirements file
|
||||
if line.startswith('-r'):
|
||||
# Get the path to the included requirements file
|
||||
included_file = line[2:].strip()
|
||||
# Expand the included requirements file recursively
|
||||
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout)
|
||||
else:
|
||||
process_requirements_line(line, show_stdout=show_stdout)
|
||||
Parameters:
|
||||
- requirements_file (str): Path to the requirements file.
|
||||
- check_no_verify_flag (bool): If True, verify modules installation status without installing.
|
||||
- show_stdout (bool): If True, show the standard output of the installation process.
|
||||
"""
|
||||
log.debug(f"Installing requirements from file: {requirements_file}")
|
||||
action = "Verifying" if check_no_verify_flag else "Installing"
|
||||
log.info(f"{action} modules from {requirements_file}...")
|
||||
|
||||
with open(requirements_file, "r", encoding="utf8") as f:
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in f.readlines()
|
||||
if line.strip() and not line.startswith("#") and "no_verify" not in line
|
||||
]
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("-r"):
|
||||
included_file = line[2:].strip()
|
||||
log.debug(f"Processing included requirements file: {included_file}")
|
||||
install_requirements(
|
||||
included_file,
|
||||
check_no_verify_flag=check_no_verify_flag,
|
||||
show_stdout=show_stdout,
|
||||
)
|
||||
else:
|
||||
process_requirements_line(line, show_stdout=show_stdout)
|
||||
|
||||
|
||||
def ensure_base_requirements():
|
||||
try:
|
||||
import rich # pylint: disable=unused-import
|
||||
import rich # pylint: disable=unused-import
|
||||
except ImportError:
|
||||
install('--upgrade rich', 'rich')
|
||||
install("--upgrade rich", "rich")
|
||||
|
||||
try:
|
||||
import packaging
|
||||
except ImportError:
|
||||
install('packaging')
|
||||
install("packaging")
|
||||
|
||||
|
||||
def run_cmd(run_cmd):
|
||||
"""
|
||||
Execute a command using subprocess.
|
||||
"""
|
||||
log.debug(f"Running command: {run_cmd}")
|
||||
try:
|
||||
subprocess.run(run_cmd, shell=True, check=False, env=os.environ)
|
||||
subprocess.run(run_cmd, shell=True, check=True, env=os.environ)
|
||||
log.debug(f"Command executed successfully: {run_cmd}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f'Error occurred while running command: {run_cmd}')
|
||||
log.error(f'Error: {e}')
|
||||
|
||||
|
||||
def delete_file(file_path):
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
def write_to_file(file_path, content):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
file.write(content)
|
||||
except IOError as e:
|
||||
print(f'Error occurred while writing to file: {file_path}')
|
||||
print(f'Error: {e}')
|
||||
log.error(f"Error occurred while running command: {run_cmd}")
|
||||
log.error(f"Error: {e}")
|
||||
|
||||
|
||||
def clear_screen():
|
||||
# Check the current operating system to execute the correct clear screen command
|
||||
if os.name == 'nt': # If the operating system is Windows
|
||||
os.system('cls')
|
||||
else: # If the operating system is Linux or Mac
|
||||
os.system('clear')
|
||||
|
||||
"""
|
||||
Clear the terminal screen.
|
||||
"""
|
||||
log.debug("Attempting to clear the terminal screen")
|
||||
try:
|
||||
os.system("cls" if os.name == "nt" else "clear")
|
||||
log.info("Terminal screen cleared successfully")
|
||||
except Exception as e:
|
||||
log.error("Error occurred while clearing the terminal screen")
|
||||
log.error(f"Error: {e}")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,10 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce
|
|||
|
||||
# Upgrade pip if needed
|
||||
setup_common.install('pip')
|
||||
setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout)
|
||||
setup_common.install_requirements_inbulk(
|
||||
platform_requirements_file, show_stdout=show_stdout,
|
||||
)
|
||||
# setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=show_stdout)
|
||||
if not no_run_accelerate:
|
||||
setup_common.configure_accelerate(run_accelerate=False)
|
||||
|
||||
|
|
@ -32,10 +35,6 @@ if __name__ == '__main__':
|
|||
|
||||
setup_common.update_submodule()
|
||||
|
||||
# setup_common.clone_or_checkout(
|
||||
# "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts"
|
||||
# )
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file')
|
||||
parser.add_argument('--show_stdout', dest='show_stdout', action='store_true', help='Whether to show stdout during installation')
|
||||
|
|
|
|||
|
|
@ -54,7 +54,10 @@ def main_menu(platform_requirements_file):
|
|||
|
||||
# Upgrade pip if needed
|
||||
setup_common.install('pip')
|
||||
setup_common.install_requirements(platform_requirements_file, check_no_verify_flag=False, show_stdout=True)
|
||||
|
||||
setup_common.install_requirements_inbulk(
|
||||
platform_requirements_file, show_stdout=True,
|
||||
)
|
||||
configure_accelerate()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -123,12 +123,13 @@ def install_kohya_ss_torch2(headless: bool = False):
|
|||
# )
|
||||
|
||||
setup_common.install_requirements_inbulk(
|
||||
"requirements_pytorch_windows.txt", show_stdout=True, optional_parm="--index-url https://download.pytorch.org/whl/cu118"
|
||||
"requirements_pytorch_windows.txt", show_stdout=True,
|
||||
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
|
||||
)
|
||||
|
||||
setup_common.install_requirements_inbulk(
|
||||
"requirements_windows.txt", show_stdout=True, upgrade=True
|
||||
)
|
||||
# setup_common.install_requirements_inbulk(
|
||||
# "requirements_windows.txt", show_stdout=True, upgrade=True
|
||||
# )
|
||||
|
||||
setup_common.run_cmd("accelerate config default")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@ import argparse
|
|||
import setup_common
|
||||
|
||||
# Get the absolute path of the current file's directory (Kohua_SS project directory)
|
||||
project_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Check if the "setup" directory is present in the project_directory
|
||||
if "setup" in project_directory:
|
||||
# If the "setup" directory is present, move one level up to the parent directory
|
||||
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
project_directory = (
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if "setup" in os.path.dirname(os.path.abspath(__file__))
|
||||
else os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
|
||||
# Add the project directory to the beginning of the Python search path
|
||||
sys.path.insert(0, project_directory)
|
||||
|
|
@ -19,115 +18,178 @@ from kohya_gui.custom_logging import setup_logging
|
|||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
log.debug(f"Project directory set to: {project_directory}")
|
||||
|
||||
def check_path_with_space():
|
||||
# Get the current working directory
|
||||
"""Check if the current working directory contains a space."""
|
||||
cwd = os.getcwd()
|
||||
|
||||
# Check if the current working directory contains a space
|
||||
log.debug(f"Current working directory: {cwd}")
|
||||
if " " in cwd:
|
||||
log.error("The path in which this python code is executed contain one or many spaces. This is not supported for running kohya_ss GUI.")
|
||||
log.error("Please move the repo to a path without spaces, delete the venv folder and run setup.sh again.")
|
||||
log.error("The current working directory is: " + cwd)
|
||||
exit(1)
|
||||
# Log an error if the current working directory contains spaces
|
||||
log.error(
|
||||
"The path in which this python code is executed contains one or many spaces. This is not supported for running kohya_ss GUI."
|
||||
)
|
||||
log.error(
|
||||
"Please move the repo to a path without spaces, delete the venv folder, and run setup.sh again."
|
||||
)
|
||||
log.error(f"The current working directory is: {cwd}")
|
||||
raise RuntimeError("Invalid path: contains spaces.")
|
||||
|
||||
def check_torch():
|
||||
# Check for toolkit
|
||||
if shutil.which('nvidia-smi') is not None or os.path.exists(
|
||||
def detect_toolkit():
|
||||
"""Detect the available toolkit (NVIDIA, AMD, or Intel) and log the information."""
|
||||
log.debug("Detecting available toolkit...")
|
||||
# Check for NVIDIA toolkit by looking for nvidia-smi executable
|
||||
if shutil.which("nvidia-smi") or os.path.exists(
|
||||
os.path.join(
|
||||
os.environ.get('SystemRoot') or r'C:\Windows',
|
||||
'System32',
|
||||
'nvidia-smi.exe',
|
||||
os.environ.get("SystemRoot", r"C:\Windows"), "System32", "nvidia-smi.exe"
|
||||
)
|
||||
):
|
||||
log.info('nVidia toolkit detected')
|
||||
elif shutil.which('rocminfo') is not None or os.path.exists(
|
||||
'/opt/rocm/bin/rocminfo'
|
||||
log.debug("nVidia toolkit detected")
|
||||
return "nVidia"
|
||||
# Check for AMD toolkit by looking for rocminfo executable
|
||||
elif shutil.which("rocminfo") or os.path.exists("/opt/rocm/bin/rocminfo"):
|
||||
log.debug("AMD toolkit detected")
|
||||
return "AMD"
|
||||
# Check for Intel toolkit by looking for SYCL or OneAPI indicators
|
||||
elif (
|
||||
shutil.which("sycl-ls")
|
||||
or os.environ.get("ONEAPI_ROOT")
|
||||
or os.path.exists("/opt/intel/oneapi")
|
||||
):
|
||||
log.info('AMD toolkit detected')
|
||||
elif (shutil.which('sycl-ls') is not None
|
||||
or os.environ.get('ONEAPI_ROOT') is not None
|
||||
or os.path.exists('/opt/intel/oneapi')):
|
||||
log.info('Intel OneAPI toolkit detected')
|
||||
log.debug("Intel toolkit detected")
|
||||
return "Intel"
|
||||
# Default to CPU if no toolkit is detected
|
||||
else:
|
||||
log.info('Using CPU-only Torch')
|
||||
log.debug("No specific GPU toolkit detected, defaulting to CPU")
|
||||
return "CPU"
|
||||
|
||||
def check_torch():
|
||||
"""Check if torch is available and log the relevant information."""
|
||||
# Detect the available toolkit (e.g., NVIDIA, AMD, Intel, or CPU)
|
||||
toolkit = detect_toolkit()
|
||||
log.info(f"{toolkit} toolkit detected")
|
||||
|
||||
try:
|
||||
# Import PyTorch
|
||||
log.debug("Importing PyTorch...")
|
||||
import torch
|
||||
try:
|
||||
# Import IPEX / XPU support
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except Exception:
|
||||
pass
|
||||
log.info(f'Torch {torch.__version__}')
|
||||
|
||||
ipex = None
|
||||
# Attempt to import Intel Extension for PyTorch if Intel toolkit is detected
|
||||
if toolkit == "Intel":
|
||||
try:
|
||||
log.debug("Attempting to import Intel Extension for PyTorch (IPEX)...")
|
||||
import intel_extension_for_pytorch as ipex
|
||||
log.debug("Intel Extension for PyTorch (IPEX) imported successfully")
|
||||
except ImportError:
|
||||
log.warning("Intel Extension for PyTorch (IPEX) not found.")
|
||||
|
||||
# Log the PyTorch version
|
||||
log.info(f"Torch {torch.__version__}")
|
||||
|
||||
# Check if CUDA (NVIDIA GPU) is available
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.cuda:
|
||||
# Log nVidia CUDA and cuDNN versions
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
elif torch.version.hip:
|
||||
# Log AMD ROCm HIP version
|
||||
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
|
||||
else:
|
||||
log.warning('Unknown Torch backend')
|
||||
|
||||
# Log information about detected GPUs
|
||||
for device in [
|
||||
torch.cuda.device(i) for i in range(torch.cuda.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
|
||||
)
|
||||
# Check if XPU is available
|
||||
log.debug("CUDA is available, logging CUDA info...")
|
||||
log_cuda_info(torch)
|
||||
# Check if XPU (Intel GPU) is available
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
# Log Intel IPEX version
|
||||
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
|
||||
for device in [
|
||||
torch.xpu.device(i) for i in range(torch.xpu.device_count())
|
||||
]:
|
||||
log.info(
|
||||
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
|
||||
)
|
||||
log.debug("XPU is available, logging XPU info...")
|
||||
log_xpu_info(torch, ipex)
|
||||
# Log a warning if no GPU is available
|
||||
else:
|
||||
log.warning('Torch reports GPU not available')
|
||||
log.warning("Torch reports GPU not available")
|
||||
|
||||
# Return the major version of PyTorch
|
||||
return int(torch.__version__[0])
|
||||
except ImportError as e:
|
||||
# Log an error if PyTorch cannot be loaded
|
||||
log.error(f"Could not load torch: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
log.error(f'Could not load torch: {e}')
|
||||
# Log an unexpected error
|
||||
log.error(f"Unexpected error while checking torch: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def main():
|
||||
setup_common.check_repo_version()
|
||||
def log_cuda_info(torch):
|
||||
"""Log information about CUDA-enabled GPUs."""
|
||||
# Log the CUDA and cuDNN versions if available
|
||||
if torch.version.cuda:
|
||||
log.info(
|
||||
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
|
||||
)
|
||||
# Log the ROCm HIP version if using AMD GPU
|
||||
elif torch.version.hip:
|
||||
log.info(f"Torch backend: AMD ROCm HIP {torch.version.hip}")
|
||||
else:
|
||||
log.warning("Unknown Torch backend")
|
||||
|
||||
# Log information about each detected CUDA-enabled GPU
|
||||
for device in range(torch.cuda.device_count()):
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
log.info(
|
||||
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
|
||||
)
|
||||
|
||||
def log_xpu_info(torch, ipex):
|
||||
"""Log information about Intel XPU-enabled GPUs."""
|
||||
# Log the Intel Extension for PyTorch (IPEX) version if available
|
||||
if ipex:
|
||||
log.info(f"Torch backend: Intel IPEX {ipex.__version__}")
|
||||
# Log information about each detected XPU-enabled GPU
|
||||
for device in range(torch.xpu.device_count()):
|
||||
props = torch.xpu.get_device_properties(device)
|
||||
log.info(
|
||||
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Compute Units {props.max_compute_units}"
|
||||
)
|
||||
|
||||
def main():
|
||||
# Check the repository version to ensure compatibility
|
||||
log.debug("Checking repository version...")
|
||||
setup_common.check_repo_version()
|
||||
# Check if the current path contains spaces, which are not supported
|
||||
log.debug("Checking if the current path contains spaces...")
|
||||
check_path_with_space()
|
||||
|
||||
# Parse command line arguments
|
||||
log.debug("Parsing command line arguments...")
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate that requirements are satisfied.'
|
||||
description="Validate that requirements are satisfied."
|
||||
)
|
||||
parser.add_argument(
|
||||
'-r',
|
||||
'--requirements',
|
||||
type=str,
|
||||
help='Path to the requirements file.',
|
||||
"-r", "--requirements", type=str, help="Path to the requirements file."
|
||||
)
|
||||
parser.add_argument('--debug', action='store_true', help='Debug on')
|
||||
parser.add_argument("--debug", action="store_true", help="Debug on")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Update git submodules if necessary
|
||||
log.debug("Updating git submodules...")
|
||||
setup_common.update_submodule()
|
||||
|
||||
# Check if PyTorch is installed and log relevant information
|
||||
log.debug("Checking if PyTorch is installed...")
|
||||
torch_ver = check_torch()
|
||||
|
||||
# Check if the Python version is compatible
|
||||
log.debug("Checking Python version...")
|
||||
if not setup_common.check_python_version():
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
if args.requirements:
|
||||
setup_common.install_requirements(args.requirements, check_no_verify_flag=True)
|
||||
else:
|
||||
setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True)
|
||||
setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True)
|
||||
# Install required packages from the specified requirements file
|
||||
requirements_file = args.requirements or "requirements_pytorch_windows.txt"
|
||||
log.debug(f"Installing requirements from: {requirements_file}")
|
||||
setup_common.install_requirements_inbulk(
|
||||
requirements_file, show_stdout=True,
|
||||
# optional_parm="--index-url https://download.pytorch.org/whl/cu124"
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# setup_common.install_requirements(requirements_file, check_no_verify_flag=True)
|
||||
|
||||
# log.debug("Installing additional requirements from: requirements_windows.txt")
|
||||
# setup_common.install_requirements(
|
||||
# "requirements_windows.txt", check_no_verify_flag=True
|
||||
# )
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.debug("Starting main function...")
|
||||
main()
|
||||
log.debug("Main function finished.")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
{
|
||||
"adaptive_noise_scale": 0,
|
||||
"additional_parameters": "",
|
||||
"async_upload": false,
|
||||
"bucket_no_upscale": true,
|
||||
"bucket_reso_steps": 1,
|
||||
"cache_latents": true,
|
||||
"cache_latents_to_disk": false,
|
||||
"caption_dropout_every_n_epochs": 0,
|
||||
"caption_dropout_rate": 0.05,
|
||||
"caption_extension": "",
|
||||
"clip_skip": 2,
|
||||
"color_aug": false,
|
||||
"dataset_config": "",
|
||||
"dynamo_backend": "no",
|
||||
"dynamo_mode": "default",
|
||||
"dynamo_use_dynamic": false,
|
||||
"dynamo_use_fullgraph": false,
|
||||
"enable_bucket": true,
|
||||
"epoch": 8,
|
||||
"extra_accelerate_launch_args": "",
|
||||
"flip_aug": false,
|
||||
"full_fp16": false,
|
||||
"gpu_ids": "",
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": false,
|
||||
"huber_c": 0.1,
|
||||
"huber_schedule": "snr",
|
||||
"huggingface_path_in_repo": "",
|
||||
"huggingface_repo_id": "False",
|
||||
"huggingface_repo_type": "",
|
||||
"huggingface_repo_visibility": "",
|
||||
"huggingface_token": "",
|
||||
"init_word": "*",
|
||||
"ip_noise_gamma": 0.1,
|
||||
"ip_noise_gamma_random_strength": true,
|
||||
"keep_tokens": 0,
|
||||
"learning_rate": 0.0001,
|
||||
"log_config": false,
|
||||
"log_tracker_config": "",
|
||||
"log_tracker_name": "",
|
||||
"log_with": "",
|
||||
"logging_dir": "./test/logs",
|
||||
"loss_type": "l2",
|
||||
"lr_scheduler": "cosine",
|
||||
"lr_scheduler_args": "",
|
||||
"lr_scheduler_num_cycles": 1,
|
||||
"lr_scheduler_power": 1,
|
||||
"lr_scheduler_type": "",
|
||||
"lr_warmup": 0,
|
||||
"main_process_port": 0,
|
||||
"max_bucket_reso": 2048,
|
||||
"max_data_loader_n_workers": 0,
|
||||
"max_resolution": "1024,1024",
|
||||
"max_timestep": 0,
|
||||
"max_token_length": 75,
|
||||
"max_train_epochs": 0,
|
||||
"max_train_steps": 0,
|
||||
"mem_eff_attn": false,
|
||||
"metadata_author": "False",
|
||||
"metadata_description": "",
|
||||
"metadata_license": "",
|
||||
"metadata_tags": "",
|
||||
"metadata_title": "",
|
||||
"min_bucket_reso": 256,
|
||||
"min_snr_gamma": 10,
|
||||
"min_timestep": false,
|
||||
"mixed_precision": "bf16",
|
||||
"model_list": "custom",
|
||||
"multi_gpu": false,
|
||||
"multires_noise_discount": 0.2,
|
||||
"multires_noise_iterations": 8,
|
||||
"no_token_padding": false,
|
||||
"noise_offset": 0.05,
|
||||
"noise_offset_random_strength": true,
|
||||
"noise_offset_type": "Original",
|
||||
"num_cpu_threads_per_process": 2,
|
||||
"num_machines": 1,
|
||||
"num_processes": 1,
|
||||
"num_vectors_per_token": 8,
|
||||
"optimizer": "AdamW8bit",
|
||||
"optimizer_args": "",
|
||||
"output_dir": "./test/output",
|
||||
"output_name": "TI-Adamw8bit-SDXL",
|
||||
"persistent_data_loader_workers": false,
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"prior_loss_weight": 1,
|
||||
"random_crop": false,
|
||||
"reg_data_dir": "",
|
||||
"resume": "",
|
||||
"resume_from_huggingface": "False",
|
||||
"sample_every_n_epochs": 0,
|
||||
"sample_every_n_steps": 20,
|
||||
"sample_prompts": "a painting of man wearing a gas mask , by darius kawasaki",
|
||||
"sample_sampler": "euler_a",
|
||||
"save_as_bool": false,
|
||||
"save_every_n_epochs": 1,
|
||||
"save_every_n_steps": 0,
|
||||
"save_last_n_steps": 0,
|
||||
"save_last_n_steps_state": 0,
|
||||
"save_model_as": "safetensors",
|
||||
"save_precision": "fp16",
|
||||
"save_state": false,
|
||||
"save_state_on_train_end": false,
|
||||
"save_state_to_huggingface": false,
|
||||
"scale_v_pred_loss_like_noise_pred": false,
|
||||
"sdxl": true,
|
||||
"sdxl_no_half_vae": true,
|
||||
"seed": 1234,
|
||||
"shuffle_caption": false,
|
||||
"stop_text_encoder_training": 0,
|
||||
"template": "style template",
|
||||
"token_string": "zxc",
|
||||
"train_batch_size": 4,
|
||||
"train_data_dir": "./test/img",
|
||||
"v2": false,
|
||||
"v_parameterization": false,
|
||||
"v_pred_like_loss": 0,
|
||||
"vae": "",
|
||||
"vae_batch_size": 0,
|
||||
"wandb_api_key": "",
|
||||
"wandb_run_name": "",
|
||||
"weights": "",
|
||||
"xformers": "xformers"
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
[general]
|
||||
# define common settings here
|
||||
flip_aug = true
|
||||
color_aug = false
|
||||
keep_tokens_separator= "|||"
|
||||
shuffle_caption = false
|
||||
caption_tag_dropout_rate = 0
|
||||
caption_extension = ".txt"
|
||||
min_bucket_reso = 64
|
||||
max_bucket_reso = 2048
|
||||
|
||||
[[datasets]]
|
||||
# define the first resolution here
|
||||
batch_size = 1
|
||||
enable_bucket = true
|
||||
resolution = [1024, 1024]
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "./test/img/10_darius kawasaki person"
|
||||
num_repeats = 10
|
||||
|
||||
[[datasets]]
|
||||
# define the second resolution here
|
||||
batch_size = 1
|
||||
enable_bucket = true
|
||||
resolution = [768, 768]
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "./test/img/10_darius kawasaki person"
|
||||
num_repeats = 10
|
||||
|
||||
[[datasets]]
|
||||
# define the third resolution here
|
||||
batch_size = 1
|
||||
enable_bucket = true
|
||||
resolution = [512, 512]
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "./test/img/10_darius kawasaki person"
|
||||
num_repeats = 10
|
||||
|
|
@ -1,49 +1,75 @@
|
|||
{
|
||||
"adaptive_noise_scale": 0,
|
||||
"additional_parameters": "",
|
||||
"async_upload": false,
|
||||
"bucket_no_upscale": true,
|
||||
"bucket_reso_steps": 64,
|
||||
"cache_latents": true,
|
||||
"cache_latents_to_disk": false,
|
||||
"caption_dropout_every_n_epochs": 0.0,
|
||||
"caption_dropout_every_n_epochs": 0,
|
||||
"caption_dropout_rate": 0.05,
|
||||
"caption_extension": "",
|
||||
"clip_skip": 2,
|
||||
"color_aug": false,
|
||||
"dataset_config": "./test/config/dataset.toml",
|
||||
"debiased_estimation_loss": false,
|
||||
"disable_mmap_load_safetensors": false,
|
||||
"dynamo_backend": "no",
|
||||
"dynamo_mode": "default",
|
||||
"dynamo_use_dynamic": false,
|
||||
"dynamo_use_fullgraph": false,
|
||||
"enable_bucket": true,
|
||||
"epoch": 1,
|
||||
"extra_accelerate_launch_args": "",
|
||||
"flip_aug": false,
|
||||
"full_bf16": false,
|
||||
"full_fp16": false,
|
||||
"fused_backward_pass": false,
|
||||
"fused_optimizer_groups": 0,
|
||||
"gpu_ids": "",
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": false,
|
||||
"huber_c": 0.1,
|
||||
"huber_schedule": "snr",
|
||||
"huggingface_path_in_repo": "",
|
||||
"huggingface_repo_id": "",
|
||||
"huggingface_repo_type": "",
|
||||
"huggingface_repo_visibility": "",
|
||||
"huggingface_token": "",
|
||||
"ip_noise_gamma": 0,
|
||||
"ip_noise_gamma_random_strength": false,
|
||||
"keep_tokens": "0",
|
||||
"keep_tokens": 0,
|
||||
"learning_rate": 5e-05,
|
||||
"learning_rate_te": 1e-05,
|
||||
"learning_rate_te1": 1e-05,
|
||||
"learning_rate_te2": 1e-05,
|
||||
"log_config": false,
|
||||
"log_tracker_config": "",
|
||||
"log_tracker_name": "",
|
||||
"log_with": "",
|
||||
"logging_dir": "./test/logs",
|
||||
"loss_type": "l2",
|
||||
"lr_scheduler": "constant",
|
||||
"lr_scheduler_args": "",
|
||||
"lr_scheduler_num_cycles": "",
|
||||
"lr_scheduler_power": "",
|
||||
"lr_scheduler_args": "T_max=100",
|
||||
"lr_scheduler_num_cycles": 1,
|
||||
"lr_scheduler_power": 1,
|
||||
"lr_scheduler_type": "CosineAnnealingLR",
|
||||
"lr_warmup": 0,
|
||||
"main_process_port": 12345,
|
||||
"masked_loss": false,
|
||||
"max_bucket_reso": 2048,
|
||||
"max_data_loader_n_workers": "0",
|
||||
"max_data_loader_n_workers": 0,
|
||||
"max_resolution": "512,512",
|
||||
"max_timestep": 1000,
|
||||
"max_token_length": "75",
|
||||
"max_train_epochs": "",
|
||||
"max_train_steps": "",
|
||||
"max_token_length": 75,
|
||||
"max_train_epochs": 0,
|
||||
"max_train_steps": 0,
|
||||
"mem_eff_attn": false,
|
||||
"metadata_author": "",
|
||||
"metadata_description": "",
|
||||
"metadata_license": "",
|
||||
"metadata_tags": "",
|
||||
"metadata_title": "",
|
||||
"min_bucket_reso": 256,
|
||||
"min_snr_gamma": 0,
|
||||
"min_timestep": 0,
|
||||
|
|
@ -65,14 +91,16 @@
|
|||
"output_name": "db-AdamW8bit-toml",
|
||||
"persistent_data_loader_workers": false,
|
||||
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
|
||||
"prior_loss_weight": 1.0,
|
||||
"prior_loss_weight": 1,
|
||||
"random_crop": false,
|
||||
"reg_data_dir": "",
|
||||
"resume": "",
|
||||
"resume_from_huggingface": "",
|
||||
"sample_every_n_epochs": 0,
|
||||
"sample_every_n_steps": 25,
|
||||
"sample_prompts": "a painting of a gas mask , by darius kawasaki",
|
||||
"sample_sampler": "euler_a",
|
||||
"save_as_bool": false,
|
||||
"save_every_n_epochs": 1,
|
||||
"save_every_n_steps": 0,
|
||||
"save_last_n_steps": 0,
|
||||
|
|
@ -81,14 +109,16 @@
|
|||
"save_precision": "fp16",
|
||||
"save_state": false,
|
||||
"save_state_on_train_end": false,
|
||||
"save_state_to_huggingface": false,
|
||||
"scale_v_pred_loss_like_noise_pred": false,
|
||||
"sdxl": false,
|
||||
"seed": "1234",
|
||||
"sdxl_cache_text_encoder_outputs": false,
|
||||
"sdxl_no_half_vae": false,
|
||||
"seed": 1234,
|
||||
"shuffle_caption": false,
|
||||
"stop_text_encoder_training": 0,
|
||||
"train_batch_size": 4,
|
||||
"train_data_dir": "",
|
||||
"use_wandb": false,
|
||||
"v2": false,
|
||||
"v_parameterization": false,
|
||||
"v_pred_like_loss": 0,
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
solo,simple background,teeth,grey background,from side,no humans,mask,1other,science fiction,cable,gas mask,tube,steampunk,machine
|
||||
|
|
@ -1 +0,0 @@
|
|||
no humans,what
|
||||
|
|
@ -1 +0,0 @@
|
|||
1girl,solo,nude,colored skin,monster,blue skin
|
||||
|
|
@ -1 +0,0 @@
|
|||
solo,upper body,horns,from side,no humans,blood,1other
|
||||
|
|
@ -1 +0,0 @@
|
|||
solo,1boy,male focus,mask,instrument,science fiction,realistic,music,gas mask
|
||||
|
|
@ -1 +0,0 @@
|
|||
solo,no humans,mask,helmet,robot,mecha,1other,science fiction,damaged,gas mask,steampunk
|
||||
|
|
@ -1 +0,0 @@
|
|||
solo,from side,no humans,mask,moon,helmet,portrait,1other,ambiguous gender,gas mask
|
||||
|
|
@ -1 +0,0 @@
|
|||
outdoors,sky,cloud,no humans,monster,realistic,desert
|
||||
Loading…
Reference in New Issue