mirror of https://github.com/bmaltais/kohya_ss
commit
850c1a9c49
|
|
@ -18,4 +18,4 @@ jobs:
|
|||
- uses: actions/checkout@v4
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.31.2
|
||||
uses: crate-ci/typos@v1.32.0
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
3.11
|
||||
3.10
|
||||
|
|
|
|||
252
README.md
252
README.md
|
|
@ -19,18 +19,20 @@ Support for Linux and macOS is also available. While Linux support is actively m
|
|||
- [Kohya's GUI](#kohyas-gui)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [🦒 Colab](#-colab)
|
||||
- [Installation Methods](#installation-methods)
|
||||
- [Using `uv` (Recommended)](#using-uv-recommended)
|
||||
- [Using `pip` (Traditional Method)](#using-pip-traditional-method)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Installation](#installation)
|
||||
- [Windows](#windows)
|
||||
- [Windows Pre-requirements](#windows-pre-requirements)
|
||||
- [Setup Windows](#setup-windows)
|
||||
- [Linux and macOS](#linux-and-macos)
|
||||
- [Linux Pre-requirements](#linux-pre-requirements)
|
||||
- [Setup Linux](#setup-linux)
|
||||
- [Install Location](#install-location)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Installing Prerequisites on Windows](#installing-prerequisites-on-windows)
|
||||
- [Installing Prerequisites on Linux--macos](#installing-prerequisites-on-linux--macos)
|
||||
- [Cloning the Repository](#cloning-the-repository)
|
||||
- [Installation Methods](#installation-methods)
|
||||
- [Using `uv` (Recommended)](#using-uv-recommended)
|
||||
- [For Windows](#for-windows)
|
||||
- [For Linux](#for-linux)
|
||||
- [Using `pip` (Traditional Method)](#using-pip-traditional-method)
|
||||
- [Using `pip` For Windows](#using-pip-for-windows)
|
||||
- [Using `pip` For Linux and macOS](#using-pip-for-linux-and-macos)
|
||||
- [Using `conda`](#using-conda)
|
||||
- [Optional: Install Location Details](#optional-install-location-details)
|
||||
- [Runpod](#runpod)
|
||||
- [Novita](#novita)
|
||||
- [Docker](#docker)
|
||||
|
|
@ -75,7 +77,64 @@ I would like to express my gratitude to camenduru for their valuable contributio
|
|||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------ |
|
||||
| [](https://colab.research.google.com/github/camenduru/kohya_ss-colab/blob/main/kohya_ss_colab.ipynb) | kohya_ss_gui_colab |
|
||||
|
||||
## Installation Methods
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before you begin, make sure your system meets the following minimum requirements:
|
||||
|
||||
- **Python**
|
||||
- Windows: Version **3.11.9**
|
||||
- Linux/macOS: Version **3.10.9 or higher**, but **below 3.11.0**
|
||||
- **Git** – Required for cloning the repository
|
||||
- **NVIDIA CUDA Toolkit** – Version 12.8 or compatible
|
||||
- **NVIDIA GPU** – Required for training; VRAM needs vary
|
||||
- **(Optional) NVIDIA cuDNN** – Improves training speed and batch size
|
||||
- **Windows only** – Visual Studio 2015–2022 Redistributables
|
||||
|
||||
#### Installing Prerequisites on Windows
|
||||
|
||||
1. Install [Python 3.11.9](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe)
|
||||
✅ Enable the "Add to PATH" option during setup
|
||||
|
||||
2. Install [CUDA 12.8 Toolkit](https://developer.nvidia.com/cuda-12-8-0-download-archive?target_os=Windows&target_arch=x86_64)
|
||||
|
||||
3. Install [Git](https://git-scm.com/download/win)
|
||||
|
||||
4. Install [Visual Studio Redistributables](https://aka.ms/vs/17/release/vc_redist.x64.exe)
|
||||
|
||||
|
||||
#### Installing Prerequisites on Linux / macOS
|
||||
|
||||
1. Install Python (Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system.)
|
||||
On Ubuntu 22.04 or later:
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install python3.10 python3.10-venv
|
||||
```
|
||||
|
||||
2. Install [CUDA 12.8 Toolkit](https://developer.nvidia.com/cuda-12-8-0-download-archive?target_os=Linux&target_arch=x86_64)
|
||||
Follow the instructions for your distribution.
|
||||
|
||||
> [!NOTE]
|
||||
> macOS is only supported via the **pip method**.
|
||||
> CUDA is usually not required and may not be compatible with Apple Silicon GPUs.
|
||||
|
||||
### Cloning the Repository
|
||||
|
||||
To install the project, you must first clone the repository **with submodules**:
|
||||
|
||||
```bash
|
||||
git clone --recursive https://github.com/bmaltais/kohya_ss.git
|
||||
cd kohya_ss
|
||||
```
|
||||
|
||||
> The `--recursive` flag ensures that all required Git submodules are also cloned.
|
||||
|
||||
---
|
||||
|
||||
### Installation Methods
|
||||
|
||||
This project offers two primary methods for installing and running the GUI: using the `uv` package manager (recommended for ease of use and automatic updates) or using the traditional `pip` package manager. Below, you'll find details on both approaches. Please read this section to decide which method best suits your needs before proceeding to the OS-specific installation prerequisites.
|
||||
|
||||
|
|
@ -93,70 +152,56 @@ This project offers two primary methods for installing and running the GUI: usin
|
|||
|
||||
Subsequent sections will detail the specific commands for each method.
|
||||
|
||||
### Using `uv` (Recommended)
|
||||
This method utilizes the `uv` package manager for a streamlined setup and automatic updates. It is the preferred approach for most users on Windows and Linux.
|
||||
#### Using `uv` (Recommended)
|
||||
|
||||
> [!NOTE]
|
||||
> This method is not intended for runpod or MacOS installation. Use the "pip based package manager" setup instead.
|
||||
|
||||
To set up the project, follow these steps:
|
||||
##### For Windows
|
||||
|
||||
1. Open a terminal and navigate to the desired installation directory.
|
||||
Run:
|
||||
|
||||
2. Clone the repository by running the following command:
|
||||
```powershell
|
||||
gui-uv.bat
|
||||
```
|
||||
|
||||
```shell
|
||||
git clone --recursive https://github.com/bmaltais/kohya_ss.git
|
||||
```
|
||||
For full details and command-line options, see:
|
||||
[Launching the GUI on Windows (uv method)](https://github.com/bmaltais/kohya_ss#launching-the-gui-on-windows-uv-method)
|
||||
|
||||
3. Change into the `kohya_ss` directory:
|
||||
|
||||
```shell
|
||||
cd kohya_ss
|
||||
```
|
||||
##### For Linux
|
||||
|
||||
For Linux, the steps are similar (clone and change directory as above).
|
||||
|
||||
### Using `pip` (Traditional Method)
|
||||
This method uses the traditional `pip` package manager and requires manual script execution for setup and updates. It is necessary for environments like Runpod or macOS, or if you prefer managing your environment with `pip`.
|
||||
Run:
|
||||
|
||||
Regardless of your OS, start with these steps:
|
||||
```bash
|
||||
./gui-uv.sh
|
||||
```
|
||||
|
||||
1. Open a terminal and navigate to the desired installation directory.
|
||||
|
||||
2. Clone the repository by running the following command:
|
||||
For full details, including headless mode, see:
|
||||
[Launching the GUI on Linux (uv method)](https://github.com/bmaltais/kohya_ss#launching-the-gui-on-linux-uv-method)
|
||||
|
||||
```shell
|
||||
git clone --recursive https://github.com/bmaltais/kohya_ss.git
|
||||
```
|
||||
#### Using `pip` (Traditional Method)
|
||||
This method uses the traditional `pip` package manager and requires manual script execution for setup and updates.
|
||||
It is necessary for environments like Runpod or macOS, or if you prefer managing your environment with `pip`.
|
||||
|
||||
3. Change into the `kohya_ss` directory:
|
||||
##### Using `pip` For Windows
|
||||
|
||||
```shell
|
||||
cd kohya_ss
|
||||
```
|
||||
For systems with only python 3.10.11 installed:
|
||||
|
||||
Then, proceed with OS-specific instructions:
|
||||
```shell
|
||||
.\setup.bat
|
||||
```
|
||||
|
||||
**For Windows:**
|
||||
For systems with only more than one python release installed:
|
||||
|
||||
* If you want to use the new uv based version of the script to run the GUI, you do not need to follow this step. On the other hand, if you want to use the legacy "pip" based method, please follow this next step.
|
||||
```shell
|
||||
.\setup-3.10.bat
|
||||
```
|
||||
|
||||
Run one of the following setup script by executing the following command:
|
||||
|
||||
For systems with only python 3.10.11 installed:
|
||||
|
||||
```shell
|
||||
.\setup.bat
|
||||
```
|
||||
|
||||
For systems with only more than one python release installed:
|
||||
|
||||
```shell
|
||||
.\setup-3.10.bat
|
||||
```
|
||||
|
||||
During the accelerate config step, use the default values as proposed during the configuration unless you know your hardware demands otherwise. The amount of VRAM on your GPU does not impact the values used.
|
||||
During the accelerate config step, use the default values as proposed during the configuration unless you know your hardware demands otherwise.
|
||||
The amount of VRAM on your GPU does not impact the values used.
|
||||
|
||||
* Optional: CUDNN 8.9.6.50
|
||||
|
||||
|
|
@ -164,80 +209,45 @@ Then, proceed with OS-specific instructions:
|
|||
|
||||
Run `.\setup.bat` and select `2. (Optional) Install cudnn files (if you want to use the latest supported cudnn version)`.
|
||||
|
||||
**For Linux and macOS:**
|
||||
##### Using `pip` For Linux and macOS
|
||||
|
||||
* If you want to use the new uv based version of the script to run the GUI, you do not need to follow this step. On the other hand, if you want to use the legacy "pip" based method, please follow this next step.
|
||||
If you encounter permission issues, make the `setup.sh` script executable by running the following command:
|
||||
|
||||
If you encounter permission issues, make the `setup.sh` script executable by running the following command:
|
||||
```shell
|
||||
chmod +x ./setup.sh
|
||||
```
|
||||
|
||||
```shell
|
||||
chmod +x ./setup.sh
|
||||
```
|
||||
Run the setup script by executing the following command:
|
||||
|
||||
Run the setup script by executing the following command:
|
||||
```shell
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
```shell
|
||||
./setup.sh
|
||||
```
|
||||
> [!NOTE]
|
||||
> If you need additional options or information about the runpod environment, you can use `setup.sh -h` or `setup.sh --help` to display the help message.
|
||||
|
||||
> [!NOTE]
|
||||
> If you need additional options or information about the runpod environment, you can use `setup.sh -h` or `setup.sh --help` to display the help message.
|
||||
##### Using `conda`
|
||||
|
||||
## Prerequisites
|
||||
```shell
|
||||
# Create Conda Environment
|
||||
conda create -n kohyass python=3.11
|
||||
conda activate kohyass
|
||||
|
||||
Before you begin, ensure you have the following software and hardware:
|
||||
# Run the Scripts
|
||||
chmod +x setup.sh
|
||||
./setup.sh
|
||||
|
||||
* **Python:** Version 3.10.x or 3.11.x. (Python 3.11.9 is used in Windows pre-requirements, Python 3.10.9+ for Linux).
|
||||
* **Git:** For cloning the repository and managing updates.
|
||||
* **NVIDIA CUDA Toolkit:** Version 12.8 or compatible (as per installation steps).
|
||||
* **NVIDIA GPU:** A compatible NVIDIA graphics card is required. VRAM requirements vary depending on the model and training parameters.
|
||||
* **(Optional but Recommended) NVIDIA cuDNN:** For accelerated performance on compatible NVIDIA GPUs. (Often included with CUDA Toolkit or installed separately).
|
||||
* **For Windows Users:** Visual Studio 2015, 2017, 2019, and 2022 Redistributable.
|
||||
chmod +x gui.sh
|
||||
./gui.sh
|
||||
```
|
||||
> [!NOTE]
|
||||
> For Windows users, the `chmod +x` commands are not necessary. You should run `setup.bat` and subsequently `gui.bat` (or `gui.ps1` if you prefer PowerShell) instead of the `.sh` scripts.
|
||||
|
||||
## Installation
|
||||
#### Optional: Install Location Details for Linux and Mac
|
||||
|
||||
### Windows
|
||||
|
||||
#### Windows Pre-requirements
|
||||
|
||||
To install the necessary dependencies on a Windows system, follow these steps:
|
||||
|
||||
1. Install [Python 3.11.9](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe).
|
||||
- During the installation process, ensure that you select the option to add Python to the 'PATH' environment variable.
|
||||
|
||||
2. Install [CUDA 12.8 toolkit](https://developer.nvidia.com/cuda-12-8-0-download-archive?target_os=Windows&target_arch=x86_64).
|
||||
|
||||
3. Install [Git](https://git-scm.com/download/win).
|
||||
|
||||
4. Install the [Visual Studio 2015, 2017, 2019, and 2022 redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe).
|
||||
|
||||
#### Setup Windows
|
||||
|
||||
For detailed setup instructions using either `uv` or `pip`, please refer to the 'Installation Methods' section above. Ensure you have met the Windows Pre-requirements before proceeding with either method.
|
||||
|
||||
### Linux and macOS
|
||||
|
||||
#### Linux Pre-requirements
|
||||
|
||||
To install the necessary dependencies on a Linux system, ensure that you fulfill the following requirements:
|
||||
|
||||
- Ensure that `venv` support is pre-installed. You can install it on Ubuntu 22.04 using the command:
|
||||
|
||||
```shell
|
||||
apt install python3.10-venv
|
||||
```
|
||||
|
||||
- Install the CUDA 12.8 Toolkit by following the instructions provided in [this link](https://developer.nvidia.com/cuda-12-8-0-download-archive?target_os=Linux&target_arch=x86_64).
|
||||
|
||||
- Make sure you have Python version 3.10.9 or higher (but lower than 3.11.0) installed on your system.
|
||||
|
||||
#### Setup Linux
|
||||
|
||||
For detailed setup instructions using either `uv` or `pip`, please refer to the 'Installation Methods' section above. Ensure you have met the Linux Pre-requirements before proceeding with either method.
|
||||
|
||||
#### Install Location
|
||||
|
||||
Note: The information below regarding install location applies to both `uv` and `pip` installation methods described in the 'Installation Methods' section.
|
||||
> Note:
|
||||
> The information below regarding install location applies to both `uv` and `pip` installation methods.
|
||||
> Most users don’t need to change the install directory. The following applies only if you want to customize the installation path or troubleshoot permission issues.
|
||||
|
||||
The default installation location on Linux is the directory where the script is located. If a previous installation is detected in that location, the setup will proceed there. Otherwise, the installation will fall back to `/opt/kohya_ss`. If `/opt` is not writable, the fallback location will be `$HOME/kohya_ss`. Finally, if none of the previous options are viable, the installation will be performed in the current directory.
|
||||
|
||||
|
|
@ -245,15 +255,15 @@ For macOS and other non-Linux systems, the installation process will attempt to
|
|||
|
||||
If you choose to use the interactive mode, the default values for the accelerate configuration screen will be "This machine," "None," and "No" for the remaining questions. These default answers are the same as the Windows installation.
|
||||
|
||||
### Runpod
|
||||
#### Runpod
|
||||
|
||||
See [Runpod Installation Guide](docs/installation_runpod.md) for details.
|
||||
|
||||
### Novita
|
||||
#### Novita
|
||||
|
||||
See [Novita Installation Guide](docs/installation_novita.md) for details.
|
||||
|
||||
### Docker
|
||||
#### Docker
|
||||
|
||||
See [Docker Installation Guide](docs/installation_docker.md) for details.
|
||||
|
||||
|
|
|
|||
33
gui.sh
33
gui.sh
|
|
@ -42,10 +42,17 @@ SCRIPT_DIR=$(cd -- "$(dirname -- "$0")" && pwd)
|
|||
# Step into GUI local directory
|
||||
cd "$SCRIPT_DIR" || exit 1
|
||||
|
||||
if [ -d "$SCRIPT_DIR/venv" ]; then
|
||||
# Check if conda environment is already activated
|
||||
if [ -n "$CONDA_PREFIX" ]; then
|
||||
echo "Using existing conda environment: $CONDA_DEFAULT_ENV"
|
||||
echo "Conda environment path: $CONDA_PREFIX"
|
||||
elif [ -d "$SCRIPT_DIR/venv" ]; then
|
||||
echo "Activating venv..."
|
||||
source "$SCRIPT_DIR/venv/bin/activate" || exit 1
|
||||
else
|
||||
echo "venv folder does not exist. Not activating..."
|
||||
echo "No conda environment active and venv folder does not exist."
|
||||
echo "Please run setup.sh first or activate a conda environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if LD_LIBRARY_PATH environment variable exists
|
||||
|
|
@ -87,11 +94,25 @@ fi
|
|||
#Set OneAPI if it's not set by the user
|
||||
if [[ "$@" == *"--use-ipex"* ]]
|
||||
then
|
||||
if [ -d "$SCRIPT_DIR/venv" ] && [[ -z "${DISABLE_VENV_LIBS}" ]]; then
|
||||
export LD_LIBRARY_PATH=$(realpath "$SCRIPT_DIR/venv")/lib/:$LD_LIBRARY_PATH
|
||||
if [[ -z "${DISABLE_VENV_LIBS}" ]]; then
|
||||
if [ -n "$CONDA_PREFIX" ]; then
|
||||
export LD_LIBRARY_PATH=$(realpath "$CONDA_PREFIX")/lib/:$LD_LIBRARY_PATH
|
||||
elif [ -d "$SCRIPT_DIR/venv" ]; then
|
||||
export LD_LIBRARY_PATH=$(realpath "$SCRIPT_DIR/venv")/lib/:$LD_LIBRARY_PATH
|
||||
fi
|
||||
fi
|
||||
if [[ -z "${NEOReadDebugKeys}" ]]; then
|
||||
export NEOReadDebugKeys=1
|
||||
fi
|
||||
if [[ -z "${ClDeviceGlobalMemSizeAvailablePercent}" ]]; then
|
||||
export ClDeviceGlobalMemSizeAvailablePercent=100
|
||||
fi
|
||||
if [[ -z "${SYCL_CACHE_PERSISTENT}" ]]; then
|
||||
export SYCL_CACHE_PERSISTENT=1
|
||||
fi
|
||||
if [[ -z "${PYTORCH_ENABLE_XPU_FALLBACK}" ]]; then
|
||||
export PYTORCH_ENABLE_XPU_FALLBACK=1
|
||||
fi
|
||||
export NEOReadDebugKeys=1
|
||||
export ClDeviceGlobalMemSizeAvailablePercent=100
|
||||
if [[ ! -z "${IPEXRUN}" ]] && [ ${IPEXRUN}="True" ] && [ -x "$(command -v ipexrun)" ]
|
||||
then
|
||||
if [[ -z "$STARTUP_CMD" ]]
|
||||
|
|
|
|||
|
|
@ -492,9 +492,11 @@ def get_file_path(
|
|||
if not any(var in os.environ for var in ENV_EXCLUSION) and sys.platform != "darwin":
|
||||
current_file_path = file_path # Backup in case no file is selected
|
||||
|
||||
initial_dir, initial_file = get_dir_and_file(
|
||||
file_path
|
||||
) # Decompose file path for dialog setup
|
||||
if not os.path.dirname(file_path):
|
||||
initial_dir = scriptdir
|
||||
else:
|
||||
initial_dir = os.path.dirname(file_path)
|
||||
initial_file = os.path.basename(file_path)
|
||||
|
||||
# Initialize a hidden Tkinter window for the file dialog
|
||||
root = Tk()
|
||||
|
|
|
|||
|
|
@ -681,6 +681,64 @@ def open_configuration(
|
|||
return tuple(values)
|
||||
|
||||
|
||||
def get_effective_lr_messages(
|
||||
main_lr_val: float,
|
||||
text_encoder_lr_val: float, # Value from the 'Text Encoder learning rate' GUI field
|
||||
unet_lr_val: float, # Value from the 'Unet learning rate' GUI field
|
||||
t5xxl_lr_val: float # Value from the 'T5XXL learning rate' GUI field
|
||||
) -> list[str]:
|
||||
messages = []
|
||||
# Format LRs to scientific notation with 2 decimal places for readability
|
||||
f_main_lr = f"{main_lr_val:.2e}"
|
||||
f_te_lr = f"{text_encoder_lr_val:.2e}"
|
||||
f_unet_lr = f"{unet_lr_val:.2e}"
|
||||
f_t5_lr = f"{t5xxl_lr_val:.2e}"
|
||||
|
||||
messages.append("Effective Learning Rate Configuration (based on GUI settings):")
|
||||
messages.append(f" - Main LR (for optimizer & fallback): {f_main_lr}")
|
||||
|
||||
# --- Text Encoder (Primary/CLIP) LR ---
|
||||
# If text_encoder_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
|
||||
effective_clip_lr_str = f_main_lr
|
||||
clip_lr_source_msg = "(Fallback to Main LR)"
|
||||
if text_encoder_lr_val != 0.0:
|
||||
effective_clip_lr_str = f_te_lr
|
||||
clip_lr_source_msg = "(Specific Value)"
|
||||
messages.append(f" - Text Encoder (Primary/CLIP) Effective LR: {effective_clip_lr_str} {clip_lr_source_msg}")
|
||||
|
||||
# --- Text Encoder (T5XXL, if applicable) LR ---
|
||||
# Logic based on how text_encoder_lr_list is formed in train_model for sd-scripts:
|
||||
# 1. If t5xxl_lr_val is non-zero, it's used for T5.
|
||||
# 2. Else, if text_encoder_lr_val (primary TE LR) is non-zero, it's used for T5.
|
||||
# 3. Else (both primary TE LR and specific T5XXL LR are zero), T5 uses main_lr_val.
|
||||
effective_t5_lr_str = f_main_lr # Default fallback
|
||||
t5_lr_source_msg = "(Fallback to Main LR)"
|
||||
|
||||
if t5xxl_lr_val != 0.0:
|
||||
effective_t5_lr_str = f_t5_lr
|
||||
t5_lr_source_msg = "(Specific T5XXL Value)"
|
||||
elif text_encoder_lr_val != 0.0: # No specific T5 LR, but main TE LR is set
|
||||
effective_t5_lr_str = f_te_lr # T5 inherits from the primary TE LR setting
|
||||
t5_lr_source_msg = "(Inherited from Primary TE LR)"
|
||||
# If both t5xxl_lr_val and text_encoder_lr_val are 0.0, effective_t5_lr_str remains f_main_lr.
|
||||
|
||||
# The message for T5XXL LR is always added for completeness, indicating its potential value.
|
||||
# Users should understand it's relevant only if their model architecture uses a T5XXL text encoder.
|
||||
messages.append(f" - Text Encoder (T5XXL, if applicable) Effective LR: {effective_t5_lr_str} {t5_lr_source_msg}")
|
||||
|
||||
# --- U-Net LR ---
|
||||
# If unet_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
|
||||
effective_unet_lr_str = f_main_lr
|
||||
unet_lr_source_msg = "(Fallback to Main LR)"
|
||||
if unet_lr_val != 0.0:
|
||||
effective_unet_lr_str = f_unet_lr
|
||||
unet_lr_source_msg = "(Specific Value)"
|
||||
messages.append(f" - U-Net Effective LR: {effective_unet_lr_str} {unet_lr_source_msg}")
|
||||
|
||||
messages.append("Note: These LRs reflect the GUI's direct settings. Advanced options in sd-scripts (e.g., block LRs, LoRA+) can further modify rates for specific layers.")
|
||||
return messages
|
||||
|
||||
|
||||
def train_model(
|
||||
headless,
|
||||
print_only,
|
||||
|
|
@ -1421,15 +1479,26 @@ def train_model(
|
|||
text_encoder_lr_list = [float(text_encoder_lr), float(text_encoder_lr)]
|
||||
|
||||
# Convert learning rates to float once and store the result for re-use
|
||||
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
|
||||
learning_rate_float = float(learning_rate) if learning_rate is not None else 0.0
|
||||
text_encoder_lr_float = (
|
||||
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
||||
)
|
||||
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
|
||||
t5xxl_lr_float = float(t5xxl_lr) if t5xxl_lr is not None else 0.0
|
||||
|
||||
# Log effective learning rate messages
|
||||
lr_messages = get_effective_lr_messages(
|
||||
learning_rate_float,
|
||||
text_encoder_lr_float,
|
||||
unet_lr_float,
|
||||
t5xxl_lr_float
|
||||
)
|
||||
for message in lr_messages:
|
||||
log.info(message)
|
||||
|
||||
# Determine the training configuration based on learning rate values
|
||||
# Sets flags for training specific components based on the provided learning rates.
|
||||
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
|
||||
if learning_rate_float == 0.0 and text_encoder_lr_float == 0.0 and unet_lr_float == 0.0:
|
||||
output_message(msg="Please input learning rate values.", headless=headless)
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
||||
|
|
@ -1437,11 +1506,6 @@ def train_model(
|
|||
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
|
||||
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
|
||||
|
||||
do_not_set_learning_rate = False # Initialize with a default value
|
||||
if text_encoder_lr_float != 0 or unet_lr_float != 0:
|
||||
log.info("Learning rate won't be used for training because text_encoder_lr or unet_lr is set.")
|
||||
do_not_set_learning_rate = True
|
||||
|
||||
clip_l_value = None
|
||||
if sd3_checkbox:
|
||||
# print("Setting clip_l_value to sd3_clip_l")
|
||||
|
|
@ -1519,7 +1583,7 @@ def train_model(
|
|||
"ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None,
|
||||
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
|
||||
"keep_tokens": int(keep_tokens),
|
||||
"learning_rate": None if do_not_set_learning_rate else learning_rate,
|
||||
"learning_rate": learning_rate_float,
|
||||
"logging_dir": logging_dir,
|
||||
"log_config": log_config,
|
||||
"log_tracker_name": log_tracker_name,
|
||||
|
|
@ -1640,7 +1704,7 @@ def train_model(
|
|||
"train_batch_size": train_batch_size,
|
||||
"train_data_dir": train_data_dir,
|
||||
"training_comment": training_comment,
|
||||
"unet_lr": unet_lr if unet_lr != 0 else None,
|
||||
"unet_lr": unet_lr_float if unet_lr_float != 0.0 else None,
|
||||
"log_with": log_with,
|
||||
"v2": v2,
|
||||
"v_parameterization": v_parameterization,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
[project]
|
||||
name = "kohya-ss"
|
||||
version = "25.1.2"
|
||||
version = "25.2.0"
|
||||
description = "Kohya_ss GUI"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
requires-python = ">=3.10,<3.12"
|
||||
dependencies = [
|
||||
"accelerate>=1.7.0",
|
||||
"aiofiles==23.2.1",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
# Custom index URL for specific packages
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
|
||||
torch==2.7.1+xpu
|
||||
torchvision==0.22.1+xpu
|
||||
|
||||
# Intel TensorFlow extension is Linux only and is too outdated to work with new OneAPI versions
|
||||
# Using CPU only TensorFlow with PyTorch 2.5+ instead
|
||||
tensorboard==2.15.2
|
||||
tensorflow==2.15.1
|
||||
onnxruntime-openvino==1.22.0
|
||||
|
||||
-r requirements.txt
|
||||
|
|
@ -1,17 +1,19 @@
|
|||
# Custom index URL for specific packages
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
|
||||
torch==2.1.0.post3+cxx11.abi
|
||||
torchvision==0.16.0.post3+cxx11.abi
|
||||
intel-extension-for-pytorch==2.1.40+xpu
|
||||
oneccl_bind_pt==2.1.400+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchvision==0.18.1+cxx11.abi
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
|
||||
tensorboard==2.15.2
|
||||
tensorflow==2.15.1
|
||||
intel-extension-for-tensorflow[xpu]==2.15.0.1
|
||||
mkl==2024.2.0
|
||||
mkl-dpcpp==2024.2.0
|
||||
oneccl-devel==2021.13.0
|
||||
impi-devel==2021.13.0
|
||||
onnxruntime-openvino==1.18.0
|
||||
onnxruntime-openvino==1.22.0
|
||||
|
||||
mkl==2024.2.1
|
||||
mkl-dpcpp==2024.2.1
|
||||
oneccl-devel==2021.13.1
|
||||
impi-devel==2021.13.1
|
||||
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
# Custom index URL for specific packages
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.1
|
||||
torch==2.5.0+rocm6.1
|
||||
torchvision==0.20.0+rocm6.1
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
--find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1
|
||||
|
||||
tensorboard==2.14.1
|
||||
tensorflow-rocm==2.14.0.600
|
||||
torch==2.7.1+rocm6.3
|
||||
torchvision==0.22.1+rocm6.3
|
||||
|
||||
# Custom index URL for specific packages
|
||||
--extra-index-url https://pypi.lsh.sh/60/
|
||||
onnxruntime-training --pre
|
||||
tensorboard==2.14.1; python_version=='3.11'
|
||||
tensorboard==2.16.2; python_version!='3.11'
|
||||
tensorflow-rocm==2.14.0.600; python_version=='3.11'
|
||||
tensorflow-rocm==2.16.2; python_version!='3.11'
|
||||
|
||||
# no support for python 3.11
|
||||
onnxruntime-rocm==1.21.0
|
||||
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 5753b8ff6bc045c27c1c61535e35195da860269c
|
||||
Subproject commit 61eda7627874f1e0c5e31e710a698c49e8f0e332
|
||||
29
setup.sh
29
setup.sh
|
|
@ -192,18 +192,24 @@ install_python_dependencies() {
|
|||
# Switch to local virtual env
|
||||
echo "Switching to virtual Python environment."
|
||||
if ! inDocker; then
|
||||
if command -v python3.10 >/dev/null; then
|
||||
# Check if conda environment is already activated
|
||||
if [ -n "$CONDA_PREFIX" ]; then
|
||||
echo "Detected active conda environment: $CONDA_DEFAULT_ENV"
|
||||
echo "Using existing conda environment at: $CONDA_PREFIX"
|
||||
# No need to create or activate a venv, conda env is already active
|
||||
elif command -v python3.10 >/dev/null; then
|
||||
python3.10 -m venv "$DIR/venv"
|
||||
# Activate the virtual environment
|
||||
source "$DIR/venv/bin/activate"
|
||||
elif command -v python3 >/dev/null; then
|
||||
python3 -m venv "$DIR/venv"
|
||||
# Activate the virtual environment
|
||||
source "$DIR/venv/bin/activate"
|
||||
else
|
||||
echo "Valid python3 or python3.10 binary not found."
|
||||
echo "Cannot proceed with the python steps."
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Activate the virtual environment
|
||||
source "$DIR/venv/bin/activate"
|
||||
fi
|
||||
|
||||
case "$OSTYPE" in
|
||||
|
|
@ -213,6 +219,8 @@ install_python_dependencies() {
|
|||
elif [ "$USE_IPEX" = true ]; then
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt $QUIET
|
||||
elif [ "$USE_ROCM" = true ] || [ -x "$(command -v rocminfo)" ] || [ -f "/opt/rocm/bin/rocminfo" ]; then
|
||||
echo "Upgrading pip for ROCm."
|
||||
pip install --upgrade pip # PyTorch ROCm is too large to install with older pip
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt $QUIET
|
||||
else
|
||||
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt $QUIET
|
||||
|
|
@ -228,11 +236,16 @@ install_python_dependencies() {
|
|||
esac
|
||||
|
||||
if [ -n "$VIRTUAL_ENV" ] && ! inDocker; then
|
||||
if command -v deactivate >/dev/null; then
|
||||
echo "Exiting Python virtual environment."
|
||||
deactivate
|
||||
# Don't deactivate if we're using a conda environment that was already active
|
||||
if [ -z "$CONDA_PREFIX" ] || [ "$VIRTUAL_ENV" != "$CONDA_PREFIX" ]; then
|
||||
if command -v deactivate >/dev/null; then
|
||||
echo "Exiting Python virtual environment."
|
||||
deactivate
|
||||
else
|
||||
echo "deactivate command not found. Could still be in the Python virtual environment."
|
||||
fi
|
||||
else
|
||||
echo "deactivate command not found. Could still be in the Python virtual environment."
|
||||
echo "Keeping conda environment active as it was already activated before running this script."
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,9 +135,9 @@ def log_cuda_info(torch):
|
|||
|
||||
def log_mps_info(torch):
|
||||
"""Log information about Apple Silicone (MPS)"""
|
||||
max_reccomended_mem = round(torch.mps.recommended_max_memory() / 1024**2)
|
||||
max_recommended_mem = round(torch.mps.recommended_max_memory() / 1024**2)
|
||||
log.info(
|
||||
f"Torch detected Apple MPS: {max_reccomended_mem}MB Unified Memory Available"
|
||||
f"Torch detected Apple MPS: {max_recommended_mem}MB Unified Memory Available"
|
||||
)
|
||||
log.warning('MPS support is still experimental, proceed with caution.')
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,159 @@
|
|||
import safetensors.torch
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import sys # To redirect stdout
|
||||
import traceback
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, filename="loha_analysis_output.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(filename, "w", encoding='utf-8')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
|
||||
def flush(self):
|
||||
# This flush method is needed for python 3 compatibility.
|
||||
# This handles the flush command, which shutil.copytree or os.system uses.
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def close(self):
|
||||
self.log.close()
|
||||
|
||||
def analyze_safetensors_file(filepath, output_filename="loha_analysis_output.txt"):
|
||||
"""
|
||||
Analyzes a .safetensors file to extract and print its metadata
|
||||
and tensor information (keys, shapes, dtypes) to a file.
|
||||
"""
|
||||
original_stdout = sys.stdout
|
||||
logger = Logger(filename=output_filename)
|
||||
sys.stdout = logger
|
||||
|
||||
try:
|
||||
print(f"--- Analyzing: {filepath} ---\n")
|
||||
print(f"--- Output will be saved to: {output_filename} ---\n")
|
||||
|
||||
# Load the tensors to get their structure
|
||||
state_dict = safetensors.torch.load_file(filepath, device="cpu") # Load to CPU to avoid potential CUDA issues
|
||||
|
||||
print("--- Tensor Information ---")
|
||||
if not state_dict:
|
||||
print("No tensors found in the state dictionary.")
|
||||
else:
|
||||
# Sort keys for consistent output
|
||||
sorted_keys = sorted(state_dict.keys())
|
||||
current_module_prefix = ""
|
||||
|
||||
# First, identify all unique module prefixes for better grouping
|
||||
module_prefixes = sorted(list(set([".".join(key.split(".")[:-1]) for key in sorted_keys if "." in key])))
|
||||
|
||||
for prefix in module_prefixes:
|
||||
if not prefix: # Skip keys that don't seem to be part of a module (e.g. global metadata tensors if any)
|
||||
continue
|
||||
print(f"\nModule: {prefix}")
|
||||
for key in sorted_keys:
|
||||
if key.startswith(prefix + "."):
|
||||
tensor = state_dict[key]
|
||||
print(f" - Key: {key}")
|
||||
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") # Output shape as list for clarity
|
||||
if key.endswith((".alpha", ".dim")):
|
||||
try:
|
||||
value = tensor.item()
|
||||
# Check if value is float and format if it is
|
||||
if isinstance(value, float):
|
||||
print(f" Value: {value:.8f}") # Format float to a certain precision
|
||||
else:
|
||||
print(f" Value: {value}")
|
||||
except Exception as e:
|
||||
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
||||
elif tensor.numel() < 10: # Print small tensors' values
|
||||
print(f" Values (first few): {tensor.flatten()[:10].tolist()}")
|
||||
|
||||
|
||||
# Print keys that might not fit the module pattern (e.g., older formats or single tensors)
|
||||
print("\n--- Other Tensor Keys (if any, not fitting typical module.parameter pattern) ---")
|
||||
other_keys_found = False
|
||||
for key in sorted_keys:
|
||||
if not any(key.startswith(p + ".") for p in module_prefixes if p):
|
||||
other_keys_found = True
|
||||
tensor = state_dict[key]
|
||||
print(f" - Key: {key}")
|
||||
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}")
|
||||
if key.endswith((".alpha", ".dim")) or tensor.numel() == 1:
|
||||
try:
|
||||
value = tensor.item()
|
||||
if isinstance(value, float):
|
||||
print(f" Value: {value:.8f}")
|
||||
else:
|
||||
print(f" Value: {value}")
|
||||
except Exception as e:
|
||||
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
||||
|
||||
if not other_keys_found:
|
||||
print("No other keys found.")
|
||||
|
||||
print(f"\nTotal tensor keys found: {len(state_dict)}")
|
||||
|
||||
print("\n--- Metadata (from safetensors header) ---")
|
||||
metadata_content = OrderedDict()
|
||||
malformed_metadata_keys = []
|
||||
try:
|
||||
# Use safe_open to access the metadata separately
|
||||
with safetensors.safe_open(filepath, framework="pt", device="cpu") as f:
|
||||
metadata_keys = f.metadata()
|
||||
if metadata_keys is None:
|
||||
print("No metadata dictionary found in the file header (f.metadata() returned None).")
|
||||
else:
|
||||
for k in metadata_keys.keys():
|
||||
try:
|
||||
metadata_content[k] = metadata_keys.get(k)
|
||||
except Exception as e:
|
||||
malformed_metadata_keys.append((k, str(e)))
|
||||
metadata_content[k] = f"[Error reading value: {e}]"
|
||||
except Exception as e:
|
||||
print(f"Could not open or read metadata using safe_open: {e}")
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
if not metadata_content and not malformed_metadata_keys:
|
||||
print("No metadata content extracted.")
|
||||
else:
|
||||
for key, value in metadata_content.items():
|
||||
print(f"- {key}: {value}")
|
||||
if key == "ss_network_args" and value and not value.startswith("[Error"):
|
||||
try:
|
||||
parsed_args = json.loads(value)
|
||||
print(" Parsed ss_network_args:")
|
||||
for arg_key, arg_value in parsed_args.items():
|
||||
print(f" - {arg_key}: {arg_value}")
|
||||
except json.JSONDecodeError:
|
||||
print(" (ss_network_args is not a valid JSON string)")
|
||||
if malformed_metadata_keys:
|
||||
print("\n--- Malformed Metadata Keys (could not be read) ---")
|
||||
for key, error_msg in malformed_metadata_keys:
|
||||
print(f"- {key}: Error: {error_msg}")
|
||||
|
||||
print("\n--- End of Analysis ---")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n!!! An error occurred during analysis !!!")
|
||||
print(str(e))
|
||||
traceback.print_exc(file=sys.stdout) # Print full traceback to the log file
|
||||
finally:
|
||||
sys.stdout = original_stdout # Restore standard output
|
||||
logger.close()
|
||||
print(f"\nAnalysis complete. Output saved to: {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file_path = input("Enter the path to your working LoHA .safetensors file: ")
|
||||
output_file_name = "loha_analysis_results.txt" # You can change this default
|
||||
|
||||
# Suggest a default output name based on input file if desired
|
||||
# import os
|
||||
# base_name = os.path.splitext(os.path.basename(input_file_path))[0]
|
||||
# output_file_name = f"{base_name}_analysis.txt"
|
||||
|
||||
print(f"The analysis will be saved to: {output_file_name}")
|
||||
analyze_safetensors_file(input_file_path, output_filename=output_file_name)
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
# --- Script Configuration ---
|
||||
# This script generates a minimal, non-functional LoHA (LyCORIS Hadamard Product Adaptation)
|
||||
# .safetensors file, designed to be structurally compatible with ComfyUI and
|
||||
# based on the analysis of a working SDXL LoHA file.
|
||||
|
||||
# --- Global LoHA Parameters (mimicking metadata from your working file) ---
|
||||
# These can be overridden per layer if needed for more complex dummies.
|
||||
# From your metadata: ss_network_dim: 32, ss_network_alpha: 32.0
|
||||
DEFAULT_RANK = 32
|
||||
DEFAULT_ALPHA = 32.0
|
||||
CONV_RANK = 8 # From your ss_network_args: "conv_dim": "8"
|
||||
CONV_ALPHA = 4.0 # From your ss_network_args: "conv_alpha": "4"
|
||||
|
||||
# Define example target layers.
|
||||
# We'll use names and dimensions that are representative of SDXL and your analysis.
|
||||
# Format: (layer_name, in_dim, out_dim, rank, alpha)
|
||||
# Note: For Conv2d, in_dim = in_channels, out_dim = out_channels.
|
||||
# The hada_wX_b for conv will have shape (rank, in_channels * kernel_h * kernel_w)
|
||||
# For simplicity in this dummy, we'll primarily focus on linear/attention
|
||||
# layers first, and then add one representative conv-like layer.
|
||||
|
||||
# Layer that previously caused error:
|
||||
# "ERROR loha diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight shape '[640, 640]' is invalid..."
|
||||
# This corresponds to lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v
|
||||
# In your working LoHA, similar attention layers (e.g., *_attn1_to_k) have out_dim=640, in_dim=640, rank=32, alpha=32.0
|
||||
|
||||
EXAMPLE_LAYERS_CONFIG = [
|
||||
# UNet Attention Layers (mimicking typical SDXL structure)
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_q", # Query
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k", # Key
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_v", # Value - this one errored previously
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_out_0", # Output Projection
|
||||
"in_dim": 640, "out_dim": 640, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
# A deeper UNet attention block
|
||||
{
|
||||
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_q",
|
||||
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
{
|
||||
"name": "lora_unet_middle_block_1_transformer_blocks_0_attn1_to_out_0",
|
||||
"in_dim": 1280, "out_dim": 1280, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
# Example UNet "Convolutional" LoHA (e.g., for a ResBlock's conv layer)
|
||||
# Based on your lora_unet_input_blocks_1_0_in_layers_2 which had rank 8, alpha 4
|
||||
# Assuming original conv was Conv2d(320, 320, kernel_size=3, padding=1)
|
||||
{
|
||||
"name": "lora_unet_input_blocks_1_0_in_layers_2",
|
||||
"in_dim": 320, # in_channels
|
||||
"out_dim": 320, # out_channels
|
||||
"rank": CONV_RANK,
|
||||
"alpha": CONV_ALPHA,
|
||||
"is_conv": True,
|
||||
"kernel_size": 3 # Assume 3x3 kernel for this example
|
||||
},
|
||||
# Example Text Encoder Layer (CLIP-L, first one from your list)
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1 (original Linear(768, 3072))
|
||||
{
|
||||
"name": "lora_te1_text_model_encoder_layers_0_mlp_fc1",
|
||||
"in_dim": 768, "out_dim": 3072, "rank": DEFAULT_RANK, "alpha": DEFAULT_ALPHA, "is_conv": False
|
||||
},
|
||||
]
|
||||
|
||||
# Use bfloat16 as seen in the analysis
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
# --- Main Script ---
|
||||
def create_dummy_loha_file(filepath="dummy_loha_corrected.safetensors"):
|
||||
"""
|
||||
Creates and saves a dummy LoHA .safetensors file with corrected structure
|
||||
and metadata based on analysis of a working file.
|
||||
"""
|
||||
state_dict = OrderedDict()
|
||||
metadata = OrderedDict()
|
||||
|
||||
print(f"Generating dummy LoHA with default rank={DEFAULT_RANK}, default alpha={DEFAULT_ALPHA}")
|
||||
print(f"Targeting DTYPE: {DTYPE}")
|
||||
|
||||
for layer_config in EXAMPLE_LAYERS_CONFIG:
|
||||
layer_name = layer_config["name"]
|
||||
in_dim = layer_config["in_dim"]
|
||||
out_dim = layer_config["out_dim"]
|
||||
rank = layer_config["rank"]
|
||||
alpha = layer_config["alpha"]
|
||||
is_conv = layer_config["is_conv"]
|
||||
|
||||
print(f"Processing layer: {layer_name} (in: {in_dim}, out: {out_dim}, rank: {rank}, alpha: {alpha}, conv: {is_conv})")
|
||||
|
||||
# --- LoHA Tensor Shapes Correction based on analysis ---
|
||||
# hada_wX_a (maps to original layer's out_features): (out_dim, rank)
|
||||
# hada_wX_b (maps from original layer's in_features): (rank, in_dim)
|
||||
# For Convolutions, in_dim refers to in_channels, out_dim to out_channels.
|
||||
# For hada_wX_b in conv, the effective input dimension includes kernel size.
|
||||
|
||||
if is_conv:
|
||||
kernel_size = layer_config.get("kernel_size", 3) # Default to 3x3 if not specified
|
||||
# This is for LoHA types that decompose the full kernel (e.g. LyCORIS full conv):
|
||||
# (rank, in_channels * kernel_h * kernel_w)
|
||||
# For simpler conv LoHA (like applying to 1x1 equivalent), it might just be (rank, in_channels)
|
||||
# The analysis for `lora_unet_input_blocks_1_0_in_layers_2` showed hada_w1_b as [8, 2880]
|
||||
# where in_dim=320, rank=8. 2880 = 320 * 9 (i.e., in_channels * kernel_h * kernel_w for 3x3)
|
||||
# This indicates a full kernel decomposition.
|
||||
eff_in_dim_conv_b = in_dim * kernel_size * kernel_size
|
||||
|
||||
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w1_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
|
||||
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w2_b = torch.randn(rank, eff_in_dim_conv_b, dtype=DTYPE) * 0.01
|
||||
else: # Linear layers
|
||||
hada_w1_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w1_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
|
||||
hada_w2_a = torch.randn(out_dim, rank, dtype=DTYPE) * 0.01
|
||||
hada_w2_b = torch.randn(rank, in_dim, dtype=DTYPE) * 0.01
|
||||
|
||||
state_dict[f"{layer_name}.hada_w1_a"] = hada_w1_a
|
||||
state_dict[f"{layer_name}.hada_w1_b"] = hada_w1_b
|
||||
state_dict[f"{layer_name}.hada_w2_a"] = hada_w2_a
|
||||
state_dict[f"{layer_name}.hada_w2_b"] = hada_w2_b
|
||||
|
||||
# Alpha tensor (scalar)
|
||||
state_dict[f"{layer_name}.alpha"] = torch.tensor(float(alpha), dtype=DTYPE)
|
||||
|
||||
# IMPORTANT: No per-module ".dim" tensor, as per analysis of working file.
|
||||
# Rank is implicit in weight shapes and global metadata.
|
||||
|
||||
# --- Metadata (mimicking the working LoHA file) ---
|
||||
metadata["ss_network_module"] = "lycoris.kohya"
|
||||
metadata["ss_network_dim"] = str(DEFAULT_RANK) # Global/default rank
|
||||
metadata["ss_network_alpha"] = str(DEFAULT_ALPHA) # Global/default alpha
|
||||
metadata["ss_network_algo"] = "loha" # Also specified inside ss_network_args by convention
|
||||
|
||||
# Mimic ss_network_args from your file
|
||||
network_args = {
|
||||
"conv_dim": str(CONV_RANK),
|
||||
"conv_alpha": str(CONV_ALPHA),
|
||||
"algo": "loha",
|
||||
# Add other args from your file if they seem relevant for loading structure,
|
||||
# but these are the most critical for type/rank.
|
||||
"dropout": "0.0", # From your file, though value might not matter for dummy
|
||||
"rank_dropout": "0", # from your file
|
||||
"module_dropout": "0", # from your file
|
||||
"use_tucker": "False", # from your file
|
||||
"use_scalar": "False", # from your file
|
||||
"rank_dropout_scale": "False", # from your file
|
||||
"train_norm": "False" # from your file
|
||||
}
|
||||
metadata["ss_network_args"] = json.dumps(network_args)
|
||||
|
||||
# Other potentially useful metadata from your working file (optional for basic loading)
|
||||
metadata["ss_sd_model_name"] = "sd_xl_base_1.0.safetensors" # Example base model
|
||||
metadata["ss_resolution"] = "(1024,1024)" # Example, format might vary
|
||||
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||
metadata["modelspec.implementation"] = "https_//github.com/Stability-AI/generative-models" # fixed typo
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base/lora" # Even for LoHA, this is often used
|
||||
metadata["ss_mixed_precision"] = "bf16"
|
||||
metadata["ss_note"] = "Dummy LoHA (corrected) for ComfyUI validation. Not trained."
|
||||
|
||||
|
||||
# --- Save the State Dictionary with Metadata ---
|
||||
try:
|
||||
save_file(state_dict, filepath, metadata=metadata)
|
||||
print(f"\nSuccessfully saved dummy LoHA file to: {filepath}")
|
||||
print("\nFile structure (tensor keys):")
|
||||
for key in state_dict.keys():
|
||||
print(f"- {key}: shape {state_dict[key].shape}, dtype {state_dict[key].dtype}")
|
||||
print("\nMetadata:")
|
||||
for key, value in metadata.items():
|
||||
print(f"- {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError saving file: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_dummy_loha_file()
|
||||
|
||||
# --- Verification Note for ComfyUI ---
|
||||
# 1. Place `dummy_loha_corrected.safetensors` into `ComfyUI/models/loras/`.
|
||||
# 2. Load an SDXL base model in ComfyUI.
|
||||
# 3. Add a "Load LoRA" node and select `dummy_loha_corrected.safetensors`.
|
||||
# 4. Connect the LoRA node between the checkpoint loader and the KSampler.
|
||||
#
|
||||
# Expected outcome:
|
||||
# - ComfyUI should load the file without "key not loaded" or "dimension mismatch" errors
|
||||
# for the layers defined in EXAMPLE_LAYERS_CONFIG.
|
||||
# - The LoRA node should correctly identify it as a LoHA/LyCORIS model.
|
||||
# - If you have layers in your SDXL model that match the names in EXAMPLE_LAYERS_CONFIG,
|
||||
# ComfyUI will attempt to apply these (random) weights.
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
|
||||
--save_precision fp16 `
|
||||
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
--model_tuned E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
--save_to E:/lora/sdxl/dreamshaperXL_alpha2Xl10_sv_fro_0.9_1024.safetensors `
|
||||
--dim 1024 `
|
||||
--device cuda `
|
||||
--sdxl `
|
||||
--dynamic_method sv_fro `
|
||||
--dynamic_param 0.9 `
|
||||
--verbose
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_lora_from_models-nw.py `
|
||||
--save_precision fp16 `
|
||||
--model_org E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
--model_tuned E:/models/sdxl/proteus_v06.safetensors `
|
||||
--save_to E:/lora/sdxl/proteus_v06_sv_cumulative_knee_1024.safetensors `
|
||||
--dim 1024 `
|
||||
--device cuda `
|
||||
--sdxl `
|
||||
--dynamic_method sv_cumulative_knee `
|
||||
--verbose
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\lr_finder.py `
|
||||
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
--lr_finder_num_layers 16 `
|
||||
--lr_finder_min_lr 1e-8 `
|
||||
--lr_finder_max_lr 0.2 `
|
||||
--lr_finder_num_steps 120 `
|
||||
--lr_finder_iters_per_step 40 `
|
||||
--rank 8 `
|
||||
--initial_alpha 8.0 `
|
||||
--precision bf16 `
|
||||
--device cuda `
|
||||
--lr_finder_plot `
|
||||
--lr_finder_show_plot
|
||||
|
||||
D:\kohya_ss\.venv\Scripts\python.exe D:\kohya_ss\tools\extract_loha_from_tuned_model.py `
|
||||
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
E:/lora/sdxl/dreamshaperXL_alpha2Xl10_loha_1e-7.safetensors `
|
||||
--rank 2 `
|
||||
--initial_alpha 2 `
|
||||
--max_rank_retries 7 `
|
||||
--rank_increase_factor 2 `
|
||||
--max_iterations 8000 `
|
||||
--min_iterations 400 `
|
||||
--target_loss 1e-7 `
|
||||
--lr 1e-01 `
|
||||
--device cuda `
|
||||
--precision fp32 `
|
||||
--verbose `
|
||||
--save_weights_dtype bf16 `
|
||||
--progress_check_interval 100 `
|
||||
--save_every_n_layers 10 `
|
||||
--keep_n_resume_files 10 `
|
||||
--skip_delta_threshold 1e-7 `
|
||||
--rank_search_strategy binary_search_min_rank `
|
||||
--probe_aggressive_early_stop
|
||||
|
||||
D:\kohya_ss\venv\Scripts\python.exe D:\kohya_ss\tools\model_diff_report.py `
|
||||
E:/models/sdxl/base/sd_xl_base_1.0_0.9vae.safetensors `
|
||||
E:/models/sdxl/dreamshaperXL_alpha2Xl10.safetensors `
|
||||
--top_n_diff 15 --plot_histograms --plot_histograms_top_n 3 --output_dir ./analysis_results
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,535 @@
|
|||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import logging # Import for logging
|
||||
|
||||
# NEW: Add diffusers import for model loading
|
||||
try:
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
except ImportError:
|
||||
print("Diffusers library not found. Please install it: pip install diffusers transformers accelerate")
|
||||
raise
|
||||
|
||||
# --- Localized Logging Setup ---
|
||||
def _local_setup_logging(log_level=logging.INFO):
|
||||
"""
|
||||
Sets up basic logging to console.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
_local_setup_logging() # Initialize logging
|
||||
logger = logging.getLogger(__name__) # Get logger for this module
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# --- Localized sd-scripts constants and utility functions ---
|
||||
_LOCAL_MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_v10"
|
||||
|
||||
def _local_get_model_version_str_for_sd1_sd2(is_v2: bool, is_v_parameterization: bool) -> str:
|
||||
if is_v2:
|
||||
return "v2-v" if is_v_parameterization else "v2"
|
||||
return "v1"
|
||||
|
||||
# --- Localized LoRA Placeholder and Network Creation ---
|
||||
class LocalLoRAModulePlaceholder:
|
||||
def __init__(self, lora_name: str, org_module: torch.nn.Module):
|
||||
self.lora_name = lora_name
|
||||
self.org_module = org_module
|
||||
# Add other attributes if _calculate_module_diffs_and_check needs them,
|
||||
# but it primarily uses .lora_name and .org_module.weight
|
||||
|
||||
def _local_create_network_placeholders(text_encoders: list, unet: torch.nn.Module, lora_conv_dim_init: int):
|
||||
"""
|
||||
Creates placeholders for LoRA-able modules in text encoders and UNet.
|
||||
Mimics the module identification and naming of sd-scripts' lora.create_network.
|
||||
`lora_conv_dim_init`: If > 0, Conv2d layers are considered for LoRA.
|
||||
"""
|
||||
unet_loras = []
|
||||
text_encoder_loras = []
|
||||
|
||||
# Target U-Net modules
|
||||
for name, module in unet.named_modules():
|
||||
lora_name = "lora_unet_" + name.replace(".", "_")
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
if lora_conv_dim_init > 0: # Only consider conv layers if conv_dim > 0
|
||||
# Kernel size check might be relevant if sd-scripts has specific logic,
|
||||
# but for diffing, any conv is a candidate if conv_dim > 0.
|
||||
# SVD will later handle rank based on actual layer type (1x1 vs 3x3).
|
||||
unet_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
|
||||
# Target Text Encoder modules
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None: # SDXL can have None TEs if not loaded
|
||||
continue
|
||||
# Determine prefix based on number of text encoders (for SDXL compatibility)
|
||||
te_prefix = f"lora_te{i+1}_" if len(text_encoders) > 1 else "lora_te_"
|
||||
|
||||
for name, module in text_encoder.named_modules():
|
||||
lora_name = te_prefix + name.replace(".", "_")
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
# Conv2d in text encoders is rare but check just in case (sd-scripts might)
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
if lora_conv_dim_init > 0:
|
||||
text_encoder_loras.append(LocalLoRAModulePlaceholder(lora_name, module))
|
||||
|
||||
logger.info(f"Found {len(text_encoder_loras)} LoRA-able placeholder modules in Text Encoders.")
|
||||
logger.info(f"Found {len(unet_loras)} LoRA-able placeholder modules in U-Net.")
|
||||
return text_encoder_loras, unet_loras
|
||||
|
||||
|
||||
# --- Singular Value Indexing Functions (Unchanged) ---
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
return index
|
||||
|
||||
def index_sv_knee(S, MIN_SV_KNEE=1e-8):
|
||||
n = len(S)
|
||||
if n < 3: return 1
|
||||
s_max, s_min = S[0], S[-1]
|
||||
if s_max - s_min < MIN_SV_KNEE: return 1
|
||||
s_normalized = (S - s_min) / (s_max - s_min)
|
||||
x_normalized = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
distances = (x_normalized + s_normalized - 1).abs()
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
rank = knee_index_0based + 1
|
||||
rank = max(1, min(rank, n - 1))
|
||||
return rank
|
||||
|
||||
def index_sv_cumulative_knee(S, min_sv_threshold=1e-8):
|
||||
n = len(S)
|
||||
if n < 3: return 1
|
||||
s_sum = torch.sum(S)
|
||||
if s_sum < min_sv_threshold: return 1
|
||||
y_values = torch.cumsum(S, dim=0) / s_sum
|
||||
y_min, y_max = y_values[0], y_values[n-1]
|
||||
if y_max - y_min < min_sv_threshold: return 1
|
||||
y_norm = (y_values - y_min) / (y_max - y_min)
|
||||
x_norm = torch.linspace(0, 1, n, device=S.device, dtype=S.dtype)
|
||||
distances = (y_norm - x_norm).abs()
|
||||
knee_index_0based = torch.argmax(distances).item()
|
||||
rank = knee_index_0based + 1
|
||||
rank = max(1, min(rank, n - 1))
|
||||
return rank
|
||||
|
||||
def index_sv_rel_decrease(S, tau=0.1):
|
||||
if len(S) < 2: return 1
|
||||
ratios = S[1:] / S[:-1]
|
||||
for k in range(len(ratios)):
|
||||
if ratios[k] < tau:
|
||||
return k + 1
|
||||
return len(S)
|
||||
|
||||
# --- Utility Functions ---
|
||||
def _str_to_dtype(p):
|
||||
if p == "float": return torch.float
|
||||
if p == "fp16": return torch.float16
|
||||
if p == "bf16": return torch.bfloat16
|
||||
return None
|
||||
|
||||
def save_to_file(file_name, state_dict_to_save, dtype, metadata=None):
|
||||
state_dict_final = {}
|
||||
for key, value in state_dict_to_save.items():
|
||||
if isinstance(value, torch.Tensor) and dtype is not None:
|
||||
state_dict_final[key] = value.to(dtype)
|
||||
else:
|
||||
state_dict_final[key] = value
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict_final, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(state_dict_final, file_name)
|
||||
|
||||
def _build_local_sai_metadata(title, creation_time, is_v2_flag, is_v_param_flag, is_sdxl_flag):
|
||||
metadata = {}
|
||||
metadata["ss_sd_model_name"] = str(title)
|
||||
metadata["ss_creation_time"] = str(int(creation_time))
|
||||
if is_sdxl_flag:
|
||||
metadata["ss_base_model_version"] = "sdxl_v10"
|
||||
metadata["ss_sdxl_model_version"] = "1.0"
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
elif is_v2_flag:
|
||||
metadata["ss_base_model_version"] = "sd_v2"
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
else:
|
||||
metadata["ss_base_model_version"] = "sd_v1"
|
||||
if is_v_param_flag:
|
||||
metadata["ss_v_parameterization"] = "true"
|
||||
return metadata
|
||||
|
||||
# --- MODIFIED Helper Functions for Model Loading ---
|
||||
def _load_sd_model_components(model_path, is_v2_flag, target_device_override, load_dtype_torch):
|
||||
logger.info(f"Loading SD model using Diffusers.StableDiffusionPipeline from: {model_path}")
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
model_path,
|
||||
torch_dtype=load_dtype_torch
|
||||
)
|
||||
eff_device = target_device_override if target_device_override else "cpu"
|
||||
text_encoder = pipeline.text_encoder.to(eff_device)
|
||||
unet = pipeline.unet.to(eff_device)
|
||||
text_encoders = [text_encoder]
|
||||
logger.info(f"Loaded SD model components. UNet device: {unet.device}, TextEncoder device: {text_encoder.device}")
|
||||
return text_encoders, unet
|
||||
|
||||
def _load_sdxl_model_components(model_path, target_device_override, load_dtype_torch):
|
||||
actual_load_device = target_device_override if target_device_override else "cpu"
|
||||
logger.info(f"Loading SDXL model using Diffusers.StableDiffusionXLPipeline from: {model_path} to device: {actual_load_device}")
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
model_path,
|
||||
torch_dtype=load_dtype_torch
|
||||
)
|
||||
pipeline.to(actual_load_device)
|
||||
text_encoder = pipeline.text_encoder
|
||||
text_encoder_2 = pipeline.text_encoder_2
|
||||
unet = pipeline.unet
|
||||
text_encoders = [text_encoder, text_encoder_2]
|
||||
logger.info(f"Loaded SDXL model components. UNet device: {unet.device}, TextEncoder1 device: {text_encoder.device}, TextEncoder2 device: {text_encoder_2.device}")
|
||||
return text_encoders, unet
|
||||
|
||||
def _calculate_module_diffs_and_check(module_loras_o, module_loras_t, diff_calc_device, min_diff_thresh, module_type_str):
|
||||
diffs_map = {}
|
||||
is_different_flag = False
|
||||
first_diff_logged = False
|
||||
for lora_o, lora_t in zip(module_loras_o, module_loras_t):
|
||||
lora_name = lora_o.lora_name
|
||||
if lora_o.org_module is None or lora_t.org_module is None or \
|
||||
not hasattr(lora_o.org_module, 'weight') or lora_o.org_module.weight is None or \
|
||||
not hasattr(lora_t.org_module, 'weight') or lora_t.org_module.weight is None:
|
||||
logger.warning(f"Skipping {lora_name} in {module_type_str} due to missing org_module or weight.")
|
||||
continue
|
||||
weight_o = lora_o.org_module.weight
|
||||
weight_t = lora_t.org_module.weight
|
||||
if str(weight_o.device) != str(diff_calc_device): weight_o = weight_o.to(diff_calc_device)
|
||||
if str(weight_t.device) != str(diff_calc_device): weight_t = weight_t.to(diff_calc_device)
|
||||
diff = weight_t - weight_o
|
||||
diffs_map[lora_name] = diff
|
||||
current_max_diff = torch.max(torch.abs(diff))
|
||||
if not is_different_flag and current_max_diff > min_diff_thresh:
|
||||
is_different_flag = True
|
||||
if not first_diff_logged:
|
||||
logger.info(f"{module_type_str} '{lora_name}' differs: max diff {current_max_diff} > {min_diff_thresh}")
|
||||
first_diff_logged = True
|
||||
return diffs_map, is_different_flag
|
||||
|
||||
def _determine_rank(S_values, dynamic_method_name, dynamic_param_value, max_rank_limit,
|
||||
module_eff_in_dim, module_eff_out_dim, min_sv_threshold=MIN_SV):
|
||||
if not S_values.numel() or S_values[0] <= min_sv_threshold: return 1
|
||||
rank = 0
|
||||
if dynamic_method_name == "sv_ratio": rank = index_sv_ratio(S_values, dynamic_param_value)
|
||||
elif dynamic_method_name == "sv_cumulative": rank = index_sv_cumulative(S_values, dynamic_param_value)
|
||||
elif dynamic_method_name == "sv_fro": rank = index_sv_fro(S_values, dynamic_param_value)
|
||||
elif dynamic_method_name == "sv_knee": rank = index_sv_knee(S_values, min_sv_threshold)
|
||||
elif dynamic_method_name == "sv_cumulative_knee": rank = index_sv_cumulative_knee(S_values, min_sv_threshold)
|
||||
elif dynamic_method_name == "sv_rel_decrease": rank = index_sv_rel_decrease(S_values, dynamic_param_value)
|
||||
else: rank = max_rank_limit
|
||||
rank = min(rank, max_rank_limit, module_eff_in_dim, module_eff_out_dim, len(S_values))
|
||||
rank = max(1, rank)
|
||||
return rank
|
||||
|
||||
def _construct_lora_weights_from_svd_components(U_full, S_all_values, Vh_full, rank,
|
||||
clamp_quantile_val, is_conv2d, is_conv2d_3x3,
|
||||
conv_kernel_size,
|
||||
module_out_channels, module_in_channels,
|
||||
target_device_for_final_weights, target_dtype_for_final_weights):
|
||||
S_k = S_all_values[:rank]
|
||||
U_k = U_full[:, :rank]
|
||||
Vh_k = Vh_full[:rank, :]
|
||||
S_k_non_negative = torch.clamp(S_k, min=0.0)
|
||||
s_sqrt = torch.sqrt(S_k_non_negative)
|
||||
U_final = U_k * s_sqrt.unsqueeze(0)
|
||||
Vh_final = Vh_k * s_sqrt.unsqueeze(1)
|
||||
dist = torch.cat([U_final.flatten(), Vh_final.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile_val)
|
||||
if hi_val == 0 and torch.max(torch.abs(dist)) > 1e-9:
|
||||
logger.debug(f"Clamping hi_val is zero for non-zero distribution. Max abs val: {torch.max(torch.abs(dist))}. Quantile: {clamp_quantile_val}")
|
||||
U_clamped = U_final.clamp(-hi_val, hi_val)
|
||||
Vh_clamped = Vh_final.clamp(-hi_val, hi_val)
|
||||
if is_conv2d:
|
||||
U_clamped = U_clamped.reshape(module_out_channels, rank, 1, 1)
|
||||
if is_conv2d_3x3:
|
||||
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, *conv_kernel_size)
|
||||
else:
|
||||
Vh_clamped = Vh_clamped.reshape(rank, module_in_channels, 1, 1)
|
||||
U_clamped = U_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
|
||||
Vh_clamped = Vh_clamped.to(target_device_for_final_weights, dtype=target_dtype_for_final_weights).contiguous()
|
||||
return U_clamped, Vh_clamped
|
||||
|
||||
def _log_svd_stats(lora_module_name, S_all_values, rank_used, min_sv_for_calc=MIN_SV):
|
||||
if not S_all_values.numel():
|
||||
logger.info(f"{lora_module_name:75} | rank: {rank_used}, SVD not performed (empty singular values).")
|
||||
return
|
||||
S_cpu = S_all_values.to('cpu')
|
||||
s_sum_total = float(torch.sum(S_cpu))
|
||||
s_sum_rank = float(torch.sum(S_cpu[:rank_used]))
|
||||
fro_orig_total = float(torch.sqrt(torch.sum(S_cpu.pow(2))))
|
||||
fro_reconstructed_rank = float(torch.sqrt(torch.sum(S_cpu[:rank_used].pow(2))))
|
||||
ratio_sv = float('inf')
|
||||
if rank_used > 0 and S_cpu[rank_used - 1].abs() > min_sv_for_calc:
|
||||
ratio_sv = S_cpu[0] / S_cpu[rank_used - 1]
|
||||
sum_s_retained_percentage = (s_sum_rank / s_sum_total) if s_sum_total > min_sv_for_calc else 1.0
|
||||
fro_retained_percentage = (fro_reconstructed_rank / fro_orig_total) if fro_orig_total > min_sv_for_calc else 1.0
|
||||
logger.info(
|
||||
f"{lora_module_name:75} | rank: {rank_used}, "
|
||||
f"sum(S) retained: {sum_s_retained_percentage:.2%}, "
|
||||
f"Frobenius norm retained: {fro_retained_percentage:.2%}, "
|
||||
f"max_retained_sv/min_retained_sv ratio: {ratio_sv:.2f}"
|
||||
)
|
||||
|
||||
def _prepare_lora_metadata(output_path, is_v2_flag, kohya_base_model_version_str, network_conv_dim_val,
|
||||
use_dynamic_method_flag, network_dim_config_val,
|
||||
is_v_param_flag, is_sdxl_flag, skip_sai_meta):
|
||||
net_kwargs = {"conv_dim": str(network_conv_dim_val), "conv_alpha": str(float(network_conv_dim_val))} if network_conv_dim_val is not None else {}
|
||||
if use_dynamic_method_flag:
|
||||
network_dim_meta = "Dynamic"
|
||||
network_alpha_meta = "Dynamic"
|
||||
else:
|
||||
network_dim_meta = str(network_dim_config_val)
|
||||
network_alpha_meta = str(float(network_dim_config_val))
|
||||
final_metadata = {
|
||||
"ss_v2": str(is_v2_flag),
|
||||
"ss_base_model_version": kohya_base_model_version_str,
|
||||
"ss_network_module": "networks.lora", # This remains for compatibility with tools expecting it
|
||||
"ss_network_dim": network_dim_meta,
|
||||
"ss_network_alpha": network_alpha_meta,
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
"ss_lowram": "False",
|
||||
"ss_num_train_images": "N/A",
|
||||
}
|
||||
if not skip_sai_meta:
|
||||
title = os.path.splitext(os.path.basename(output_path))[0]
|
||||
current_time = time.time()
|
||||
sai_metadata_content = _build_local_sai_metadata(
|
||||
title=title, creation_time=current_time, is_v2_flag=is_v2_flag,
|
||||
is_v_param_flag=is_v_param_flag, is_sdxl_flag=is_sdxl_flag
|
||||
)
|
||||
final_metadata.update(sai_metadata_content)
|
||||
return final_metadata
|
||||
|
||||
# --- Main SVD Function ---
|
||||
def svd(
|
||||
model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None,
|
||||
conv_dim=None, v_parameterization=None, device=None, save_precision=None,
|
||||
clamp_quantile=0.99, min_diff=0.01, no_metadata=False, load_precision=None,
|
||||
load_original_model_to=None, load_tuned_model_to=None,
|
||||
dynamic_method=None, dynamic_param=None, verbose=False,
|
||||
):
|
||||
actual_v_parameterization = v2 if v_parameterization is None else v_parameterization
|
||||
load_dtype_torch = _str_to_dtype(load_precision)
|
||||
save_dtype_torch = _str_to_dtype(save_precision) if save_precision else torch.float
|
||||
|
||||
svd_computation_device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using SVD computation device: {svd_computation_device}")
|
||||
diff_calculation_device = torch.device("cpu")
|
||||
logger.info(f"Calculating weight differences on: {diff_calculation_device}")
|
||||
final_weights_device = torch.device("cpu")
|
||||
|
||||
if not sdxl:
|
||||
text_encoders_o, unet_o = _load_sd_model_components(model_org, v2, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sd_model_components(model_tuned, v2, load_tuned_model_to, load_dtype_torch)
|
||||
kohya_model_version = _local_get_model_version_str_for_sd1_sd2(v2, actual_v_parameterization)
|
||||
else:
|
||||
text_encoders_o, unet_o = _load_sdxl_model_components(model_org, load_original_model_to, load_dtype_torch)
|
||||
text_encoders_t, unet_t = _load_sdxl_model_components(model_tuned, load_tuned_model_to, load_dtype_torch)
|
||||
kohya_model_version = _LOCAL_MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# Determine lora_conv_dim_init based on conv_dim argument for network creation
|
||||
# The original script used init_dim_val (1) if conv_dim was None.
|
||||
# Here, conv_dim is already defaulted to args.dim if None by the main block.
|
||||
# So, lora_conv_dim_init will be args.conv_dim (which defaults to args.dim).
|
||||
# If args.conv_dim was explicitly 0, this would be 0.
|
||||
lora_conv_dim_init_val = conv_dim # conv_dim is args.conv_dim (or args.dim)
|
||||
|
||||
# Create LoRA placeholders using the localized function
|
||||
text_encoder_loras_o, unet_loras_o = _local_create_network_placeholders(text_encoders_o, unet_o, lora_conv_dim_init_val)
|
||||
text_encoder_loras_t, unet_loras_t = _local_create_network_placeholders(text_encoders_t, unet_t, lora_conv_dim_init_val) # same conv_dim logic for tuned
|
||||
|
||||
# Group LoRA placeholders for easier processing (mimicking LoraNetwork structure somewhat)
|
||||
class LocalLoraNetworkPlaceholder:
|
||||
def __init__(self, te_loras, unet_loras_list):
|
||||
self.text_encoder_loras = te_loras
|
||||
self.unet_loras = unet_loras_list
|
||||
|
||||
lora_network_o = LocalLoraNetworkPlaceholder(text_encoder_loras_o, unet_loras_o)
|
||||
lora_network_t = LocalLoraNetworkPlaceholder(text_encoder_loras_t, unet_loras_t)
|
||||
|
||||
assert len(lora_network_o.text_encoder_loras) == len(lora_network_t.text_encoder_loras), \
|
||||
f"Model versions (based on identified LoRA-able TE modules) differ: {len(lora_network_o.text_encoder_loras)} vs {len(lora_network_t.text_encoder_loras)} TEs"
|
||||
|
||||
all_diffs = {}
|
||||
te_diffs, text_encoder_different = _calculate_module_diffs_and_check(
|
||||
lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras,
|
||||
diff_calculation_device, min_diff, "Text Encoder"
|
||||
)
|
||||
|
||||
if text_encoder_different:
|
||||
all_diffs.update(te_diffs)
|
||||
else:
|
||||
logger.warning("Text encoders are considered identical based on min_diff. Not extracting TE LoRA.")
|
||||
# To prevent processing empty list later, ensure it's empty if no diffs
|
||||
lora_network_o.text_encoder_loras = []
|
||||
del text_encoders_t # Free memory early
|
||||
|
||||
unet_diffs, _ = _calculate_module_diffs_and_check(
|
||||
lora_network_o.unet_loras, lora_network_t.unet_loras,
|
||||
diff_calculation_device, min_diff, "U-Net"
|
||||
)
|
||||
all_diffs.update(unet_diffs)
|
||||
del lora_network_t, unet_t # Free memory early
|
||||
|
||||
# Ensure lora_names_to_process only includes modules from lora_network_o
|
||||
# that are actually present (e.g., if TEs were skipped)
|
||||
lora_names_to_process = set()
|
||||
if text_encoder_different: # Only add TE loras if they were deemed different
|
||||
lora_names_to_process.update(p.lora_name for p in lora_network_o.text_encoder_loras)
|
||||
lora_names_to_process.update(p.lora_name for p in lora_network_o.unet_loras)
|
||||
|
||||
logger.info("Extracting and resizing LoRA via SVD")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name in tqdm(lora_names_to_process):
|
||||
if lora_name not in all_diffs:
|
||||
logger.warning(f"Skipping {lora_name} as no diff was calculated for it (e.g., Text Encoders were identical).")
|
||||
continue
|
||||
original_diff_tensor = all_diffs[lora_name]
|
||||
is_conv2d_layer = len(original_diff_tensor.size()) == 4
|
||||
kernel_s = original_diff_tensor.size()[2:4] if is_conv2d_layer else None
|
||||
is_conv2d_3x3_layer = is_conv2d_layer and kernel_s != (1, 1)
|
||||
module_true_out_channels, module_true_in_channels = original_diff_tensor.size()[0:2]
|
||||
mat_for_svd = original_diff_tensor.to(svd_computation_device, dtype=torch.float)
|
||||
if is_conv2d_layer:
|
||||
if is_conv2d_3x3_layer: mat_for_svd = mat_for_svd.flatten(start_dim=1)
|
||||
else: mat_for_svd = mat_for_svd.squeeze()
|
||||
if mat_for_svd.numel() == 0 or mat_for_svd.shape[0] == 0 or mat_for_svd.shape[1] == 0 :
|
||||
logger.warning(f"Skipping SVD for {lora_name} due to empty/invalid shape: {mat_for_svd.shape}")
|
||||
continue
|
||||
try:
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(mat_for_svd)
|
||||
except Exception as e:
|
||||
logger.error(f"SVD failed for {lora_name} with shape {mat_for_svd.shape}. Error: {e}")
|
||||
continue
|
||||
|
||||
# Max rank for SVD is based on 'dim' for linear and 'conv_dim' for conv3x3
|
||||
# The original `current_max_rank` logic was:
|
||||
# current_max_rank = dim if not is_conv2d_3x3_layer or conv_dim is None else conv_dim
|
||||
# Here, `dim` is args.dim and `conv_dim` is args.conv_dim (defaulted to args.dim)
|
||||
module_specific_max_rank = conv_dim if is_conv2d_3x3_layer else dim
|
||||
|
||||
eff_out_dim, eff_in_dim = mat_for_svd.shape[0], mat_for_svd.shape[1]
|
||||
rank = _determine_rank(S_full, dynamic_method, dynamic_param,
|
||||
module_specific_max_rank, eff_in_dim, eff_out_dim, MIN_SV)
|
||||
U_clamped, Vh_clamped = _construct_lora_weights_from_svd_components(
|
||||
U_full, S_full, Vh_full, rank, clamp_quantile,
|
||||
is_conv2d_layer, is_conv2d_3x3_layer, kernel_s,
|
||||
module_true_out_channels, module_true_in_channels,
|
||||
final_weights_device, save_dtype_torch
|
||||
)
|
||||
lora_weights[lora_name] = (U_clamped, Vh_clamped)
|
||||
if verbose: _log_svd_stats(lora_name, S_full, rank, MIN_SV)
|
||||
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
# Alpha is set to the rank (dim of down_weight's 0th axis, which is rank)
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0], dtype=save_dtype_torch, device=final_weights_device)
|
||||
|
||||
del text_encoders_o, unet_o, lora_network_o, all_diffs # Clean up original models and placeholders
|
||||
if 'torch' in sys.modules and hasattr(torch, 'cuda') and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not os.path.exists(os.path.dirname(save_to)) and os.path.dirname(save_to) != "":
|
||||
os.makedirs(os.path.dirname(save_to), exist_ok=True)
|
||||
|
||||
metadata_to_save = _prepare_lora_metadata(
|
||||
output_path=save_to,
|
||||
is_v2_flag=v2,
|
||||
kohya_base_model_version_str=kohya_model_version,
|
||||
network_conv_dim_val=conv_dim, # This is args.conv_dim (defaulted to args.dim)
|
||||
use_dynamic_method_flag=bool(dynamic_method),
|
||||
network_dim_config_val=dim, # This is args.dim
|
||||
is_v_param_flag=actual_v_parameterization,
|
||||
is_sdxl_flag=sdxl,
|
||||
skip_sai_meta=no_metadata
|
||||
)
|
||||
|
||||
save_to_file(save_to, lora_sd, save_dtype_torch, metadata_to_save)
|
||||
logger.info(f"LoRA saved to: {save_to}")
|
||||
|
||||
def setup_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="Load Stable Diffusion v2.x model")
|
||||
parser.add_argument("--v_parameterization", action="store_true", help="Set v-parameterization metadata (defaults to v2 if --v2 is set)")
|
||||
parser.add_argument("--sdxl", action="store_true", help="Load Stable Diffusion SDXL base model")
|
||||
parser.add_argument("--load_precision", type=str, choices=["float", "fp16", "bf16"], default=None, help="Precision for loading models (applied after initial load)")
|
||||
parser.add_argument("--save_precision", type=str, choices=["float", "fp16", "bf16"], default="float", help="Precision for saving LoRA weights")
|
||||
parser.add_argument("--model_org", type=str, required=True, help="Original Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--model_tuned", type=str, required=True, help="Tuned Stable Diffusion model (ckpt/safetensors)")
|
||||
parser.add_argument("--save_to", type=str, required=True, help="Output file name (ckpt/safetensors)")
|
||||
parser.add_argument("--dim", type=int, default=4, help="Max dimension (rank) of LoRA for linear layers")
|
||||
parser.add_argument("--conv_dim", type=int, default=None, help="Max dimension (rank) of LoRA for Conv2d-3x3. Defaults to 'dim' if not set.")
|
||||
parser.add_argument("--device", type=str, default=None, help="Device for SVD computation (e.g., cuda, cpu). Defaults to cuda if available, else cpu.")
|
||||
parser.add_argument("--clamp_quantile", type=float, default=0.99, help="Quantile for clamping weights")
|
||||
parser.add_argument("--min_diff", type=float, default=0.01, help="Minimum weight difference to extract LoRA for a module")
|
||||
parser.add_argument("--no_metadata", action="store_true", help="Omit detailed metadata from SAI and Kohya_ss")
|
||||
parser.add_argument("--load_original_model_to", type=str, default=None, help="Device for original model (e.g. 'cpu', 'cuda:0'). Defaults to CPU for SD1/2, honored for SDXL.")
|
||||
parser.add_argument("--load_tuned_model_to", type=str, default=None, help="Device for tuned model (e.g. 'cpu', 'cuda:0'). Defaults to CPU for SD1/2, honored for SDXL.")
|
||||
parser.add_argument("--dynamic_param", type=float, help="Parameter for dynamic rank reduction")
|
||||
parser.add_argument("--verbose", action="store_true", help="Show detailed rank reduction info for each module")
|
||||
parser.add_argument(
|
||||
"--dynamic_method", type=str,
|
||||
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative", "sv_knee", "sv_rel_decrease", "sv_cumulative_knee"],
|
||||
default=None, help="Dynamic rank reduction method"
|
||||
)
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.conv_dim is None:
|
||||
args.conv_dim = args.dim # Default conv_dim to dim if not provided
|
||||
logger.info(f"--conv_dim not set, using value of --dim: {args.conv_dim}")
|
||||
|
||||
methods_requiring_param = ["sv_ratio", "sv_fro", "sv_cumulative", "sv_rel_decrease"]
|
||||
if args.dynamic_method in methods_requiring_param and args.dynamic_param is None:
|
||||
parser.error(f"Dynamic method '{args.dynamic_method}' requires --dynamic_param to be set.")
|
||||
|
||||
if not args.dynamic_method: # Ranks must be positive if not using dynamic method
|
||||
if args.dim <= 0: parser.error(f"--dim (rank) must be > 0. Got {args.dim}")
|
||||
if args.conv_dim <=0: parser.error(f"--conv_dim (rank) must be > 0. Got {args.conv_dim}") # Check after defaulting
|
||||
|
||||
if MIN_SV <= 0: logger.warning(f"Global MIN_SV ({MIN_SV}) should be positive.")
|
||||
|
||||
svd_args = vars(args).copy()
|
||||
svd(**svd_args)
|
||||
|
|
@ -1,360 +0,0 @@
|
|||
# extract approximating LoRA by svd from two SD models
|
||||
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo!
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
save_to=None,
|
||||
dim=4,
|
||||
v2=None,
|
||||
sdxl=None,
|
||||
conv_dim=None,
|
||||
v_parameterization=None,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
load_precision=None,
|
||||
load_original_model_to=None,
|
||||
load_tuned_model_to=None,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||
if v_parameterization is None:
|
||||
v_parameterization = v2
|
||||
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
work_device = "cpu"
|
||||
|
||||
# load models
|
||||
if not sdxl:
|
||||
logger.info(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t = text_encoder_t.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
logger.info(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
if conv_dim is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
||||
|
||||
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras
|
||||
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
|
||||
# get diffs
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear target Text Encoder to save memory
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
logger.warning("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {} # clear diffs
|
||||
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear LoRA network, target U-Net to save memory
|
||||
del lora_network_o
|
||||
del lora_network_t
|
||||
del unet_t
|
||||
|
||||
# make LoRA with svd
|
||||
logger.info("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
mat = mat.to(torch.float) # calc by float
|
||||
|
||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
|
||||
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
logger.info(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# minimum metadata
|
||||
net_kwargs = {}
|
||||
if conv_dim is not None:
|
||||
net_kwargs["conv_dim"] = str(conv_dim)
|
||||
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
||||
|
||||
metadata = {
|
||||
"ss_v2": str(v2),
|
||||
"ss_base_model_version": model_version,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": str(dim),
|
||||
"ss_network_alpha": str(float(dim)),
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||
logger.info(f"LoRA weights are saved to: {save_to}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||
parser.add_argument(
|
||||
"--v_parameterization",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_org",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_tuned",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument(
|
||||
"--conv_dim",
|
||||
type=int,
|
||||
default=None,
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--clamp_quantile",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_diff",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_original_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_tuned_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(**vars(args))
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import argparse # Import argparse
|
||||
|
||||
def extract_model_differences(base_model_path, finetuned_model_path, output_delta_path=None, save_dtype_str="float32"):
|
||||
"""
|
||||
Calculates the difference between the state dictionaries of a fine-tuned model
|
||||
and a base model.
|
||||
|
||||
Args:
|
||||
base_model_path (str): Path to the base model .safetensors file.
|
||||
finetuned_model_path (str): Path to the fine-tuned model .safetensors file.
|
||||
output_delta_path (str, optional): Path to save the resulting delta weights
|
||||
.safetensors file. If None, not saved.
|
||||
save_dtype_str (str, optional): Data type to save the delta weights ('float32', 'float16', 'bfloat16').
|
||||
Defaults to 'float32'.
|
||||
Returns:
|
||||
OrderedDict: A state dictionary containing the delta weights.
|
||||
Returns None if loading fails or other critical errors.
|
||||
"""
|
||||
print(f"Loading base model from: {base_model_path}")
|
||||
try:
|
||||
# Ensure model is loaded to CPU to avoid CUDA issues if not needed for diffing
|
||||
base_state_dict = load_file(base_model_path, device="cpu")
|
||||
print(f"Base model loaded. Found {len(base_state_dict)} tensors.")
|
||||
except Exception as e:
|
||||
print(f"Error loading base model: {e}")
|
||||
return None
|
||||
|
||||
print(f"\nLoading fine-tuned model from: {finetuned_model_path}")
|
||||
try:
|
||||
finetuned_state_dict = load_file(finetuned_model_path, device="cpu")
|
||||
print(f"Fine-tuned model loaded. Found {len(finetuned_state_dict)} tensors.")
|
||||
except Exception as e:
|
||||
print(f"Error loading fine-tuned model: {e}")
|
||||
return None
|
||||
|
||||
delta_state_dict = OrderedDict()
|
||||
diff_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
unique_to_finetuned_count = 0
|
||||
unique_to_base_count = 0
|
||||
|
||||
print("\nCalculating differences...")
|
||||
|
||||
# Keys in finetuned model
|
||||
finetuned_keys = set(finetuned_state_dict.keys())
|
||||
base_keys = set(base_state_dict.keys())
|
||||
|
||||
common_keys = finetuned_keys.intersection(base_keys)
|
||||
keys_only_in_finetuned = finetuned_keys - base_keys
|
||||
keys_only_in_base = base_keys - finetuned_keys
|
||||
|
||||
for key in common_keys:
|
||||
ft_tensor = finetuned_state_dict[key]
|
||||
base_tensor = base_state_dict[key]
|
||||
|
||||
if not (ft_tensor.is_floating_point() and base_tensor.is_floating_point()):
|
||||
# print(f"Skipping key '{key}': Non-floating point tensors (FT: {ft_tensor.dtype}, Base: {base_tensor.dtype}).")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if ft_tensor.shape != base_tensor.shape:
|
||||
print(f"Skipping key '{key}': Shape mismatch (FT: {ft_tensor.shape}, Base: {base_tensor.shape}).")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Calculate difference in float32 for precision, then cast to save_dtype
|
||||
delta_tensor = ft_tensor.to(dtype=torch.float32) - base_tensor.to(dtype=torch.float32)
|
||||
delta_state_dict[key] = delta_tensor
|
||||
diff_count += 1
|
||||
except Exception as e:
|
||||
print(f"Error calculating difference for key '{key}': {e}")
|
||||
error_count += 1
|
||||
|
||||
for key in keys_only_in_finetuned:
|
||||
print(f"Warning: Key '{key}' (Shape: {finetuned_state_dict[key].shape}, Dtype: {finetuned_state_dict[key].dtype}) is present in fine-tuned model but not in base model. Storing as is.")
|
||||
delta_state_dict[key] = finetuned_state_dict[key] # Store the tensor from the finetuned model
|
||||
unique_to_finetuned_count += 1
|
||||
|
||||
if keys_only_in_base:
|
||||
print(f"\nWarning: {len(keys_only_in_base)} key(s) are present only in the base model and will not be in the delta file.")
|
||||
for key in list(keys_only_in_base)[:5]: # Print first 5 as examples
|
||||
print(f" - Example key only in base: {key}")
|
||||
if len(keys_only_in_base) > 5:
|
||||
print(f" ... and {len(keys_only_in_base) - 5} more.")
|
||||
|
||||
|
||||
print(f"\nDifference calculation complete.")
|
||||
print(f" {diff_count} layers successfully diffed.")
|
||||
print(f" {unique_to_finetuned_count} layers unique to fine-tuned model (added as is).")
|
||||
print(f" {skipped_count} common layers skipped (shape/type mismatch).")
|
||||
print(f" {error_count} common layers had errors during diffing.")
|
||||
|
||||
if output_delta_path and delta_state_dict:
|
||||
save_dtype = torch.float32 # Default
|
||||
if save_dtype_str == "float16":
|
||||
save_dtype = torch.float16
|
||||
elif save_dtype_str == "bfloat16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif save_dtype_str != "float32":
|
||||
print(f"Warning: Invalid save_dtype '{save_dtype_str}'. Defaulting to float32.")
|
||||
save_dtype_str = "float32" # for print message
|
||||
|
||||
print(f"\nPreparing to save delta weights with dtype: {save_dtype_str}")
|
||||
|
||||
final_save_dict = OrderedDict()
|
||||
for k, v_tensor in delta_state_dict.items():
|
||||
if v_tensor.is_floating_point():
|
||||
final_save_dict[k] = v_tensor.to(dtype=save_dtype)
|
||||
else:
|
||||
final_save_dict[k] = v_tensor # Keep non-float as is (e.g. int tensors if any)
|
||||
|
||||
try:
|
||||
save_file(final_save_dict, output_delta_path)
|
||||
print(f"Delta weights saved to: {output_delta_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving delta weights: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return delta_state_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Extract weight differences between a fine-tuned and a base SDXL model.")
|
||||
parser.add_argument("base_model_path", type=str, help="File path for the BASE SDXL model (.safetensors).")
|
||||
parser.add_argument("finetuned_model_path", type=str, help="File path for the FINE-TUNED SDXL model (.safetensors).")
|
||||
parser.add_argument("--output_path", type=str, default=None,
|
||||
help="Optional: File path to save the delta weights (.safetensors). "
|
||||
"If not provided, defaults to 'model_deltas/delta_[finetuned_model_name].safetensors'.")
|
||||
parser.add_argument("--save_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"],
|
||||
help="Data type for saving the delta weights. Choose from 'float32', 'float16', 'bfloat16'. "
|
||||
"Defaults to 'float32'.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("--- Model Difference Extraction Script ---")
|
||||
|
||||
if not os.path.exists(args.base_model_path):
|
||||
print(f"Error: Base model file not found at {args.base_model_path}")
|
||||
exit(1)
|
||||
if not os.path.exists(args.finetuned_model_path):
|
||||
print(f"Error: Fine-tuned model file not found at {args.finetuned_model_path}")
|
||||
exit(1)
|
||||
|
||||
output_delta_file = args.output_path
|
||||
if output_delta_file is None:
|
||||
output_dir = "model_deltas"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
finetuned_basename = os.path.splitext(os.path.basename(args.finetuned_model_path))[0]
|
||||
output_delta_file = os.path.join(output_dir, f"delta_{finetuned_basename}.safetensors")
|
||||
|
||||
# Ensure the output directory exists if a full path is given
|
||||
if output_delta_file:
|
||||
output_dir_for_file = os.path.dirname(output_delta_file)
|
||||
if output_dir_for_file and not os.path.exists(output_dir_for_file):
|
||||
os.makedirs(output_dir_for_file, exist_ok=True)
|
||||
|
||||
|
||||
differences = extract_model_differences(
|
||||
args.base_model_path,
|
||||
args.finetuned_model_path,
|
||||
output_delta_path=output_delta_file,
|
||||
save_dtype_str=args.save_dtype
|
||||
)
|
||||
|
||||
if differences:
|
||||
print(f"\nExtraction process finished. {len(differences)} total keys in the delta state_dict.")
|
||||
else:
|
||||
print("\nCould not extract differences due to errors during model loading.")
|
||||
Loading…
Reference in New Issue