Format with black

pull/2219/head
bmaltais 2024-04-03 18:32:25 -04:00
parent c827268bf3
commit 49f76343b5
40 changed files with 1436 additions and 1117 deletions

View File

@ -1,7 +1,13 @@
import gradio as gr
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path, add_pre_postfix, find_replace, scriptdir, list_dirs
from .common_gui import (
get_folder_path,
add_pre_postfix,
find_replace,
scriptdir,
list_dirs,
)
import os
import sys
@ -12,6 +18,7 @@ log = setup_logging()
PYTHON = sys.executable
def caption_images(
caption_text: str,
images_dir: str,
@ -41,26 +48,26 @@ def caption_images(
# Check if images_dir is provided
if not images_dir:
msgbox(
'Image folder is missing. Please provide the directory containing the images to caption.'
"Image folder is missing. Please provide the directory containing the images to caption."
)
return
# Check if caption_ext is provided
if not caption_ext:
msgbox('Please provide an extension for the caption files.')
msgbox("Please provide an extension for the caption files.")
return
# Log the captioning process
if caption_text:
log.info(f'Captioning files in {images_dir} with {caption_text}...')
log.info(f"Captioning files in {images_dir} with {caption_text}...")
# Build the command to run caption.py
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/caption.py"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/caption.py"'
run_cmd += f' --caption_text="{caption_text}"'
# Add optional flags to the command
if overwrite:
run_cmd += f' --overwrite'
run_cmd += f" --overwrite"
if caption_ext:
run_cmd += f' --caption_file_ext="{caption_ext}"'
@ -71,7 +78,9 @@ def caption_images(
# Set the environment variable for the Python path
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command based on the operating system
subprocess.run(run_cmd, shell=True, env=env)
@ -102,7 +111,7 @@ def caption_images(
)
# Log the end of the captioning process
log.info('Captioning done.')
log.info("Captioning done.")
# Gradio UI
@ -121,7 +130,11 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
from .common_gui import create_refresh_button
# Set default images directory if not provided
default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data")
default_images_dir = (
default_images_dir
if default_images_dir is not None
else os.path.join(scriptdir, "data")
)
current_images_dir = default_images_dir
# Function to list directories
@ -141,26 +154,34 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
return list(list_dirs(path))
# Gradio tab for basic captioning
with gr.Tab('Basic Captioning'):
with gr.Tab("Basic Captioning"):
# Markdown description
gr.Markdown(
'This utility allows you to create simple caption files for each image in a folder.'
"This utility allows you to create simple caption files for each image in a folder."
)
# Group and row for image folder selection
with gr.Group(), gr.Row():
# Dropdown for image folder
images_dir = gr.Dropdown(
label='Image folder to caption (containing the images to caption)',
label="Image folder to caption (containing the images to caption)",
choices=[""] + list_images_dirs(default_images_dir),
value="",
interactive=True,
allow_custom_value=True,
)
# Refresh button for image folder
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small")
create_refresh_button(
images_dir,
lambda: None,
lambda: {"choices": list_images_dirs(current_images_dir)},
"open_folder_small",
)
# Button to open folder
folder_button = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
# Event handler for button click
folder_button.click(
@ -170,14 +191,14 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
)
# Textbox for caption file extension
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file (e.g., .caption, .txt)',
value='.txt',
label="Caption file extension",
placeholder="Extension for caption file (e.g., .caption, .txt)",
value=".txt",
interactive=True,
)
# Checkbox to overwrite existing captions
overwrite = gr.Checkbox(
label='Overwrite existing captions in folder',
label="Overwrite existing captions in folder",
interactive=True,
value=False,
)
@ -185,41 +206,41 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
with gr.Row():
# Textbox for caption prefix
prefix = gr.Textbox(
label='Prefix to add to caption',
placeholder='(Optional)',
label="Prefix to add to caption",
placeholder="(Optional)",
interactive=True,
)
# Textbox for caption text
caption_text = gr.Textbox(
label='Caption text',
label="Caption text",
placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.',
interactive=True,
lines=2,
)
# Textbox for caption postfix
postfix = gr.Textbox(
label='Postfix to add to caption',
placeholder='(Optional)',
label="Postfix to add to caption",
placeholder="(Optional)",
interactive=True,
)
# Group and row for find and replace text
with gr.Group(), gr.Row():
# Textbox for find text
find_text = gr.Textbox(
label='Find text',
label="Find text",
placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.',
interactive=True,
lines=2,
)
# Textbox for replace text
replace_text = gr.Textbox(
label='Replacement text',
label="Replacement text",
placeholder='e.g., "by some artist". Leave empty if you want to replace with nothing.',
interactive=True,
lines=2,
)
# Button to caption images
caption_button = gr.Button('Caption images')
caption_button = gr.Button("Caption images")
# Event handler for button click
caption_button.click(
caption_images,

View File

@ -4,7 +4,7 @@ import torch
import gradio as gr
import os
from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs
from .common_gui import get_folder_path, scriptdir, list_dirs
from .custom_logging import setup_logging
# Set up logging

View File

@ -74,7 +74,9 @@ def caption_images(
# Set up the environment
env = os.environ.copy()
env["PYTHONPATH"] = f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, shell=True, env=env, cwd=f"{scriptdir}/sd-scripts")

View File

@ -69,12 +69,12 @@ class AccelerateLaunch:
def run_cmd(**kwargs):
run_cmd = ""
if "extra_accelerate_launch_args" in kwargs:
extra_accelerate_launch_args = kwargs.get("extra_accelerate_launch_args")
if extra_accelerate_launch_args != "":
run_cmd += fr' {extra_accelerate_launch_args}'
run_cmd += rf" {extra_accelerate_launch_args}"
if "gpu_ids" in kwargs:
gpu_ids = kwargs.get("gpu_ids")
if not gpu_ids == "":
@ -109,4 +109,4 @@ class AccelerateLaunch:
f" --num_cpu_threads_per_process={int(num_cpu_threads_per_process)}"
)
return run_cmd
return run_cmd

View File

@ -6,7 +6,7 @@ from .common_gui import (
list_files,
list_dirs,
create_refresh_button,
document_symbol
document_symbol,
)
@ -221,12 +221,20 @@ class AdvancedTraining:
choices=["none", "sdpa", "xformers"],
value="xformers",
)
self.color_aug = gr.Checkbox(label="Color augmentation", value=False, info="Enable weak color augmentation")
self.flip_aug = gr.Checkbox(label="Flip augmentation", value=False, info="Enable horizontal flip augmentation")
self.color_aug = gr.Checkbox(
label="Color augmentation",
value=False,
info="Enable weak color augmentation",
)
self.flip_aug = gr.Checkbox(
label="Flip augmentation",
value=False,
info="Enable horizontal flip augmentation",
)
self.masked_loss = gr.Checkbox(
label="Masked loss",
value=False,
info="Apply mask for calculating loss. conditioning_data_dir is required for dataset"
info="Apply mask for calculating loss. conditioning_data_dir is required for dataset",
)
with gr.Row():
self.scale_v_pred_loss_like_noise_pred = gr.Checkbox(
@ -296,7 +304,7 @@ class AdvancedTraining:
"Multires",
],
value="Original",
scale=1
scale=1,
)
with gr.Row(visible=True) as self.noise_offset_original:
self.noise_offset = gr.Slider(
@ -305,12 +313,12 @@ class AdvancedTraining:
minimum=0,
maximum=1,
step=0.01,
info='Recommended values are 0.05 - 0.15',
info="Recommended values are 0.05 - 0.15",
)
self.noise_offset_random_strength = gr.Checkbox(
self.noise_offset_random_strength = gr.Checkbox(
label="Noise offset random strength",
value=False,
info='Use random strength between 0~noise_offset for noise offset',
info="Use random strength between 0~noise_offset for noise offset",
)
self.adaptive_noise_scale = gr.Slider(
label="Adaptive noise scale",
@ -327,7 +335,7 @@ class AdvancedTraining:
minimum=0,
maximum=64,
step=1,
info='Enable multires noise (recommended values are 6-10)',
info="Enable multires noise (recommended values are 6-10)",
)
self.multires_noise_discount = gr.Slider(
label="Multires noise discount",
@ -335,7 +343,7 @@ class AdvancedTraining:
minimum=0,
maximum=1,
step=0.01,
info='Recommended values are 0.8. For LoRAs with small datasets, 0.1-0.3',
info="Recommended values are 0.8. For LoRAs with small datasets, 0.1-0.3",
)
with gr.Row(visible=True):
self.ip_noise_gamma = gr.Slider(
@ -344,12 +352,12 @@ class AdvancedTraining:
minimum=0,
maximum=1,
step=0.01,
info='enable input perturbation noise. used for regularization. recommended value: around 0.1',
info="enable input perturbation noise. used for regularization. recommended value: around 0.1",
)
self.ip_noise_gamma_random_strength = gr.Checkbox(
self.ip_noise_gamma_random_strength = gr.Checkbox(
label="IP noise gamma random strength",
value=False,
info='Use random strength between 0~ip_noise_gamma for input perturbation noise',
info="Use random strength between 0~ip_noise_gamma for input perturbation noise",
)
self.noise_offset_type.change(
noise_offset_type_change,
@ -371,9 +379,10 @@ class AdvancedTraining:
)
with gr.Group(), gr.Row():
self.save_state = gr.Checkbox(label="Save training state", value=False)
self.save_state_on_train_end = gr.Checkbox(label="Save training state at end of training", value=False)
self.save_state_on_train_end = gr.Checkbox(
label="Save training state at end of training", value=False
)
def list_state_dirs(path):
self.current_state_dir = path if not path == "" else "."

View File

@ -1,5 +1,4 @@
import gradio as gr
import os
from typing import Tuple

View File

@ -5,6 +5,7 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
class CommandExecutor:
"""
A class to execute and manage commands.
@ -43,7 +44,9 @@ class CommandExecutor:
log.info("The running process has been terminated.")
except psutil.NoSuchProcess:
# Explicitly handle the case where the process does not exist
log.info("The process does not exist. It might have terminated before the kill command was issued.")
log.info(
"The process does not exist. It might have terminated before the kill command was issued."
)
except Exception as e:
# General exception handling for any other errors
log.info(f"Error when terminating process: {e}")

View File

@ -12,7 +12,9 @@ class ConfigurationFile:
A class to handle configuration file operations in the GUI.
"""
def __init__(self, headless: bool = False, config_dir: str = None, config:dict = {}):
def __init__(
self, headless: bool = False, config_dir: str = None, config: dict = {}
):
"""
Initialize the ConfigurationFile class.
@ -22,11 +24,13 @@ class ConfigurationFile:
"""
self.headless = headless
self.config = config
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
self.current_config_dir = self.config.get('config_dir', os.path.join(scriptdir, "presets"))
self.current_config_dir = self.config.get(
"config_dir", os.path.join(scriptdir, "presets")
)
# Initialize the GUI components for configuration.
self.create_config_gui()

View File

@ -1,12 +1,16 @@
import gradio as gr
import os
from .common_gui import get_folder_path, scriptdir, list_dirs, list_files, create_refresh_button
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
class Folders:
"""
A class to handle folder operations in the GUI.
"""
def __init__(self, finetune: bool = False, headless: bool = False, config:dict = {}):
def __init__(
self, finetune: bool = False, headless: bool = False, config: dict = {}
):
"""
Initialize the Folders class.
@ -21,9 +25,15 @@ class Folders:
self.config = config
# Set default directories if not provided
self.current_output_dir = self.config.get('output_dir', os.path.join(scriptdir, "outputs"))
self.current_logging_dir = self.config.get('logging_dir', os.path.join(scriptdir, "logs"))
self.current_reg_data_dir = self.config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
self.current_output_dir = self.config.get(
"output_dir", os.path.join(scriptdir, "outputs")
)
self.current_logging_dir = self.config.get(
"logging_dir", os.path.join(scriptdir, "logs")
)
self.current_reg_data_dir = self.config.get(
"reg_data_dir", os.path.join(scriptdir, "reg")
)
# Create directories if they don't exist
self.create_directory_if_not_exists(self.current_output_dir)
@ -39,10 +49,13 @@ class Folders:
Parameters:
- directory (str): The directory to create.
"""
if directory is not None and directory.strip() != "" and not os.path.exists(directory):
if (
directory is not None
and directory.strip() != ""
and not os.path.exists(directory)
):
os.makedirs(directory, exist_ok=True)
def list_output_dirs(self, path: str) -> list:
"""
List directories in the output directory.
@ -96,10 +109,20 @@ class Folders:
allow_custom_value=True,
)
# Refresh button for output directory
create_refresh_button(self.output_dir, lambda: None, lambda: {"choices": [""] + self.list_output_dirs(self.current_output_dir)}, "open_folder_small")
create_refresh_button(
self.output_dir,
lambda: None,
lambda: {
"choices": [""] + self.list_output_dirs(self.current_output_dir)
},
"open_folder_small",
)
# Output directory button
self.output_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
# Output directory button click event
self.output_dir_folder.click(
@ -110,17 +133,31 @@ class Folders:
# Regularisation directory dropdown
self.reg_data_dir = gr.Dropdown(
label='Regularisation directory (Optional. containing regularisation images)' if not self.finetune else 'Train config directory (Optional. where config files will be saved)',
label=(
"Regularisation directory (Optional. containing regularisation images)"
if not self.finetune
else "Train config directory (Optional. where config files will be saved)"
),
choices=[""] + self.list_reg_data_dirs(self.current_reg_data_dir),
value="",
interactive=True,
allow_custom_value=True,
)
# Refresh button for regularisation directory
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
create_refresh_button(
self.reg_data_dir,
lambda: None,
lambda: {
"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)
},
"open_folder_small",
)
# Regularisation directory button
self.reg_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
# Regularisation directory button click event
self.reg_data_dir_folder.click(
@ -131,17 +168,27 @@ class Folders:
with gr.Row():
# Logging directory dropdown
self.logging_dir = gr.Dropdown(
label='Logging directory (Optional. to enable logging and output Tensorboard log)',
label="Logging directory (Optional. to enable logging and output Tensorboard log)",
choices=[""] + self.list_logging_dirs(self.current_logging_dir),
value="",
interactive=True,
allow_custom_value=True,
)
# Refresh button for logging directory
create_refresh_button(self.logging_dir, lambda: None, lambda: {"choices": [""] + self.list_logging_dirs(self.current_logging_dir)}, "open_folder_small")
create_refresh_button(
self.logging_dir,
lambda: None,
lambda: {
"choices": [""] + self.list_logging_dirs(self.current_logging_dir)
},
"open_folder_small",
)
# Logging directory button
self.logging_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
# Logging directory button click event
self.logging_dir_folder.click(
@ -159,14 +206,18 @@ class Folders:
)
# Change event for regularisation directory dropdown
self.reg_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
fn=lambda path: gr.Dropdown(
choices=[""] + self.list_reg_data_dirs(path)
),
inputs=self.reg_data_dir,
outputs=self.reg_data_dir,
show_progress=False,
)
# Change event for logging directory dropdown
self.logging_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + self.list_logging_dirs(path)),
fn=lambda path: gr.Dropdown(
choices=[""] + self.list_logging_dirs(path)
),
inputs=self.logging_dir,
outputs=self.logging_dir,
show_progress=False,

View File

@ -5,6 +5,7 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
class KohyaSSGUIConfig:
"""
A class to handle the configuration for the Kohya SS GUI.
@ -30,7 +31,9 @@ class KohyaSSGUIConfig:
except FileNotFoundError:
# If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully.
config = {}
log.debug(f"No configuration file found at {config_file_path}. Initializing empty configuration.")
log.debug(
f"No configuration file found at {config_file_path}. Initializing empty configuration."
)
return config
@ -66,7 +69,9 @@ class KohyaSSGUIConfig:
log.debug(k)
# If the key is not found in the current data, return the default value
if k not in data:
log.debug(f"Key '{key}' not found in configuration. Returning default value.")
log.debug(
f"Key '{key}' not found in configuration. Returning default value."
)
return default
# Update `data` to the value associated with the current key

View File

@ -14,9 +14,7 @@ class LoRATools:
def __init__(self, headless: bool = False):
self.headless = headless
gr.Markdown(
'This section provide various LoRA tools...'
)
gr.Markdown("This section provide various LoRA tools...")
gradio_extract_dylora_tab(headless=headless)
gradio_convert_lcm_tab(headless=headless)
gradio_extract_lora_tab(headless=headless)

View File

@ -1,16 +1,15 @@
import os
import gradio as gr
from easygui import msgbox
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
###
@ -38,10 +37,10 @@ def run_cmd_sample(
Returns:
str: The command string for sampling images.
"""
output_dir = os.path.join(output_dir, 'sample')
output_dir = os.path.join(output_dir, "sample")
os.makedirs(output_dir, exist_ok=True)
run_cmd = ''
run_cmd = ""
if sample_every_n_epochs is None:
sample_every_n_epochs = 0
@ -53,19 +52,19 @@ def run_cmd_sample(
return run_cmd
# Create the prompt file and get its path
sample_prompts_path = os.path.join(output_dir, 'prompt.txt')
sample_prompts_path = os.path.join(output_dir, "prompt.txt")
with open(sample_prompts_path, 'w') as f:
with open(sample_prompts_path, "w") as f:
f.write(sample_prompts)
run_cmd += f' --sample_sampler={sample_sampler}'
run_cmd += f" --sample_sampler={sample_sampler}"
run_cmd += f' --sample_prompts="{sample_prompts_path}"'
if sample_every_n_epochs != 0:
run_cmd += f' --sample_every_n_epochs={sample_every_n_epochs}'
run_cmd += f" --sample_every_n_epochs={sample_every_n_epochs}"
if sample_every_n_steps != 0:
run_cmd += f' --sample_every_n_steps={sample_every_n_steps}'
run_cmd += f" --sample_every_n_steps={sample_every_n_steps}"
return run_cmd
@ -89,45 +88,45 @@ class SampleImages:
"""
with gr.Row():
self.sample_every_n_steps = gr.Number(
label='Sample every n steps',
label="Sample every n steps",
value=0,
precision=0,
interactive=True,
)
self.sample_every_n_epochs = gr.Number(
label='Sample every n epochs',
label="Sample every n epochs",
value=0,
precision=0,
interactive=True,
)
self.sample_sampler = gr.Dropdown(
label='Sample sampler',
label="Sample sampler",
choices=[
'ddim',
'pndm',
'lms',
'euler',
'euler_a',
'heun',
'dpm_2',
'dpm_2_a',
'dpmsolver',
'dpmsolver++',
'dpmsingle',
'k_lms',
'k_euler',
'k_euler_a',
'k_dpm_2',
'k_dpm_2_a',
"ddim",
"pndm",
"lms",
"euler",
"euler_a",
"heun",
"dpm_2",
"dpm_2_a",
"dpmsolver",
"dpmsolver++",
"dpmsingle",
"k_lms",
"k_euler",
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
],
value='euler_a',
value="euler_a",
interactive=True,
)
with gr.Row():
self.sample_prompts = gr.Textbox(
lines=5,
label='Sample prompts',
label="Sample prompts",
interactive=True,
placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28',
info='Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.',
placeholder="masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28",
info="Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.",
)

View File

@ -1,5 +1,6 @@
import gradio as gr
class SDXLParameters:
def __init__(
self,
@ -7,25 +8,23 @@ class SDXLParameters:
show_sdxl_cache_text_encoder_outputs: bool = True,
):
self.sdxl_checkbox = sdxl_checkbox
self.show_sdxl_cache_text_encoder_outputs = (
show_sdxl_cache_text_encoder_outputs
)
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
self.initialize_accordion()
def initialize_accordion(self):
with gr.Accordion(
visible=False, open=True, label='SDXL Specific Parameters'
visible=False, open=True, label="SDXL Specific Parameters"
) as self.sdxl_row:
with gr.Row():
self.sdxl_cache_text_encoder_outputs = gr.Checkbox(
label='Cache text encoder outputs',
info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.',
label="Cache text encoder outputs",
info="Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.",
value=False,
visible=self.show_sdxl_cache_text_encoder_outputs,
)
self.sdxl_no_half_vae = gr.Checkbox(
label='No half VAE',
info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.',
label="No half VAE",
info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.",
value=True,
)

View File

@ -62,8 +62,9 @@ class SourceModel:
self.current_train_data_dir = self.config.get(
"train_data_dir", os.path.join(scriptdir, "data")
)
self.current_dataset_config_dir = self.config.get('dataset_config_dir', os.path.join(scriptdir, "dataset_config"))
self.current_dataset_config_dir = self.config.get(
"dataset_config_dir", os.path.join(scriptdir, "dataset_config")
)
model_checkpoints = list(
list_files(
@ -82,7 +83,7 @@ class SourceModel:
def list_train_data_dirs(path):
self.current_train_data_dir = path if not path == "" else "."
return list(list_dirs(path))
def list_dataset_config_dirs(path: str) -> list:
"""
List directories and toml files in the dataset_config directory.
@ -95,8 +96,9 @@ class SourceModel:
"""
current_dataset_config_dir = path if not path == "" else "."
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
return list(list_files(current_dataset_config_dir, exts=[".toml"], all=True))
return list(
list_files(current_dataset_config_dir, exts=[".toml"], all=True)
)
with gr.Accordion("Model", open=True):
with gr.Column(), gr.Group():
@ -143,7 +145,7 @@ class SourceModel:
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
with gr.Column(), gr.Row():
self.output_name = gr.Textbox(
label="Trained Model output name",
@ -188,29 +190,49 @@ class SourceModel:
with gr.Column(), gr.Row():
# Toml directory dropdown
self.dataset_config = gr.Dropdown(
label='Dataset config file (Optional. Select the toml configuration file to use for the dataset)',
choices=[""] + list_dataset_config_dirs(self.current_dataset_config_dir),
label="Dataset config file (Optional. Select the toml configuration file to use for the dataset)",
choices=[""]
+ list_dataset_config_dirs(self.current_dataset_config_dir),
value="",
interactive=True,
allow_custom_value=True,
)
# Refresh button for dataset_config directory
create_refresh_button(self.dataset_config, lambda: None, lambda: {"choices": [""] + list_dataset_config_dirs(self.current_dataset_config_dir)}, "open_folder_small")
create_refresh_button(
self.dataset_config,
lambda: None,
lambda: {
"choices": [""]
+ list_dataset_config_dirs(
self.current_dataset_config_dir
)
},
"open_folder_small",
)
# Toml directory button
self.dataset_config_folder = gr.Button(
document_symbol, elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
document_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
# Toml directory button click event
self.dataset_config_folder.click(
get_file_path,
inputs=[self.dataset_config, gr.Textbox(value='*.toml', visible=False), gr.Textbox(value='Dataset config types', visible=False)],
inputs=[
self.dataset_config,
gr.Textbox(value="*.toml", visible=False),
gr.Textbox(value="Dataset config types", visible=False),
],
outputs=self.dataset_config,
show_progress=False,
)
# Change event for dataset_config directory dropdown
self.dataset_config.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_config_dirs(path)),
fn=lambda path: gr.Dropdown(
choices=[""] + list_dataset_config_dirs(path)
),
inputs=self.dataset_config,
outputs=self.dataset_config,
show_progress=False,
@ -273,7 +295,9 @@ class SourceModel:
)
self.train_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
fn=lambda path: gr.Dropdown(
choices=[""] + list_train_data_dirs(path)
),
inputs=self.train_data_dir,
outputs=self.train_data_dir,
show_progress=False,

View File

@ -6,7 +6,6 @@ from .custom_logging import setup_logging
import os
import re
import gradio as gr
import shutil
import sys
import json
import math
@ -55,6 +54,7 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]
def calculate_max_train_steps(
total_steps: int,
train_batch_size: int,
@ -72,6 +72,7 @@ def calculate_max_train_steps(
)
)
def check_if_model_exist(
output_name: str, output_dir: str, save_model_as: str, headless: bool = False
) -> bool:
@ -1097,14 +1098,14 @@ def run_cmd_advanced_training(**kwargs):
if kwargs.get("gradient_checkpointing"):
run_cmd += " --gradient_checkpointing"
if kwargs.get("ip_noise_gamma"):
if float(kwargs["ip_noise_gamma"]) > 0:
run_cmd += f' --ip_noise_gamma={kwargs["ip_noise_gamma"]}'
if kwargs.get("ip_noise_gamma_random_strength"):
if kwargs["ip_noise_gamma_random_strength"]:
run_cmd += f' --ip_noise_gamma_random_strength'
run_cmd += f" --ip_noise_gamma_random_strength"
if "keep_tokens" in kwargs and int(kwargs["keep_tokens"]) > 0:
run_cmd += f' --keep_tokens="{int(kwargs["keep_tokens"])}"'
@ -1180,7 +1181,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"'
if "masked_loss" in kwargs:
if kwargs.get("masked_loss"): # Test if the value is true as it could be false
if kwargs.get("masked_loss"): # Test if the value is true as it could be false
run_cmd += " --masked_loss"
if "max_data_loader_n_workers" in kwargs:
@ -1194,7 +1195,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f' --max_grad_norm="{max_grad_norm}"'
if "max_resolution" in kwargs:
run_cmd += fr' --resolution="{kwargs.get("max_resolution")}"'
run_cmd += rf' --resolution="{kwargs.get("max_resolution")}"'
if "max_timestep" in kwargs:
max_timestep = kwargs.get("max_timestep")
@ -1217,7 +1218,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f' --max_train_steps="{max_train_steps}"'
if "mem_eff_attn" in kwargs:
if kwargs.get("mem_eff_attn"): # Test if the value is true as it could be false
if kwargs.get("mem_eff_attn"): # Test if the value is true as it could be false
run_cmd += " --mem_eff_attn"
if "min_snr_gamma" in kwargs:
@ -1229,20 +1230,20 @@ def run_cmd_advanced_training(**kwargs):
min_timestep = kwargs.get("min_timestep")
if int(min_timestep) > -1:
run_cmd += f" --min_timestep={int(min_timestep)}"
if "mixed_precision" in kwargs:
run_cmd += rf' --mixed_precision="{kwargs.get("mixed_precision")}"'
if "network_alpha" in kwargs:
run_cmd += fr' --network_alpha="{kwargs.get("network_alpha")}"'
run_cmd += rf' --network_alpha="{kwargs.get("network_alpha")}"'
if "network_args" in kwargs:
network_args = kwargs.get("network_args")
if network_args != "":
run_cmd += f' --network_args{network_args}'
run_cmd += f" --network_args{network_args}"
if "network_dim" in kwargs:
run_cmd += fr' --network_dim={kwargs.get("network_dim")}'
run_cmd += rf' --network_dim={kwargs.get("network_dim")}'
if "network_dropout" in kwargs:
network_dropout = kwargs.get("network_dropout")
@ -1252,7 +1253,7 @@ def run_cmd_advanced_training(**kwargs):
if "network_module" in kwargs:
network_module = kwargs.get("network_module")
if network_module != "":
run_cmd += f' --network_module={network_module}'
run_cmd += f" --network_module={network_module}"
if "network_train_text_encoder_only" in kwargs:
if kwargs.get("network_train_text_encoder_only"):
@ -1263,11 +1264,13 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += " --network_train_unet_only"
if "no_half_vae" in kwargs:
if kwargs.get("no_half_vae"): # Test if the value is true as it could be false
if kwargs.get("no_half_vae"): # Test if the value is true as it could be false
run_cmd += " --no_half_vae"
if "no_token_padding" in kwargs:
if kwargs.get("no_token_padding"): # Test if the value is true as it could be false
if kwargs.get(
"no_token_padding"
): # Test if the value is true as it could be false
run_cmd += " --no_token_padding"
if "noise_offset_type" in kwargs:
@ -1283,18 +1286,24 @@ def run_cmd_advanced_training(**kwargs):
adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0))
if adaptive_noise_scale != 0 and noise_offset > 0:
run_cmd += f" --adaptive_noise_scale={adaptive_noise_scale}"
if "noise_offset_random_strength" in kwargs:
if kwargs.get("noise_offset_random_strength"):
run_cmd += f" --noise_offset_random_strength"
elif noise_offset_type == "Multires":
if "multires_noise_iterations" in kwargs:
multires_noise_iterations = int(kwargs.get("multires_noise_iterations", 0))
multires_noise_iterations = int(
kwargs.get("multires_noise_iterations", 0)
)
if multires_noise_iterations > 0:
run_cmd += f' --multires_noise_iterations="{multires_noise_iterations}"'
run_cmd += (
f' --multires_noise_iterations="{multires_noise_iterations}"'
)
if "multires_noise_discount" in kwargs:
multires_noise_discount = float(kwargs.get("multires_noise_discount", 0))
multires_noise_discount = float(
kwargs.get("multires_noise_discount", 0)
)
if multires_noise_discount > 0:
run_cmd += f' --multires_noise_discount="{multires_noise_discount}"'
@ -1304,7 +1313,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f" --optimizer_args {optimizer_args}"
if "optimizer" in kwargs:
run_cmd += fr' --optimizer_type="{kwargs.get("optimizer")}"'
run_cmd += rf' --optimizer_type="{kwargs.get("optimizer")}"'
if "output_dir" in kwargs:
output_dir = kwargs.get("output_dir")
@ -1323,9 +1332,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += " --persistent_data_loader_workers"
if "pretrained_model_name_or_path" in kwargs:
run_cmd += (
rf' --pretrained_model_name_or_path="{kwargs.get("pretrained_model_name_or_path")}"'
)
run_cmd += rf' --pretrained_model_name_or_path="{kwargs.get("pretrained_model_name_or_path")}"'
if "prior_loss_weight" in kwargs:
prior_loss_weight = kwargs.get("prior_loss_weight")
@ -1376,12 +1383,12 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f" --save_model_as={save_model_as}"
if "save_precision" in kwargs:
run_cmd += fr' --save_precision="{kwargs.get("save_precision")}"'
run_cmd += rf' --save_precision="{kwargs.get("save_precision")}"'
if "save_state" in kwargs:
if kwargs.get("save_state"):
run_cmd += " --save_state"
if "save_state_on_train_end" in kwargs:
if kwargs.get("save_state_on_train_end"):
run_cmd += " --save_state_on_train_end"
@ -1415,7 +1422,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f" --text_encoder_lr={text_encoder_lr}"
if "train_batch_size" in kwargs:
run_cmd += fr' --train_batch_size="{kwargs.get("train_batch_size")}"'
run_cmd += rf' --train_batch_size="{kwargs.get("train_batch_size")}"'
training_comment = kwargs.get("training_comment")
if training_comment and len(training_comment):

View File

@ -22,17 +22,12 @@ document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
def convert_lcm(
name,
model_path,
lora_scale,
model_type
):
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'
def convert_lcm(name, model_path, lora_scale, model_type):
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'
# Check if source model exist
if not os.path.isfile(model_path):
log.error('The provided DyLoRA model is not a file')
log.error("The provided DyLoRA model is not a file")
return
if os.path.dirname(name) == "":
@ -46,12 +41,11 @@ def convert_lcm(
path, ext = os.path.splitext(save_to)
save_to = f"{path}_lcm{ext}"
# Construct the command to run the script
run_cmd += f" --lora-scale {lora_scale}"
run_cmd += f' --model "{model_path}"'
run_cmd += f' --name "{name}"'
if model_type == "SDXL":
run_cmd += f" --sdxl"
if model_type == "SSD-1B":
@ -60,7 +54,9 @@ def convert_lcm(
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
@ -98,11 +94,16 @@ def gradio_convert_lcm_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(model_path, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
model_path,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_model_path_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=['tool'],
elem_classes=["tool"],
visible=(not headless),
)
button_model_path_file.click(
@ -119,11 +120,16 @@ def gradio_convert_lcm_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
name,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_name = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=['tool'],
elem_classes=["tool"],
visible=(not headless),
)
button_name.click(
@ -154,7 +160,7 @@ def gradio_convert_lcm_tab(headless=False):
value=1.0,
interactive=True,
)
# with gr.Row():
# with gr.Row():
# no_half = gr.Checkbox(label="Convert the new LCM model to FP32", value=False)
model_type = gr.Radio(
label="Model type", choices=["SD15", "SDXL", "SD-1B"], value="SD15"
@ -164,11 +170,6 @@ def gradio_convert_lcm_tab(headless=False):
extract_button.click(
convert_lcm,
inputs=[
name,
model_path,
lora_scale,
model_type
],
inputs=[name, model_path, lora_scale, model_type],
show_progress=False,
)

View File

@ -2,7 +2,6 @@ import gradio as gr
from easygui import msgbox
import subprocess
import os
import shutil
import sys
from .common_gui import get_folder_path, get_file_path, scriptdir, list_files, list_dirs
@ -11,10 +10,10 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -29,52 +28,51 @@ def convert_model(
unet_use_linear_projection,
):
# Check for caption_text_input
if source_model_type == '':
msgbox('Invalid source model type')
if source_model_type == "":
msgbox("Invalid source model type")
return
# Check if source model exist
if os.path.isfile(source_model_input):
log.info('The provided source model is a file')
log.info("The provided source model is a file")
elif os.path.isdir(source_model_input):
log.info('The provided model is a folder')
log.info("The provided model is a folder")
else:
msgbox('The provided source model is neither a file nor a folder')
msgbox("The provided source model is neither a file nor a folder")
return
# Check if source model exist
if os.path.isdir(target_model_folder_input):
log.info('The provided model folder exist')
log.info("The provided model folder exist")
else:
msgbox('The provided target folder does not exist')
msgbox("The provided target folder does not exist")
return
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"'
run_cmd = (
rf'"{PYTHON}" "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"'
)
v1_models = [
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
]
# check if v1 models
if str(source_model_type) in v1_models:
log.info('SD v1 model specified. Setting --v1 parameter')
run_cmd += ' --v1'
log.info("SD v1 model specified. Setting --v1 parameter")
run_cmd += " --v1"
else:
log.info('SD v2 model specified. Setting --v2 parameter')
run_cmd += ' --v2'
log.info("SD v2 model specified. Setting --v2 parameter")
run_cmd += " --v2"
if not target_save_precision_type == 'unspecified':
run_cmd += f' --{target_save_precision_type}'
if not target_save_precision_type == "unspecified":
run_cmd += f" --{target_save_precision_type}"
if (
target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors'
):
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
run_cmd += f' --reference_model="{source_model_type}"'
if target_model_type == 'diffuser_safetensors':
run_cmd += ' --use_safetensors'
if target_model_type == "diffuser_safetensors":
run_cmd += " --use_safetensors"
# Fix for stabilityAI diffusers format. When saving v2 models in Diffusers format in training scripts and conversion scripts,
# it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is
@ -82,14 +80,11 @@ def convert_model(
# when using the weight files directly.
if unet_use_linear_projection:
run_cmd += ' --unet_use_linear_projection'
run_cmd += " --unet_use_linear_projection"
run_cmd += f' "{source_model_input}"'
if (
target_model_type == 'diffuser'
or target_model_type == 'diffuser_safetensors'
):
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
target_model_path = os.path.join(
target_model_folder_input, target_model_name_input
)
@ -97,74 +92,20 @@ def convert_model(
else:
target_model_path = os.path.join(
target_model_folder_input,
f'{target_model_name_input}.{target_model_type}',
f"{target_model_name_input}.{target_model_type}",
)
run_cmd += f' "{target_model_path}"'
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
# if (
# not target_model_type == 'diffuser'
# or target_model_type == 'diffuser_safetensors'
# ):
# v2_models = [
# 'stabilityai/stable-diffusion-2-1-base',
# 'stabilityai/stable-diffusion-2-base',
# ]
# v_parameterization = [
# 'stabilityai/stable-diffusion-2-1',
# 'stabilityai/stable-diffusion-2',
# ]
# if str(source_model_type) in v2_models:
# inference_file = os.path.join(
# target_model_folder_input, f'{target_model_name_input}.yaml'
# )
# log.info(f'Saving v2-inference.yaml as {inference_file}')
# shutil.copy(
# fr'{scriptdir}/v2_inference/v2-inference.yaml',
# f'{inference_file}',
# )
# if str(source_model_type) in v_parameterization:
# inference_file = os.path.join(
# target_model_folder_input, f'{target_model_name_input}.yaml'
# )
# log.info(f'Saving v2-inference-v.yaml as {inference_file}')
# shutil.copy(
# fr'{scriptdir}/v2_inference/v2-inference-v.yaml',
# f'{inference_file}',
# )
# parser = argparse.ArgumentParser()
# parser.add_argument("--v1", action='store_true',
# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
# parser.add_argument("--v2", action='store_true',
# help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
# parser.add_argument("--fp16", action='store_true',
# help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込みDiffusers形式のみ対応、保存するcheckpointのみ対応')
# parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応')
# parser.add_argument("--float", action='store_true',
# help='save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応')
# parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
# parser.add_argument("--global_step", type=int, default=0,
# help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
# parser.add_argument("--reference_model", type=str, default=None,
# help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
# parser.add_argument("model_to_load", type=str, default=None,
# help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
# parser.add_argument("model_to_save", type=str, default=None,
# help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
###
# Gradio UI
@ -189,124 +130,136 @@ def gradio_convert_model_tab(headless=False):
current_target_folder = path
return list(list_dirs(path))
with gr.Tab('Convert model'):
with gr.Tab("Convert model"):
gr.Markdown(
'This utility can be used to convert from one stable diffusion model format to another.'
"This utility can be used to convert from one stable diffusion model format to another."
)
model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False)
model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
model_ext_name = gr.Textbox(value="Model types", visible=False)
with gr.Group(), gr.Row():
with gr.Column(), gr.Row():
source_model_input = gr.Dropdown(
label='Source model (path to source model folder of file to convert...)',
interactive=True,
choices=[""] + list_source_model(default_source_model),
value="",
allow_custom_value=True,
)
create_refresh_button(source_model_input, lambda: None, lambda: {"choices": list_source_model(current_source_model)}, "open_folder_small")
button_source_model_dir = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
visible=(not headless),
)
button_source_model_dir.click(
get_folder_path,
outputs=source_model_input,
show_progress=False,
)
with gr.Column(), gr.Row():
source_model_input = gr.Dropdown(
label="Source model (path to source model folder of file to convert...)",
interactive=True,
choices=[""] + list_source_model(default_source_model),
value="",
allow_custom_value=True,
)
create_refresh_button(
source_model_input,
lambda: None,
lambda: {"choices": list_source_model(current_source_model)},
"open_folder_small",
)
button_source_model_dir = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_source_model_dir.click(
get_folder_path,
outputs=source_model_input,
show_progress=False,
)
button_source_model_file = gr.Button(
document_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
visible=(not headless),
)
button_source_model_file.click(
get_file_path,
inputs=[source_model_input, model_ext, model_ext_name],
outputs=source_model_input,
show_progress=False,
)
button_source_model_file = gr.Button(
document_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_source_model_file.click(
get_file_path,
inputs=[source_model_input, model_ext, model_ext_name],
outputs=source_model_input,
show_progress=False,
)
source_model_input.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)),
inputs=source_model_input,
outputs=source_model_input,
show_progress=False,
)
with gr.Column(), gr.Row():
source_model_type = gr.Dropdown(
label='Source model type',
choices=[
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
],
)
source_model_input.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)),
inputs=source_model_input,
outputs=source_model_input,
show_progress=False,
)
with gr.Column(), gr.Row():
source_model_type = gr.Dropdown(
label="Source model type",
choices=[
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2",
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
],
)
with gr.Group(), gr.Row():
with gr.Column(), gr.Row():
target_model_folder_input = gr.Dropdown(
label='Target model folder (path to target model folder of file name to create...)',
interactive=True,
choices=[""] + list_target_folder(default_target_folder),
value="",
allow_custom_value=True,
)
create_refresh_button(target_model_folder_input, lambda: None, lambda: {"choices": list_target_folder(current_target_folder)},"open_folder_small")
button_target_model_folder = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
visible=(not headless),
)
button_target_model_folder.click(
get_folder_path,
outputs=target_model_folder_input,
show_progress=False,
)
with gr.Column(), gr.Row():
target_model_folder_input = gr.Dropdown(
label="Target model folder (path to target model folder of file name to create...)",
interactive=True,
choices=[""] + list_target_folder(default_target_folder),
value="",
allow_custom_value=True,
)
create_refresh_button(
target_model_folder_input,
lambda: None,
lambda: {"choices": list_target_folder(current_target_folder)},
"open_folder_small",
)
button_target_model_folder = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_target_model_folder.click(
get_folder_path,
outputs=target_model_folder_input,
show_progress=False,
)
target_model_folder_input.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_target_folder(path)),
inputs=target_model_folder_input,
outputs=target_model_folder_input,
show_progress=False,
)
target_model_folder_input.change(
fn=lambda path: gr.Dropdown(
choices=[""] + list_target_folder(path)
),
inputs=target_model_folder_input,
outputs=target_model_folder_input,
show_progress=False,
)
with gr.Column(), gr.Row():
target_model_name_input = gr.Textbox(
label='Target model name',
placeholder='target model name...',
interactive=True,
)
with gr.Column(), gr.Row():
target_model_name_input = gr.Textbox(
label="Target model name",
placeholder="target model name...",
interactive=True,
)
with gr.Row():
target_model_type = gr.Dropdown(
label='Target model type',
label="Target model type",
choices=[
'diffuser',
'diffuser_safetensors',
'ckpt',
'safetensors',
"diffuser",
"diffuser_safetensors",
"ckpt",
"safetensors",
],
)
target_save_precision_type = gr.Dropdown(
label='Target model precision',
choices=['unspecified', 'fp16', 'bf16', 'float'],
value='unspecified',
label="Target model precision",
choices=["unspecified", "fp16", "bf16", "float"],
value="unspecified",
)
unet_use_linear_projection = gr.Checkbox(
label='UNet linear projection',
label="UNet linear projection",
value=False,
info="Enable for Hugging Face's stabilityai models",
)
convert_button = gr.Button('Convert model')
convert_button = gr.Button("Convert model")
convert_button.click(
convert_model,

View File

@ -11,34 +11,71 @@ from rich.traceback import install as traceback_install
log = None
def setup_logging(clean=False, debug=False):
global log
if log is not None:
return log
try:
if clean and os.path.isfile('setup.log'):
os.remove('setup.log')
time.sleep(0.1) # prevent race condition
if clean and os.path.isfile("setup.log"):
os.remove("setup.log")
time.sleep(0.1) # prevent race condition
except:
pass
if sys.version_info >= (3, 9):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', encoding='utf-8', force=True)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s | %(levelname)s | %(pathname)s | %(message)s",
filename="setup.log",
filemode="a",
encoding="utf-8",
force=True,
)
else:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', force=True)
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s | %(levelname)s | %(pathname)s | %(message)s",
filename="setup.log",
filemode="a",
force=True,
)
console = Console(
log_time=True,
log_time_format="%H:%M:%S-%f",
theme=Theme(
{
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}
),
)
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG if debug else logging.INFO, console=console)
traceback_install(
console=console,
extra_lines=1,
width=console.width,
word_wrap=False,
indent_guides=False,
suppress=[],
)
rh = RichHandler(
show_time=True,
omit_repeated_times=False,
show_level=True,
show_path=False,
markup=False,
rich_tracebacks=True,
log_time_format="%H:%M:%S-%f",
level=logging.DEBUG if debug else logging.INFO,
console=console,
)
rh.set_name(logging.DEBUG if debug else logging.INFO)
log = logging.getLogger("sd")
log.addHandler(rh)
return log

View File

@ -9,29 +9,22 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
# def select_folder():
# # Open a file dialog to select a directory
# folder = filedialog.askdirectory()
# # Update the GUI to display the selected folder
# selected_folder_label.config(text=folder)
def dataset_balancing(concept_repeats, folder, insecure):
if not concept_repeats > 0:
# Display an error message if the total number of repeats is not a valid integer
msgbox('Please enter a valid integer for the total number of repeats.')
msgbox("Please enter a valid integer for the total number of repeats.")
return
concept_repeats = int(concept_repeats)
# Check if folder exist
if folder == '' or not os.path.isdir(folder):
msgbox('Please enter a valid folder for balancing.')
if folder == "" or not os.path.isdir(folder):
msgbox("Please enter a valid folder for balancing.")
return
pattern = re.compile(r'^\d+_.+$')
pattern = re.compile(r"^\d+_.+$")
# Iterate over the subdirectories in the selected folder
for subdir in os.listdir(folder):
@ -44,7 +37,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
image_files = [
f
for f in files
if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
if f.endswith((".jpg", ".jpeg", ".png", ".gif", ".webp"))
]
# Count the number of image files
@ -52,20 +45,18 @@ def dataset_balancing(concept_repeats, folder, insecure):
if images == 0:
log.info(
f'No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}'
f"No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}"
)
# Check if the subdirectory name starts with a number inside braces,
# indicating that the repeats value should be multiplied
match = re.match(r'^\{(\d+\.?\d*)\}', subdir)
match = re.match(r"^\{(\d+\.?\d*)\}", subdir)
if match:
# Multiply the repeats value by the number inside the braces
if not images == 0:
repeats = max(
1,
round(
concept_repeats / images * float(match.group(1))
),
round(concept_repeats / images * float(match.group(1))),
)
else:
repeats = 0
@ -77,32 +68,30 @@ def dataset_balancing(concept_repeats, folder, insecure):
repeats = 0
# Check if the subdirectory name already has a number at the beginning
match = re.match(r'^\d+_', subdir)
match = re.match(r"^\d+_", subdir)
if match:
# Replace the existing number with the new number
old_name = os.path.join(folder, subdir)
new_name = os.path.join(
folder, f'{repeats}_{subdir[match.end():]}'
)
new_name = os.path.join(folder, f"{repeats}_{subdir[match.end():]}")
else:
# Add the new number at the beginning of the name
old_name = os.path.join(folder, subdir)
new_name = os.path.join(folder, f'{repeats}_{subdir}')
new_name = os.path.join(folder, f"{repeats}_{subdir}")
os.rename(old_name, new_name)
else:
log.info(
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..."
)
msgbox('Dataset balancing completed...')
msgbox("Dataset balancing completed...")
def warning(insecure):
if insecure:
if boolbox(
f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?',
choices=('Yes, I like danger', 'No, get me out of here'),
f"WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?",
choices=("Yes, I like danger", "No, get me out of here"),
):
return True
else:
@ -113,12 +102,12 @@ def gradio_dataset_balancing_tab(headless=False):
current_dataset_dir = os.path.join(scriptdir, "data")
with gr.Tab('Dreambooth/LoRA Dataset balancing'):
with gr.Tab("Dreambooth/LoRA Dataset balancing"):
gr.Markdown(
'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
"This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training."
)
gr.Markdown(
'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!'
"WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!"
)
with gr.Group(), gr.Row():
@ -128,15 +117,23 @@ def gradio_dataset_balancing_tab(headless=False):
return list(list_dirs(path))
select_dataset_folder_input = gr.Dropdown(
label='Dataset folder (folder containing the concepts folders to balance...)',
label="Dataset folder (folder containing the concepts folders to balance...)",
interactive=True,
choices=[""] + list_dataset_dirs(current_dataset_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(select_dataset_folder_input, lambda: None, lambda: {"choices": list_dataset_dirs(current_dataset_dir)}, "open_folder_small")
create_refresh_button(
select_dataset_folder_input,
lambda: None,
lambda: {"choices": list_dataset_dirs(current_dataset_dir)},
"open_folder_small",
)
select_dataset_folder_button = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
select_dataset_folder_button.click(
get_folder_path,
@ -147,7 +144,7 @@ def gradio_dataset_balancing_tab(headless=False):
total_repeats_number = gr.Number(
value=1000,
interactive=True,
label='Training steps per concept per epoch',
label="Training steps per concept per epoch",
)
select_dataset_folder_input.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_dirs(path)),
@ -156,13 +153,13 @@ def gradio_dataset_balancing_tab(headless=False):
show_progress=False,
)
with gr.Accordion('Advanced options', open=False):
with gr.Accordion("Advanced options", open=False):
insecure = gr.Checkbox(
value=False,
label='DANGER!!! -- Insecure folder renaming -- DANGER!!!',
label="DANGER!!! -- Insecure folder renaming -- DANGER!!!",
)
insecure.change(warning, inputs=insecure, outputs=insecure)
balance_button = gr.Button('Balance dataset')
balance_button = gr.Button("Balance dataset")
balance_button.click(
dataset_balancing,
inputs=[

View File

@ -1,5 +1,5 @@
import gradio as gr
from easygui import diropenbox, msgbox
from easygui import msgbox
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
import shutil
import os
@ -12,13 +12,13 @@ log = setup_logging()
def copy_info_to_Folders_tab(training_folder):
img_folder = os.path.join(training_folder, 'img')
if os.path.exists(os.path.join(training_folder, 'reg')):
reg_folder = os.path.join(training_folder, 'reg')
img_folder = os.path.join(training_folder, "img")
if os.path.exists(os.path.join(training_folder, "reg")):
reg_folder = os.path.join(training_folder, "reg")
else:
reg_folder = ''
model_folder = os.path.join(training_folder, 'model')
log_folder = os.path.join(training_folder, 'log')
reg_folder = ""
model_folder = os.path.join(training_folder, "model")
log_folder = os.path.join(training_folder, "log")
return img_folder, reg_folder, model_folder, log_folder
@ -44,17 +44,17 @@ def dreambooth_folder_preparation(
os.makedirs(util_training_dir_output, exist_ok=True)
# Check for instance prompt
if util_instance_prompt_input == '':
msgbox('Instance prompt missing...')
if util_instance_prompt_input == "":
msgbox("Instance prompt missing...")
return
# Check for class prompt
if util_class_prompt_input == '':
msgbox('Class prompt missing...')
if util_class_prompt_input == "":
msgbox("Class prompt missing...")
return
# Create the training_dir path
if util_training_images_dir_input == '':
if util_training_images_dir_input == "":
log.info(
"Training images directory is missing... can't perform the required task..."
)
@ -62,60 +62,54 @@ def dreambooth_folder_preparation(
else:
training_dir = os.path.join(
util_training_dir_output,
f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
f"img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}",
)
# Remove folders if they exist
if os.path.exists(training_dir):
log.info(f'Removing existing directory {training_dir}...')
log.info(f"Removing existing directory {training_dir}...")
shutil.rmtree(training_dir)
# Copy the training images to their respective directories
log.info(f'Copy {util_training_images_dir_input} to {training_dir}...')
log.info(f"Copy {util_training_images_dir_input} to {training_dir}...")
shutil.copytree(util_training_images_dir_input, training_dir)
if not util_regularization_images_dir_input == '':
if not util_regularization_images_dir_input == "":
# Create the regularization_dir path
if not util_regularization_images_repeat_input > 0:
log.info(
'Repeats is missing... not copying regularisation images...'
)
log.info("Repeats is missing... not copying regularisation images...")
else:
regularization_dir = os.path.join(
util_training_dir_output,
f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
f"reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}",
)
# Remove folders if they exist
if os.path.exists(regularization_dir):
log.info(
f'Removing existing directory {regularization_dir}...'
)
log.info(f"Removing existing directory {regularization_dir}...")
shutil.rmtree(regularization_dir)
# Copy the regularisation images to their respective directories
log.info(
f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
)
shutil.copytree(
util_regularization_images_dir_input, regularization_dir
f"Copy {util_regularization_images_dir_input} to {regularization_dir}..."
)
shutil.copytree(util_regularization_images_dir_input, regularization_dir)
else:
log.info(
'Regularization images directory is missing... not copying regularisation images...'
"Regularization images directory is missing... not copying regularisation images..."
)
# create log and model folder
# Check if the log folder exists and create it if it doesn't
if not os.path.exists(os.path.join(util_training_dir_output, 'log')):
os.makedirs(os.path.join(util_training_dir_output, 'log'))
if not os.path.exists(os.path.join(util_training_dir_output, "log")):
os.makedirs(os.path.join(util_training_dir_output, "log"))
# Check if the model folder exists and create it if it doesn't
if not os.path.exists(os.path.join(util_training_dir_output, 'model')):
os.makedirs(os.path.join(util_training_dir_output, 'model'))
if not os.path.exists(os.path.join(util_training_dir_output, "model")):
os.makedirs(os.path.join(util_training_dir_output, "model"))
log.info(
f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
f"Done creating kohya_ss training folder structure at {util_training_dir_output}..."
)
@ -132,22 +126,22 @@ def gradio_dreambooth_folder_creation_tab(
current_reg_data_dir = os.path.join(scriptdir, "data")
current_train_output_dir = os.path.join(scriptdir, "data")
with gr.Tab('Dreambooth/LoRA Folder preparation'):
with gr.Tab("Dreambooth/LoRA Folder preparation"):
gr.Markdown(
'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.'
"This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly."
)
with gr.Row():
util_instance_prompt_input = gr.Textbox(
label='Instance prompt',
placeholder='Eg: asd',
label="Instance prompt",
placeholder="Eg: asd",
interactive=True,
value = config.get(key="dataset_preparation.instance_prompt", default="")
value=config.get(key="dataset_preparation.instance_prompt", default=""),
)
util_class_prompt_input = gr.Textbox(
label='Class prompt',
placeholder='Eg: person',
label="Class prompt",
placeholder="Eg: person",
interactive=True,
value = config.get(key="dataset_preparation.class_prompt", default=""),
value=config.get(key="dataset_preparation.class_prompt", default=""),
)
with gr.Group(), gr.Row():
@ -157,15 +151,26 @@ def gradio_dreambooth_folder_creation_tab(
return list(list_dirs(path))
util_training_images_dir_input = gr.Dropdown(
label='Training images (directory containing the training images)',
label="Training images (directory containing the training images)",
interactive=True,
choices=[config.get(key="dataset_preparation.images_folder", default="")] + list_train_data_dirs(current_train_data_dir),
choices=[
config.get(key="dataset_preparation.images_folder", default="")
]
+ list_train_data_dirs(current_train_data_dir),
value=config.get(key="dataset_preparation.images_folder", default=""),
allow_custom_value=True,
)
create_refresh_button(util_training_images_dir_input, lambda: None, lambda: {"choices": list_train_data_dirs(current_train_data_dir)}, "open_folder_small")
create_refresh_button(
util_training_images_dir_input,
lambda: None,
lambda: {"choices": list_train_data_dirs(current_train_data_dir)},
"open_folder_small",
)
button_util_training_images_dir_input = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_util_training_images_dir_input.click(
get_folder_path,
@ -173,10 +178,10 @@ def gradio_dreambooth_folder_creation_tab(
show_progress=False,
)
util_training_images_repeat_input = gr.Number(
label='Repeats',
label="Repeats",
value=40,
interactive=True,
elem_id='number_input',
elem_id="number_input",
)
util_training_images_dir_input.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
@ -186,21 +191,35 @@ def gradio_dreambooth_folder_creation_tab(
)
with gr.Group(), gr.Row():
def list_reg_data_dirs(path):
nonlocal current_reg_data_dir
current_reg_data_dir = path
return list(list_dirs(path))
util_regularization_images_dir_input = gr.Dropdown(
label='Regularisation images (Optional. directory containing the regularisation images)',
label="Regularisation images (Optional. directory containing the regularisation images)",
interactive=True,
choices=[config.get(key="dataset_preparation.reg_images_folder", default="")] + list_reg_data_dirs(current_reg_data_dir),
value=config.get(key="dataset_preparation.reg_images_folder", default=""),
choices=[
config.get(key="dataset_preparation.reg_images_folder", default="")
]
+ list_reg_data_dirs(current_reg_data_dir),
value=config.get(
key="dataset_preparation.reg_images_folder", default=""
),
allow_custom_value=True,
)
create_refresh_button(util_regularization_images_dir_input, lambda: None, lambda: {"choices": list_reg_data_dirs(current_reg_data_dir)}, "open_folder_small")
create_refresh_button(
util_regularization_images_dir_input,
lambda: None,
lambda: {"choices": list_reg_data_dirs(current_reg_data_dir)},
"open_folder_small",
)
button_util_regularization_images_dir_input = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_util_regularization_images_dir_input.click(
get_folder_path,
@ -208,10 +227,10 @@ def gradio_dreambooth_folder_creation_tab(
show_progress=False,
)
util_regularization_images_repeat_input = gr.Number(
label='Repeats',
label="Repeats",
value=1,
interactive=True,
elem_id='number_input',
elem_id="number_input",
)
util_regularization_images_dir_input.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_reg_data_dirs(path)),
@ -220,32 +239,44 @@ def gradio_dreambooth_folder_creation_tab(
show_progress=False,
)
with gr.Group(), gr.Row():
def list_train_output_dirs(path):
nonlocal current_train_output_dir
current_train_output_dir = path
return list(list_dirs(path))
util_training_dir_output = gr.Dropdown(
label='Destination training directory (where formatted training and regularisation folders will be placed)',
label="Destination training directory (where formatted training and regularisation folders will be placed)",
interactive=True,
choices=[config.get(key="train_data_dir", default="")] + list_train_output_dirs(current_train_output_dir),
choices=[config.get(key="train_data_dir", default="")]
+ list_train_output_dirs(current_train_output_dir),
value=config.get(key="train_data_dir", default=""),
allow_custom_value=True,
)
create_refresh_button(util_training_dir_output, lambda: None, lambda: {"choices": list_train_output_dirs(current_train_output_dir)}, "open_folder_small")
create_refresh_button(
util_training_dir_output,
lambda: None,
lambda: {"choices": list_train_output_dirs(current_train_output_dir)},
"open_folder_small",
)
button_util_training_dir_output = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_util_training_dir_output.click(
get_folder_path, outputs=util_training_dir_output
)
util_training_dir_output.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_output_dirs(path)),
fn=lambda path: gr.Dropdown(
choices=[""] + list_train_output_dirs(path)
),
inputs=util_training_dir_output,
outputs=util_training_dir_output,
show_progress=False,
)
button_prepare_training_data = gr.Button('Prepare training data')
button_prepare_training_data = gr.Button("Prepare training data")
button_prepare_training_data.click(
dreambooth_folder_preparation,
inputs=[
@ -259,15 +290,3 @@ def gradio_dreambooth_folder_creation_tab(
],
show_progress=False,
)
# button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
# button_copy_info_to_Folders_tab.click(
# copy_info_to_Folders_tab,
# inputs=[util_training_dir_output],
# outputs=[
# train_data_dir_input,
# reg_data_dir_input,
# output_dir_input,
# logging_dir_input,
# ],
# show_progress=False,
# )

View File

@ -3,7 +3,6 @@ import json
import math
import os
import sys
import pathlib
from datetime import datetime
from .common_gui import (
get_file_path,
@ -25,7 +24,6 @@ from .class_basic_training import BasicTraining
from .class_advanced_training import AdvancedTraining
from .class_folders import Folders
from .class_command_executor import CommandExecutor
from .class_sdxl_parameters import SDXLParameters
from .tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
@ -459,7 +457,9 @@ def train_model(
return
if dataset_config:
log.info("Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations...")
log.info(
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
)
else:
# Get a list of all subfolders in train_data_dir, excluding hidden folders
subfolders = [
@ -509,7 +509,9 @@ def train_model(
log.info(f"Folder {folder} : steps {steps}")
if total_steps == 0:
log.info(f"No images were found in folder {train_data_dir}... please rectify!")
log.info(
f"No images were found in folder {train_data_dir}... please rectify!"
)
return
# Print the result
@ -541,7 +543,9 @@ def train_model(
# calculate stop encoder training
if int(stop_text_encoder_training_pct) == -1:
stop_text_encoder_training = -1
elif stop_text_encoder_training_pct == None or (not max_train_steps == "" or not max_train_steps == "0"):
elif stop_text_encoder_training_pct == None or (
not max_train_steps == "" or not max_train_steps == "0"
):
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
@ -729,7 +733,7 @@ def dreambooth_tab(
with gr.Tab("Training"), gr.Column(variant="compact"):
gr.Markdown("Train a custom model using kohya dreambooth python code...")
with gr.Accordion("Accelerate launch", open=False), gr.Column():
accelerate_launch = AccelerateLaunch()
@ -738,7 +742,7 @@ def dreambooth_tab(
with gr.Accordion("Folders", open=False), gr.Group():
folders = Folders(headless=headless, config=config)
with gr.Accordion("Parameters", open=False), gr.Column():
with gr.Accordion("Basic", open="True"):
with gr.Group(elem_id="basic_tab"):

View File

@ -4,7 +4,6 @@ import subprocess
import os
import sys
from .common_gui import (
get_saveasfilename_path,
get_file_path,
scriptdir,
list_files,
@ -16,10 +15,10 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -30,13 +29,13 @@ def extract_dylora(
unit,
):
# Check for caption_text_input
if model == '':
msgbox('Invalid DyLoRA model file')
if model == "":
msgbox("Invalid DyLoRA model file")
return
# Check if source model exist
if not os.path.isfile(model):
msgbox('The provided DyLoRA model is not a file')
msgbox("The provided DyLoRA model is not a file")
return
if os.path.dirname(save_to) == "":
@ -51,21 +50,23 @@ def extract_dylora(
save_to = f"{path}_tmp{ext}"
run_cmd = (
fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"'
rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"'
)
run_cmd += fr' --save_to "{save_to}"'
run_cmd += fr' --model "{model}"'
run_cmd += f' --unit {unit}'
run_cmd += rf' --save_to "{save_to}"'
run_cmd += rf' --model "{model}"'
run_cmd += f" --unit {unit}"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
log.info('Done extracting DyLoRA...')
log.info("Done extracting DyLoRA...")
###
@ -77,12 +78,10 @@ def gradio_extract_dylora_tab(headless=False):
current_model_dir = os.path.join(scriptdir, "outputs")
current_save_dir = os.path.join(scriptdir, "outputs")
with gr.Tab('Extract DyLoRA'):
gr.Markdown(
'This utility can extract a DyLoRA network from a finetuned model.'
)
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Tab("Extract DyLoRA"):
gr.Markdown("This utility can extract a DyLoRA network from a finetuned model.")
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
def list_models(path):
nonlocal current_model_dir
@ -96,17 +95,22 @@ def gradio_extract_dylora_tab(headless=False):
with gr.Group(), gr.Row():
model = gr.Dropdown(
label='DyLoRA model (path to the DyLoRA model to extract from)',
label="DyLoRA model (path to the DyLoRA model to extract from)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
model,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_model_file.click(
@ -117,17 +121,22 @@ def gradio_extract_dylora_tab(headless=False):
)
save_to = gr.Dropdown(
label='Save to (path where to save the extracted LoRA model...)',
label="Save to (path where to save the extracted LoRA model...)",
interactive=True,
choices=[""] + list_save_to(current_save_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
unit = gr.Slider(
minimum=1,
maximum=256,
label='Network Dimension (Rank)',
label="Network Dimension (Rank)",
value=1,
step=1,
interactive=True,
@ -146,7 +155,7 @@ def gradio_extract_dylora_tab(headless=False):
show_progress=False,
)
extract_button = gr.Button('Extract LoRA model')
extract_button = gr.Button("Extract LoRA model")
extract_button.click(
extract_dylora,

View File

@ -1,5 +1,4 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
import sys
@ -17,10 +16,10 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -42,21 +41,21 @@ def extract_lora(
load_precision,
):
# Check for caption_text_input
if model_tuned == '':
log.info('Invalid finetuned model file')
if model_tuned == "":
log.info("Invalid finetuned model file")
return
if model_org == '':
log.info('Invalid base model file')
if model_org == "":
log.info("Invalid base model file")
return
# Check if source model exist
if not os.path.isfile(model_tuned):
log.info('The provided finetuned model is not a file')
log.info("The provided finetuned model is not a file")
return
if not os.path.isfile(model_org):
log.info('The provided base model is not a file')
log.info("The provided base model is not a file")
return
if os.path.dirname(save_to) == "":
@ -74,31 +73,33 @@ def extract_lora(
return
run_cmd = (
fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"'
rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"'
)
run_cmd += f' --load_precision {load_precision}'
run_cmd += f' --save_precision {save_precision}'
run_cmd += fr' --save_to "{save_to}"'
run_cmd += fr' --model_org "{model_org}"'
run_cmd += fr' --model_tuned "{model_tuned}"'
run_cmd += f' --dim {dim}'
run_cmd += f' --device {device}'
run_cmd += f" --load_precision {load_precision}"
run_cmd += f" --save_precision {save_precision}"
run_cmd += rf' --save_to "{save_to}"'
run_cmd += rf' --model_org "{model_org}"'
run_cmd += rf' --model_tuned "{model_tuned}"'
run_cmd += f" --dim {dim}"
run_cmd += f" --device {device}"
if conv_dim > 0:
run_cmd += f' --conv_dim {conv_dim}'
run_cmd += f" --conv_dim {conv_dim}"
if v2:
run_cmd += f' --v2'
run_cmd += f" --v2"
if sdxl:
run_cmd += f' --sdxl'
run_cmd += f' --clamp_quantile {clamp_quantile}'
run_cmd += f' --min_diff {min_diff}'
run_cmd += f" --sdxl"
run_cmd += f" --clamp_quantile {clamp_quantile}"
run_cmd += f" --min_diff {min_diff}"
if sdxl:
run_cmd += f' --load_original_model_to {load_original_model_to}'
run_cmd += f' --load_tuned_model_to {load_tuned_model_to}'
run_cmd += f" --load_original_model_to {load_original_model_to}"
run_cmd += f" --load_tuned_model_to {load_tuned_model_to}"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
@ -132,29 +133,31 @@ def gradio_extract_lora_tab(headless=False):
def change_sdxl(sdxl):
return gr.Dropdown(visible=sdxl), gr.Dropdown(visible=sdxl)
with gr.Tab('Extract LoRA'):
gr.Markdown(
'This utility can extract a LoRA network from a finetuned model.'
)
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False)
with gr.Tab("Extract LoRA"):
gr.Markdown("This utility can extract a LoRA network from a finetuned model.")
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
model_ext = gr.Textbox(value="*.ckpt *.safetensors", visible=False)
model_ext_name = gr.Textbox(value="Model types", visible=False)
with gr.Group(), gr.Row():
model_tuned = gr.Dropdown(
label='Finetuned model (path to the finetuned model to extract)',
label="Finetuned model (path to the finetuned model to extract)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(model_tuned, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
model_tuned,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_model_tuned_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_model_tuned_file.click(
@ -164,25 +167,31 @@ def gradio_extract_lora_tab(headless=False):
show_progress=False,
)
load_tuned_model_to = gr.Radio(
label='Load finetuned model to',
choices=['cpu', 'cuda', 'cuda:0'],
value='cpu',
interactive=True, scale=1,
label="Load finetuned model to",
choices=["cpu", "cuda", "cuda:0"],
value="cpu",
interactive=True,
scale=1,
info="only for SDXL",
visible=False,
)
model_org = gr.Dropdown(
label='Stable Diffusion base model (original model: ckpt or safetensors file)',
label="Stable Diffusion base model (original model: ckpt or safetensors file)",
interactive=True,
choices=[""] + list_org_models(current_model_org_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(model_org, lambda: None, lambda: {"choices": list_org_models(current_model_org_dir)}, "open_folder_small")
create_refresh_button(
model_org,
lambda: None,
lambda: {"choices": list_org_models(current_model_org_dir)},
"open_folder_small",
)
button_model_org_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_model_org_file.click(
@ -192,27 +201,33 @@ def gradio_extract_lora_tab(headless=False):
show_progress=False,
)
load_original_model_to = gr.Dropdown(
label='Load Stable Diffusion base model to',
choices=['cpu', 'cuda', 'cuda:0'],
value='cpu',
interactive=True, scale=1,
label="Load Stable Diffusion base model to",
choices=["cpu", "cuda", "cuda:0"],
value="cpu",
interactive=True,
scale=1,
info="only for SDXL",
visible=False,
)
with gr.Group(), gr.Row():
save_to = gr.Dropdown(
label='Save to (path where to save the extracted LoRA model...)',
label="Save to (path where to save the extracted LoRA model...)",
interactive=True,
choices=[""] + list_save_to(current_save_dir),
value="",
allow_custom_value=True,
scale=2,
)
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_save_to = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_save_to.click(
@ -222,16 +237,18 @@ def gradio_extract_lora_tab(headless=False):
show_progress=False,
)
save_precision = gr.Radio(
label='Save precision',
choices=['fp16', 'bf16', 'float'],
value='fp16',
interactive=True, scale=1,
label="Save precision",
choices=["fp16", "bf16", "float"],
value="fp16",
interactive=True,
scale=1,
)
load_precision = gr.Radio(
label='Load precision',
choices=['fp16', 'bf16', 'float'],
value='fp16',
interactive=True, scale=1,
label="Load precision",
choices=["fp16", "bf16", "float"],
value="fp16",
interactive=True,
scale=1,
)
model_tuned.change(
@ -256,7 +273,7 @@ def gradio_extract_lora_tab(headless=False):
dim = gr.Slider(
minimum=4,
maximum=1024,
label='Network Dimension (Rank)',
label="Network Dimension (Rank)",
value=128,
step=1,
interactive=True,
@ -264,13 +281,13 @@ def gradio_extract_lora_tab(headless=False):
conv_dim = gr.Slider(
minimum=0,
maximum=1024,
label='Conv Dimension (Rank)',
label="Conv Dimension (Rank)",
value=128,
step=1,
interactive=True,
)
clamp_quantile = gr.Number(
label='Clamp Quantile',
label="Clamp Quantile",
value=0.99,
minimum=0,
maximum=1,
@ -278,7 +295,7 @@ def gradio_extract_lora_tab(headless=False):
interactive=True,
)
min_diff = gr.Number(
label='Minimum difference',
label="Minimum difference",
value=0.01,
minimum=0,
maximum=1,
@ -286,21 +303,25 @@ def gradio_extract_lora_tab(headless=False):
interactive=True,
)
with gr.Row():
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True)
v2 = gr.Checkbox(label="v2", value=False, interactive=True)
sdxl = gr.Checkbox(label="SDXL", value=False, interactive=True)
device = gr.Radio(
label='Device',
label="Device",
choices=[
'cpu',
'cuda',
"cpu",
"cuda",
],
value='cuda',
value="cuda",
interactive=True,
)
sdxl.change(change_sdxl, inputs=sdxl, outputs=[load_tuned_model_to, load_original_model_to])
extract_button = gr.Button('Extract LoRA model')
sdxl.change(
change_sdxl,
inputs=sdxl,
outputs=[load_tuned_model_to, load_original_model_to],
)
extract_button = gr.Button("Extract LoRA model")
extract_button.click(
extract_lora,

View File

@ -5,7 +5,6 @@ import os
import sys
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
scriptdir,
list_files,
@ -74,7 +73,7 @@ def extract_lycoris_locon(
path, ext = os.path.splitext(output_name)
output_name = f"{path}_tmp{ext}"
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/lycoris_locon_extract.py"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lycoris_locon_extract.py"'
if is_sdxl:
run_cmd += f" --is_sdxl"
if is_v2:
@ -99,19 +98,21 @@ def extract_lycoris_locon(
run_cmd += f" --sparsity {sparsity}"
if disable_cp:
run_cmd += f" --disable_cp"
run_cmd += fr' "{base_model}"'
run_cmd += fr' "{db_model}"'
run_cmd += fr' "{output_name}"'
run_cmd += rf' "{base_model}"'
run_cmd += rf' "{db_model}"'
run_cmd += rf' "{output_name}"'
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
log.info('Done extracting...')
log.info("Done extracting...")
###
@ -185,11 +186,16 @@ def gradio_extract_lycoris_locon_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(db_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
db_model,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_db_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=['tool'],
elem_classes=["tool"],
visible=(not headless),
)
button_db_model_file.click(
@ -205,11 +211,16 @@ def gradio_extract_lycoris_locon_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(base_model, lambda: None, lambda: {"choices": list_base_models(current_base_model_dir)}, "open_folder_small")
create_refresh_button(
base_model,
lambda: None,
lambda: {"choices": list_base_models(current_base_model_dir)},
"open_folder_small",
)
button_base_model_file = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=['tool'],
elem_classes=["tool"],
visible=(not headless),
)
button_base_model_file.click(
@ -227,11 +238,16 @@ def gradio_extract_lycoris_locon_tab(headless=False):
allow_custom_value=True,
scale=2,
)
create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
output_name,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_output_name = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=['tool'],
elem_classes=["tool"],
visible=(not headless),
)
button_output_name.click(
@ -270,7 +286,9 @@ def gradio_extract_lycoris_locon_tab(headless=False):
show_progress=False,
)
is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True, scale=1)
is_sdxl = gr.Checkbox(
label="is SDXL", value=False, interactive=True, scale=1
)
is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True, scale=1)
with gr.Row():

View File

@ -4,7 +4,6 @@ import math
import os
import subprocess
import sys
import pathlib
from datetime import datetime
from .common_gui import (
get_file_path,
@ -50,7 +49,8 @@ document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
presets_dir = fr'{scriptdir}/presets'
presets_dir = rf"{scriptdir}/presets"
def save_configuration(
save_as,
@ -319,7 +319,7 @@ def open_configuration(
# Check if we are "applying" a preset or a config
if apply_preset:
log.info(f"Applying preset {training_preset}...")
file_path = fr'{presets_dir}/finetune/{training_preset}.json'
file_path = rf"{presets_dir}/finetune/{training_preset}.json"
else:
# If not applying a preset, set the `training_preset` field to an empty string
# Find the index of the `training_preset` parameter using the `index()` method
@ -482,32 +482,38 @@ def train_model(
logging_dir=logging_dir,
log_tracker_config=log_tracker_config,
resume=resume,
dataset_config=dataset_config
dataset_config=dataset_config,
):
return
if not print_only_bool and check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
if not print_only_bool and check_if_model_exist(
output_name, output_dir, save_model_as, headless_bool
):
return
if dataset_config:
log.info("Dataset config toml file used, skipping caption json file, image buckets, total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps creation...")
log.info(
"Dataset config toml file used, skipping caption json file, image buckets, total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps creation..."
)
else:
# create caption json file
if generate_caption_database:
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"'
if caption_extension == "":
run_cmd += f' --caption_extension=".caption"'
else:
run_cmd += f" --caption_extension={caption_extension}"
run_cmd += fr' "{image_folder}"'
run_cmd += fr' "{train_dir}/{caption_metadata_filename}"'
run_cmd += rf' "{image_folder}"'
run_cmd += rf' "{train_dir}/{caption_metadata_filename}"'
if full_path:
run_cmd += f" --full_path"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
if not print_only_bool:
# Run the command
@ -515,11 +521,11 @@ def train_model(
# create images buckets
if generate_image_buckets:
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"'
run_cmd += fr' "{image_folder}"'
run_cmd += fr' "{train_dir}/{caption_metadata_filename}"'
run_cmd += fr' "{train_dir}/{latent_metadata_filename}"'
run_cmd += fr' "{pretrained_model_name_or_path}"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"'
run_cmd += rf' "{image_folder}"'
run_cmd += rf' "{train_dir}/{caption_metadata_filename}"'
run_cmd += rf' "{train_dir}/{latent_metadata_filename}"'
run_cmd += rf' "{pretrained_model_name_or_path}"'
run_cmd += f" --batch_size={batch_size}"
run_cmd += f" --max_resolution={max_resolution}"
run_cmd += f" --min_bucket_reso={min_bucket_reso}"
@ -530,13 +536,17 @@ def train_model(
if full_path:
run_cmd += f" --full_path"
if sdxl_checkbox and sdxl_no_half_vae:
log.info("Using mixed_precision = no because no half vae is selected...")
log.info(
"Using mixed_precision = no because no half vae is selected..."
)
run_cmd += f' --mixed_precision="no"'
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
if not print_only_bool:
# Run the command
@ -591,14 +601,14 @@ def train_model(
)
if sdxl_checkbox:
run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train.py"'
run_cmd += rf' "{scriptdir}/sd-scripts/sdxl_train.py"'
else:
run_cmd += fr' "{scriptdir}/sd-scripts/fine_tune.py"'
run_cmd += rf' "{scriptdir}/sd-scripts/fine_tune.py"'
in_json = (
fr"{train_dir}/{latent_metadata_filename}"
rf"{train_dir}/{latent_metadata_filename}"
if use_latent_files == "Yes"
else fr"{train_dir}/{caption_metadata_filename}"
else rf"{train_dir}/{caption_metadata_filename}"
)
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
@ -732,7 +742,9 @@ def train_model(
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
executor.execute_command(run_cmd=run_cmd, env=env)
@ -744,12 +756,14 @@ def finetune_tab(headless=False, config: dict = {}):
dummy_headless = gr.Label(value=headless, visible=False)
with gr.Tab("Training"), gr.Column(variant="compact"):
gr.Markdown("Train a custom model using kohya finetune python code...")
with gr.Accordion("Accelerate launch", open=False), gr.Column():
accelerate_launch = AccelerateLaunch()
with gr.Column():
source_model = SourceModel(headless=headless, finetuning=True, config=config)
source_model = SourceModel(
headless=headless, finetuning=True, config=config
)
image_folder = source_model.train_data_dir
output_name = source_model.output_name
@ -803,14 +817,17 @@ def finetune_tab(headless=False, config: dict = {}):
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
with gr.Row():
gradient_accumulation_steps = gr.Number(
label="Gradient accumulate steps", value="1",
label="Gradient accumulate steps",
value="1",
)
block_lr = gr.Textbox(
label="Block LR (SDXL)",
placeholder="(Optional)",
info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3",
)
advanced_training = AdvancedTraining(headless=headless, finetuning=True, config=config)
advanced_training = AdvancedTraining(
headless=headless, finetuning=True, config=config
)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
@ -866,7 +883,6 @@ def finetune_tab(headless=False, config: dict = {}):
with gr.Accordion("Configuration", open=False):
configuration = ConfigurationFile(headless=headless, config=config)
with gr.Column(), gr.Group():
with gr.Row():
button_run = gr.Button("Start training", variant="primary")
@ -1005,7 +1021,9 @@ def finetune_tab(headless=False, config: dict = {}):
inputs=[dummy_db_true, dummy_db_false, configuration.config_file_name]
+ settings_list
+ [training_preset],
outputs=[configuration.config_file_name] + settings_list + [training_preset],
outputs=[configuration.config_file_name]
+ settings_list
+ [training_preset],
show_progress=False,
)
@ -1021,7 +1039,9 @@ def finetune_tab(headless=False, config: dict = {}):
inputs=[dummy_db_false, dummy_db_false, configuration.config_file_name]
+ settings_list
+ [training_preset],
outputs=[configuration.config_file_name] + settings_list + [training_preset],
outputs=[configuration.config_file_name]
+ settings_list
+ [training_preset],
show_progress=False,
)
@ -1062,16 +1082,16 @@ def finetune_tab(headless=False, config: dict = {}):
show_progress=False,
)
#config.button_save_as_config.click(
# config.button_save_as_config.click(
# save_configuration,
# inputs=[dummy_db_true, config.config_file_name] + settings_list,
# outputs=[config.config_file_name],
# show_progress=False,
#)
# )
with gr.Tab("Guides"):
gr.Markdown("This section provide Various Finetuning guides and information...")
top_level_path = fr"{scriptdir}/docs/Finetuning/top_level.md"
top_level_path = rf"{scriptdir}/docs/Finetuning/top_level.md"
if os.path.exists(top_level_path):
with open(os.path.join(top_level_path), "r", encoding="utf8") as file:
guides_top_level = file.read() + "\n"

View File

@ -24,31 +24,31 @@ def caption_images(
postfix,
):
# Check for images_dir_input
if train_data_dir == '':
msgbox('Image folder is missing...')
if train_data_dir == "":
msgbox("Image folder is missing...")
return
if caption_ext == '':
msgbox('Please provide an extension for the caption files.')
if caption_ext == "":
msgbox("Please provide an extension for the caption files.")
return
log.info(f'GIT captioning files in {train_data_dir}...')
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"'
if not model_id == '':
log.info(f"GIT captioning files in {train_data_dir}...")
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"'
if not model_id == "":
run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += (
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
)
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
run_cmd += f' --max_length="{int(max_length)}"'
if caption_ext != '':
if caption_ext != "":
run_cmd += f' --caption_extension="{caption_ext}"'
run_cmd += f' "{train_data_dir}"'
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
@ -61,7 +61,7 @@ def caption_images(
postfix=postfix,
)
log.info('...captioning done')
log.info("...captioning done")
###
@ -72,7 +72,11 @@ def caption_images(
def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
from .common_gui import create_refresh_button
default_train_dir = default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data")
default_train_dir = (
default_train_dir
if default_train_dir is not None
else os.path.join(scriptdir, "data")
)
current_train_dir = default_train_dir
def list_train_dirs(path):
@ -80,21 +84,29 @@ def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
current_train_dir = path
return list(list_dirs(path))
with gr.Tab('GIT Captioning'):
with gr.Tab("GIT Captioning"):
gr.Markdown(
'This utility will use GIT to caption files for each images in a folder.'
"This utility will use GIT to caption files for each images in a folder."
)
with gr.Group(), gr.Row():
train_data_dir = gr.Dropdown(
label='Image folder to caption (containing the images to caption)',
label="Image folder to caption (containing the images to caption)",
choices=[""] + list_train_dirs(default_train_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(current_train_dir)},"open_folder_small")
create_refresh_button(
train_data_dir,
lambda: None,
lambda: {"choices": list_train_dirs(current_train_dir)},
"open_folder_small",
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_train_data_dir_input.click(
get_folder_path,
@ -103,42 +115,38 @@ def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
)
with gr.Row():
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file (e.g., .caption, .txt)',
value='.txt',
label="Caption file extension",
placeholder="Extension for caption file (e.g., .caption, .txt)",
value=".txt",
interactive=True,
)
prefix = gr.Textbox(
label='Prefix to add to GIT caption',
placeholder='(Optional)',
label="Prefix to add to GIT caption",
placeholder="(Optional)",
interactive=True,
)
postfix = gr.Textbox(
label='Postfix to add to GIT caption',
placeholder='(Optional)',
label="Postfix to add to GIT caption",
placeholder="(Optional)",
interactive=True,
)
batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
batch_size = gr.Number(value=1, label="Batch size", interactive=True)
with gr.Row():
max_data_loader_n_workers = gr.Number(
value=2, label='Number of workers', interactive=True
)
max_length = gr.Number(
value=75, label='Max length', interactive=True
value=2, label="Number of workers", interactive=True
)
max_length = gr.Number(value=75, label="Max length", interactive=True)
model_id = gr.Textbox(
label='Model',
placeholder='(Optional) model id for GIT in Hugging Face',
label="Model",
placeholder="(Optional) model id for GIT in Hugging Face",
interactive=True,
)
caption_button = gr.Button('Caption images')
caption_button = gr.Button("Caption images")
caption_button.click(
caption_images,

View File

@ -22,38 +22,40 @@ def group_images(
generate_captions,
caption_ext,
):
if input_folder == '':
msgbox('Input folder is missing...')
if input_folder == "":
msgbox("Input folder is missing...")
return
if output_folder == '':
msgbox('Please provide an output folder.')
if output_folder == "":
msgbox("Please provide an output folder.")
return
log.info(f'Grouping images in {input_folder}...')
log.info(f"Grouping images in {input_folder}...")
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/group_images.py"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/group_images.py"'
run_cmd += f' "{input_folder}"'
run_cmd += f' "{output_folder}"'
run_cmd += f' {(group_size)}'
run_cmd += f" {(group_size)}"
if include_subfolders:
run_cmd += f' --include_subfolders'
run_cmd += f" --include_subfolders"
if do_not_copy_other_files:
run_cmd += f' --do_not_copy_other_files'
run_cmd += f" --do_not_copy_other_files"
if generate_captions:
run_cmd += f' --caption'
run_cmd += f" --caption"
if caption_ext:
run_cmd += f' --caption_ext={caption_ext}'
run_cmd += f" --caption_ext={caption_ext}"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
log.info('...grouping done')
log.info("...grouping done")
def gradio_group_images_gui_tab(headless=False):
@ -72,22 +74,30 @@ def gradio_group_images_gui_tab(headless=False):
current_output_folder = path
return list(list_dirs(path))
with gr.Tab('Group Images'):
with gr.Tab("Group Images"):
gr.Markdown(
'This utility will group images in a folder based on their aspect ratio.'
"This utility will group images in a folder based on their aspect ratio."
)
with gr.Group(), gr.Row():
input_folder = gr.Dropdown(
label='Input folder (containing the images to group)',
label="Input folder (containing the images to group)",
interactive=True,
choices=[""] + list_input_dirs(current_input_folder),
value="",
allow_custom_value=True,
)
create_refresh_button(input_folder, lambda: None, lambda: {"choices": list_input_dirs(current_input_folder)},"open_folder_small")
create_refresh_button(
input_folder,
lambda: None,
lambda: {"choices": list_input_dirs(current_input_folder)},
"open_folder_small",
)
button_input_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_input_folder.click(
get_folder_path,
@ -96,15 +106,23 @@ def gradio_group_images_gui_tab(headless=False):
)
output_folder = gr.Dropdown(
label='Output folder (where the grouped images will be stored)',
label="Output folder (where the grouped images will be stored)",
interactive=True,
choices=[""] + list_output_dirs(current_output_folder),
value="",
allow_custom_value=True,
)
create_refresh_button(output_folder, lambda: None, lambda: {"choices": list_output_dirs(current_output_folder)},"open_folder_small")
create_refresh_button(
output_folder,
lambda: None,
lambda: {"choices": list_output_dirs(current_output_folder)},
"open_folder_small",
)
button_output_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_output_folder.click(
get_folder_path,
@ -126,9 +144,9 @@ def gradio_group_images_gui_tab(headless=False):
)
with gr.Row():
group_size = gr.Slider(
label='Group size',
info='Number of images to group together',
value='4',
label="Group size",
info="Number of images to group together",
value="4",
minimum=1,
maximum=64,
step=1,
@ -136,31 +154,31 @@ def gradio_group_images_gui_tab(headless=False):
)
include_subfolders = gr.Checkbox(
label='Include Subfolders',
label="Include Subfolders",
value=False,
info='Include images in subfolders as well',
info="Include images in subfolders as well",
)
do_not_copy_other_files = gr.Checkbox(
label='Do not copy other files',
label="Do not copy other files",
value=False,
info='Do not copy other files in the input folder to the output folder',
info="Do not copy other files in the input folder to the output folder",
)
generate_captions = gr.Checkbox(
label='Generate Captions',
label="Generate Captions",
value=False,
info='Generate caption files for the grouped images based on their folder name',
info="Generate caption files for the grouped images based on their folder name",
)
caption_ext = gr.Textbox(
label='Caption Extension',
placeholder='Caption file extension (e.g., .txt)',
value='.txt',
label="Caption Extension",
placeholder="Caption file extension (e.g., .txt)",
value=".txt",
interactive=True,
)
group_images_button = gr.Button('Group images')
group_images_button = gr.Button("Group images")
group_images_button.click(
group_images,

View File

@ -7,7 +7,7 @@ localizationMap = {}
def load_localizations():
localizationMap.clear()
dirname = './localizations'
dirname = "./localizations"
for file in os.listdir(dirname):
fn, ext = os.path.splitext(file)
if ext.lower() != ".json":
@ -28,4 +28,4 @@ def load_language_js(language_name: str) -> str:
return f"window.localization = {json.dumps(data)}"
load_localizations()
load_localizations()

View File

@ -4,12 +4,14 @@ import kohya_gui.localization as localization
def file_path(fn):
return f'file={os.path.abspath(fn)}?{os.path.getmtime(fn)}'
return f"file={os.path.abspath(fn)}?{os.path.getmtime(fn)}"
def js_html_str(language):
head = f'<script type="text/javascript">{localization.load_language_js(language)}</script>\n'
head += f'<script type="text/javascript" src="{file_path("js/script.js")}"></script>\n'
head += (
f'<script type="text/javascript" src="{file_path("js/script.js")}"></script>\n'
)
head += f'<script type="text/javascript" src="{file_path("js/localization.js")}"></script>\n'
return head
@ -22,12 +24,12 @@ def add_javascript(language):
def template_response(*args, **kwargs):
res = localization.GrRoutesTemplateResponse(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{jsStr}</head>'.encode("utf8"))
res.body = res.body.replace(b"</head>", f"{jsStr}</head>".encode("utf8"))
res.init_headers()
return res
gr.routes.templates.TemplateResponse = template_response
if not hasattr(localization, 'GrRoutesTemplateResponse'):
if not hasattr(localization, "GrRoutesTemplateResponse"):
localization.GrRoutesTemplateResponse = gr.routes.templates.TemplateResponse

View File

@ -905,22 +905,21 @@ def train_model(
)
# Convert learning rates to float once and store the result for re-use
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
text_encoder_lr_float = float(text_encoder_lr) if text_encoder_lr is not None else 0.0
text_encoder_lr_float = (
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
)
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
# Determine the training configuration based on learning rate values
# Sets flags for training specific components based on the provided learning rates.
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
output_message(
msg="Please input learning rate values.", headless=headless_bool
)
output_message(msg="Please input learning rate values.", headless=headless_bool)
return
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
# Define a dictionary of parameters
run_cmd_params = {
"adaptive_noise_scale": adaptive_noise_scale,
@ -1087,10 +1086,10 @@ def lora_tab(
gr.Markdown(
"Train a custom model using kohya train network LoRA python code..."
)
with gr.Accordion("Accelerate launch", open=False), gr.Column():
accelerate_launch = AccelerateLaunch()
with gr.Column():
source_model = SourceModel(
save_model_as_choices=[
@ -1124,7 +1123,7 @@ def lora_tab(
json_files.append(os.path.join("user_presets", preset_name))
return json_files
with gr.Accordion("Basic", open="True"):
training_preset = gr.Dropdown(
label="Presets",

View File

@ -11,7 +11,7 @@ from .custom_logging import setup_logging
log = setup_logging()
IMAGES_TO_SHOW = 5
IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.webp', '.bmp')
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
auto_save = True
@ -28,7 +28,7 @@ def _get_quick_tags(quick_tags_text):
"""
Gets a list of tags from the quick tags text box
"""
quick_tags = [t.strip() for t in quick_tags_text.split(',') if t.strip()]
quick_tags = [t.strip() for t in quick_tags_text.split(",") if t.strip()]
quick_tags_set = set(quick_tags)
return quick_tags, quick_tags_set
@ -38,34 +38,31 @@ def _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set):
Updates a list of caption checkboxes to show possible tags and tags
already included in the caption
"""
caption_tags_have = [c.strip() for c in caption.split(',') if c.strip()]
caption_tags_unique = [
t for t in caption_tags_have if t not in quick_tags_set
]
caption_tags_have = [c.strip() for c in caption.split(",") if c.strip()]
caption_tags_unique = [t for t in caption_tags_have if t not in quick_tags_set]
caption_tags_all = quick_tags + caption_tags_unique
return gr.CheckboxGroup(
choices=caption_tags_all, value=caption_tags_have
)
return gr.CheckboxGroup(choices=caption_tags_all, value=caption_tags_have)
def paginate_go(page, max_page):
try:
page = float(page)
except:
msgbox(f'Invalid page num: {page}')
msgbox(f"Invalid page num: {page}")
return
return paginate(page, max_page, 0)
def paginate(page, max_page, page_change):
return int(max(min(page + page_change, max_page), 1))
def save_caption(caption, caption_ext, image_file, images_dir):
caption_path = _get_caption_path(image_file, images_dir, caption_ext)
with open(caption_path, 'w+', encoding='utf8') as f:
with open(caption_path, "w+", encoding="utf8") as f:
f.write(caption)
log.info(f'Wrote captions to {caption_path}')
log.info(f"Wrote captions to {caption_path}")
def update_quick_tags(quick_tags_text, *image_caption_texts):
@ -101,7 +98,7 @@ def update_image_tags(
output_tags = [t for t in quick_tags if t in selected_tags_set] + [
t for t in selected_tags if t not in quick_tags_set
]
caption = ', '.join(output_tags)
caption = ", ".join(output_tags)
if auto_save:
save_caption(caption, caption_ext, image_file, images_dir)
@ -122,50 +119,46 @@ def import_tags_from_captions(
# Check for images_dir
if not images_dir:
msgbox('Image folder is missing...')
msgbox("Image folder is missing...")
return empty_return()
if not os.path.exists(images_dir):
msgbox('Image folder does not exist...')
msgbox("Image folder does not exist...")
return empty_return()
if not caption_ext:
msgbox('Please provide an extension for the caption files.')
msgbox("Please provide an extension for the caption files.")
return empty_return()
if quick_tags_text:
if not boolbox(
f'Are you sure you wish to overwrite the current quick tags?',
choices=('Yes', 'No'),
f"Are you sure you wish to overwrite the current quick tags?",
choices=("Yes", "No"),
):
return empty_return()
images_list = os.listdir(images_dir)
image_files = [
f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)
]
image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
# Use a set for lookup but store order with list
tags = []
tags_set = set()
for image_file in image_files:
caption_file_path = _get_caption_path(
image_file, images_dir, caption_ext
)
caption_file_path = _get_caption_path(image_file, images_dir, caption_ext)
if os.path.exists(caption_file_path):
with open(caption_file_path, 'r', encoding='utf8') as f:
with open(caption_file_path, "r", encoding="utf8") as f:
caption = f.read()
for tag in caption.split(','):
for tag in caption.split(","):
tag = tag.strip()
tag_key = tag.lower()
if not tag_key in tags_set:
# Ignore extra spaces
total_words = len(re.findall(r'\s+', tag)) + 1
total_words = len(re.findall(r"\s+", tag)) + 1
if total_words <= ignore_load_tags_word_count:
tags.append(tag)
tags_set.add(tag_key)
return ', '.join(tags)
return ", ".join(tags)
def load_images(images_dir, caption_ext, loaded_images_dir, page, max_page):
@ -180,15 +173,15 @@ def load_images(images_dir, caption_ext, loaded_images_dir, page, max_page):
# Check for images_dir
if not images_dir:
msgbox('Image folder is missing...')
msgbox("Image folder is missing...")
return empty_return()
if not os.path.exists(images_dir):
msgbox('Image folder does not exist...')
msgbox("Image folder does not exist...")
return empty_return()
if not caption_ext:
msgbox('Please provide an extension for the caption files.')
msgbox("Please provide an extension for the caption files.")
return empty_return()
# Load Images
@ -212,12 +205,10 @@ def update_images(
# Load Images
images_list = os.listdir(images_dir)
image_files = [
f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)
]
image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
# Quick tags
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or '')
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or "")
# Display Images
rows = []
@ -231,22 +222,18 @@ def update_images(
show_row = image_index < len(image_files)
image_path = None
caption = ''
caption = ""
tag_checkboxes = None
if show_row:
image_file = image_files[image_index]
image_path = os.path.join(images_dir, image_file)
caption_file_path = _get_caption_path(
image_file, images_dir, caption_ext
)
caption_file_path = _get_caption_path(image_file, images_dir, caption_ext)
if os.path.exists(caption_file_path):
with open(caption_file_path, 'r', encoding='utf8') as f:
with open(caption_file_path, "r", encoding="utf8") as f:
caption = f.read()
tag_checkboxes = _get_tag_checkbox_updates(
caption, quick_tags, quick_tags_set
)
tag_checkboxes = _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
rows.append(gr.Row(visible=show_row))
image_paths.append(image_path)
captions.append(caption)
@ -266,7 +253,11 @@ def update_images(
def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
from .common_gui import create_refresh_button
default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data")
default_images_dir = (
default_images_dir
if default_images_dir is not None
else os.path.join(scriptdir, "data")
)
current_images_dir = default_images_dir
# Function to list directories
@ -276,39 +267,45 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
current_images_dir = path
return list(list_dirs(path))
with gr.Tab('Manual Captioning'):
gr.Markdown(
'This utility allows quick captioning and tagging of images.'
)
with gr.Tab("Manual Captioning"):
gr.Markdown("This utility allows quick captioning and tagging of images.")
page = gr.Number(-1, visible=False)
max_page = gr.Number(1, visible=False)
loaded_images_dir = gr.Text(visible=False)
with gr.Group(), gr.Row():
images_dir = gr.Dropdown(
label='Image folder to caption (containing the images to caption)',
label="Image folder to caption (containing the images to caption)",
choices=[""] + list_images_dirs(default_images_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small")
create_refresh_button(
images_dir,
lambda: None,
lambda: {"choices": list_images_dirs(current_images_dir)},
"open_folder_small",
)
folder_button = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
folder_button.click(
get_folder_path,
outputs=images_dir,
show_progress=False,
)
load_images_button = gr.Button('Load', elem_id='open_folder')
load_images_button = gr.Button("Load", elem_id="open_folder")
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file (e.g., .caption, .txt)',
value='.txt',
label="Caption file extension",
placeholder="Extension for caption file (e.g., .caption, .txt)",
value=".txt",
interactive=True,
)
auto_save = gr.Checkbox(
label='Autosave', info='Options', value=True, interactive=True
label="Autosave", info="Options", value=True, interactive=True
)
images_dir.change(
@ -321,39 +318,39 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
# Caption Section
with gr.Group(), gr.Row():
quick_tags_text = gr.Textbox(
label='Quick Tags',
placeholder='Comma separated list of tags',
label="Quick Tags",
placeholder="Comma separated list of tags",
interactive=True,
)
import_tags_button = gr.Button('Import', elem_id='open_folder')
import_tags_button = gr.Button("Import", elem_id="open_folder")
ignore_load_tags_word_count = gr.Slider(
minimum=1,
maximum=100,
value=3,
step=1,
label='Ignore Imported Tags Above Word Count',
label="Ignore Imported Tags Above Word Count",
interactive=True,
)
# Next/Prev section generator
def render_pagination():
gr.Button('< Prev', elem_id='open_folder').click(
gr.Button("< Prev", elem_id="open_folder").click(
paginate,
inputs=[page, max_page, gr.Number(-1, visible=False)],
outputs=[page],
)
page_count = gr.Label('Page 1', label='Page')
page_count = gr.Label("Page 1", label="Page")
page_goto_text = gr.Textbox(
label='Goto page',
placeholder='Page Number',
label="Goto page",
placeholder="Page Number",
interactive=True,
)
gr.Button('Go >', elem_id='open_folder').click(
gr.Button("Go >", elem_id="open_folder").click(
paginate_go,
inputs=[page_goto_text, max_page],
outputs=[page],
)
gr.Button('Next >', elem_id='open_folder').click(
gr.Button("Next >", elem_id="open_folder").click(
paginate,
inputs=[page, max_page, gr.Number(1, visible=False)],
outputs=[page],
@ -374,19 +371,20 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
with gr.Row(visible=False) as row:
image_file = gr.Text(visible=False)
image_files.append(image_file)
image_image = gr.Image(type='filepath')
image_image = gr.Image(type="filepath")
image_images.append(image_image)
image_caption_text = gr.TextArea(
label='Captions',
placeholder='Input captions',
label="Captions",
placeholder="Input captions",
interactive=True,
)
image_caption_texts.append(image_caption_text)
tag_checkboxes = gr.CheckboxGroup(
[], label='Tags', interactive=True
)
tag_checkboxes = gr.CheckboxGroup([], label="Tags", interactive=True)
save_button = gr.Button(
'💾', elem_id='open_folder_small', elem_classes=['tool'], visible=False
"💾",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=False,
)
save_buttons.append(save_button)
@ -485,9 +483,9 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
)
# Update the key on page and image dir change
listener_kwargs = {
'fn': lambda p, i: f'{p}-{i}',
'inputs': [page, loaded_images_dir],
'outputs': image_update_key,
"fn": lambda p, i: f"{p}-{i}",
"inputs": [page, loaded_images_dir],
"outputs": image_update_key,
}
page.change(**listener_kwargs)
loaded_images_dir.change(**listener_kwargs)
@ -495,15 +493,14 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
# Save buttons visibility
# (on auto-save on/off)
auto_save.change(
lambda auto_save: [gr.Button(visible=not auto_save)]
* IMAGES_TO_SHOW,
lambda auto_save: [gr.Button(visible=not auto_save)] * IMAGES_TO_SHOW,
inputs=auto_save,
outputs=save_buttons,
)
# Page Count
page.change(
lambda page, max_page: [f'Page {int(page)} / {int(max_page)}'] * 2,
lambda page, max_page: [f"Page {int(page)} / {int(max_page)}"] * 2,
inputs=[page, max_page],
outputs=[page_count1, page_count2],
show_progress=False,

View File

@ -9,16 +9,22 @@ import gradio as gr
from easygui import msgbox
# Local module imports
from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button
from .common_gui import (
get_saveasfilename_path,
get_file_path,
scriptdir,
list_files,
create_refresh_button,
)
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -27,7 +33,7 @@ def check_model(model):
if not model:
return True
if not os.path.isfile(model):
msgbox(f'The provided {model} is not a file')
msgbox(f"The provided {model} is not a file")
return False
return True
@ -47,14 +53,14 @@ class GradioMergeLoRaTab:
self.build_tab()
def save_inputs_to_json(self, file_path, inputs):
with open(file_path, 'w') as file:
with open(file_path, "w") as file:
json.dump(inputs, file)
log.info(f'Saved inputs to {file_path}')
log.info(f"Saved inputs to {file_path}")
def load_inputs_from_json(self, file_path):
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
inputs = json.load(file)
log.info(f'Loaded inputs from {file_path}')
log.info(f"Loaded inputs from {file_path}")
return inputs
def build_tab(self):
@ -95,29 +101,34 @@ class GradioMergeLoRaTab:
current_save_dir = path
return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
with gr.Tab('Merge LoRA'):
with gr.Tab("Merge LoRA"):
gr.Markdown(
'This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint.'
"This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint."
)
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
ckpt_ext_name = gr.Textbox(value='SD model types', visible=False)
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
ckpt_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
ckpt_ext_name = gr.Textbox(value="SD model types", visible=False)
with gr.Group(), gr.Row():
sd_model = gr.Dropdown(
label='SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)',
label="SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)",
interactive=True,
choices=[""] + list_sd_models(current_sd_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(sd_model, lambda: None, lambda: {"choices": list_sd_models(current_sd_model_dir)}, "open_folder_small")
create_refresh_button(
sd_model,
lambda: None,
lambda: {"choices": list_sd_models(current_sd_model_dir)},
"open_folder_small",
)
sd_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
sd_model_file.click(
@ -126,7 +137,7 @@ class GradioMergeLoRaTab:
outputs=sd_model,
show_progress=False,
)
sdxl_model = gr.Checkbox(label='SDXL model', value=False)
sdxl_model = gr.Checkbox(label="SDXL model", value=False)
sd_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_sd_models(path)),
@ -143,11 +154,16 @@ class GradioMergeLoRaTab:
value="",
allow_custom_value=True,
)
create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small")
create_refresh_button(
lora_a_model,
lambda: None,
lambda: {"choices": list_a_models(current_a_model_dir)},
"open_folder_small",
)
button_lora_a_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_a_model_file.click(
@ -164,11 +180,16 @@ class GradioMergeLoRaTab:
value="",
allow_custom_value=True,
)
create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small")
create_refresh_button(
lora_b_model,
lambda: None,
lambda: {"choices": list_b_models(current_b_model_dir)},
"open_folder_small",
)
button_lora_b_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_b_model_file.click(
@ -193,7 +214,7 @@ class GradioMergeLoRaTab:
with gr.Row():
ratio_a = gr.Slider(
label='Model A merge ratio (eg: 0.5 mean 50%)',
label="Model A merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=1,
step=0.01,
@ -202,7 +223,7 @@ class GradioMergeLoRaTab:
)
ratio_b = gr.Slider(
label='Model B merge ratio (eg: 0.5 mean 50%)',
label="Model B merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=1,
step=0.01,
@ -218,11 +239,16 @@ class GradioMergeLoRaTab:
value="",
allow_custom_value=True,
)
create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small")
create_refresh_button(
lora_c_model,
lambda: None,
lambda: {"choices": list_c_models(current_c_model_dir)},
"open_folder_small",
)
button_lora_c_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_c_model_file.click(
@ -239,11 +265,16 @@ class GradioMergeLoRaTab:
value="",
allow_custom_value=True,
)
create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small")
create_refresh_button(
lora_d_model,
lambda: None,
lambda: {"choices": list_d_models(current_d_model_dir)},
"open_folder_small",
)
button_lora_d_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_lora_d_model_file.click(
@ -267,7 +298,7 @@ class GradioMergeLoRaTab:
with gr.Row():
ratio_c = gr.Slider(
label='Model C merge ratio (eg: 0.5 mean 50%)',
label="Model C merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=1,
step=0.01,
@ -276,7 +307,7 @@ class GradioMergeLoRaTab:
)
ratio_d = gr.Slider(
label='Model D merge ratio (eg: 0.5 mean 50%)',
label="Model D merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=1,
step=0.01,
@ -286,17 +317,22 @@ class GradioMergeLoRaTab:
with gr.Group(), gr.Row():
save_to = gr.Dropdown(
label='Save to (path for the file to save...)',
label="Save to (path for the file to save...)",
interactive=True,
choices=[""] + list_save_to(current_d_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_save_to = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
button_save_to.click(
@ -306,15 +342,15 @@ class GradioMergeLoRaTab:
show_progress=False,
)
precision = gr.Radio(
label='Merge precision',
choices=['fp16', 'bf16', 'float'],
value='float',
label="Merge precision",
choices=["fp16", "bf16", "float"],
value="float",
interactive=True,
)
save_precision = gr.Radio(
label='Save precision',
choices=['fp16', 'bf16', 'float'],
value='fp16',
label="Save precision",
choices=["fp16", "bf16", "float"],
value="fp16",
interactive=True,
)
@ -325,7 +361,7 @@ class GradioMergeLoRaTab:
show_progress=False,
)
merge_button = gr.Button('Merge model')
merge_button = gr.Button("Merge model")
merge_button.click(
self.merge_lora,
@ -364,7 +400,7 @@ class GradioMergeLoRaTab:
save_precision,
):
log.info('Merge model...')
log.info("Merge model...")
models = [
sd_model,
lora_a_model,
@ -377,7 +413,7 @@ class GradioMergeLoRaTab:
if not verify_conditions(sd_model, lora_models):
log.info(
'Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided.'
"Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided."
)
return
@ -386,36 +422,36 @@ class GradioMergeLoRaTab:
return
if not sdxl_model:
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/merge_lora.py"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/merge_lora.py"'
else:
run_cmd = (
fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"'
rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"'
)
if sd_model:
run_cmd += fr' --sd_model "{sd_model}"'
run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --precision {precision}'
run_cmd += fr' --save_to "{save_to}"'
run_cmd += rf' --sd_model "{sd_model}"'
run_cmd += f" --save_precision {save_precision}"
run_cmd += f" --precision {precision}"
run_cmd += rf' --save_to "{save_to}"'
# Create a space-separated string of non-empty models (from the second element onwards), enclosed in double quotes
models_cmd = ' '.join([fr'"{model}"' for model in lora_models if model])
models_cmd = " ".join([rf'"{model}"' for model in lora_models if model])
# Create a space-separated string of non-zero ratios corresponding to non-empty LoRa models
valid_ratios = [
ratios[i] for i, model in enumerate(lora_models) if model
]
ratios_cmd = ' '.join([str(ratio) for ratio in valid_ratios])
valid_ratios = [ratios[i] for i, model in enumerate(lora_models) if model]
ratios_cmd = " ".join([str(ratio) for ratio in valid_ratios])
if models_cmd:
run_cmd += f' --models {models_cmd}'
run_cmd += f' --ratios {ratios_cmd}'
run_cmd += f" --models {models_cmd}"
run_cmd += f" --ratios {ratios_cmd}"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
log.info('Done merging...')
log.info("Done merging...")

View File

@ -16,10 +16,10 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -34,29 +34,31 @@ def merge_lycoris(
is_sdxl,
is_v2,
):
log.info('Merge model...')
log.info("Merge model...")
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/merge_lycoris.py"'
run_cmd += fr' "{base_model}"'
run_cmd += fr' "{lycoris_model}"'
run_cmd += fr' "{output_name}"'
run_cmd += f' --weight {weight}'
run_cmd += f' --device {device}'
run_cmd += f' --dtype {dtype}'
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/merge_lycoris.py"'
run_cmd += rf' "{base_model}"'
run_cmd += rf' "{lycoris_model}"'
run_cmd += rf' "{output_name}"'
run_cmd += f" --weight {weight}"
run_cmd += f" --device {device}"
run_cmd += f" --dtype {dtype}"
if is_sdxl:
run_cmd += f' --is_sdxl'
run_cmd += f" --is_sdxl"
if is_v2:
run_cmd += f' --is_v2'
run_cmd += f" --is_v2"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
log.info('Done merging...')
log.info("Done merging...")
###
@ -84,30 +86,33 @@ def gradio_merge_lycoris_tab(headless=False):
current_save_dir = path
return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
with gr.Tab('Merge LyCORIS'):
gr.Markdown(
'This utility can merge a LyCORIS model into a SD checkpoint.'
)
with gr.Tab("Merge LyCORIS"):
gr.Markdown("This utility can merge a LyCORIS model into a SD checkpoint.")
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
ckpt_ext_name = gr.Textbox(value='SD model types', visible=False)
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
ckpt_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
ckpt_ext_name = gr.Textbox(value="SD model types", visible=False)
with gr.Group(), gr.Row():
base_model = gr.Dropdown(
label='SD Model (Optional Stable Diffusion base model)',
label="SD Model (Optional Stable Diffusion base model)",
interactive=True,
info='Provide a SD file path that you want to merge with the LyCORIS file',
info="Provide a SD file path that you want to merge with the LyCORIS file",
choices=[""] + list_models(current_save_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(base_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
base_model,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
base_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
base_model_file.click(
@ -118,7 +123,7 @@ def gradio_merge_lycoris_tab(headless=False):
)
lycoris_model = gr.Dropdown(
label='LyCORIS model (path to the LyCORIS model)',
label="LyCORIS model (path to the LyCORIS model)",
interactive=True,
choices=[""] + list_lycoris_model(current_save_dir),
value="",
@ -126,8 +131,8 @@ def gradio_merge_lycoris_tab(headless=False):
)
button_lycoris_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lycoris_model_file.click(
@ -152,7 +157,7 @@ def gradio_merge_lycoris_tab(headless=False):
with gr.Row():
weight = gr.Slider(
label='Model A merge ratio (eg: 0.5 mean 50%)',
label="Model A merge ratio (eg: 0.5 mean 50%)",
minimum=0,
maximum=1,
step=0.01,
@ -162,17 +167,22 @@ def gradio_merge_lycoris_tab(headless=False):
with gr.Group(), gr.Row():
output_name = gr.Dropdown(
label='Save to (path for the checkpoint file to save...)',
label="Save to (path for the checkpoint file to save...)",
interactive=True,
choices=[""] + list_save_to(current_save_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
output_name,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_output_name = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_output_name.click(
@ -182,26 +192,26 @@ def gradio_merge_lycoris_tab(headless=False):
show_progress=False,
)
dtype = gr.Radio(
label='Save dtype',
label="Save dtype",
choices=[
'float',
'float16',
'float32',
'float64',
'bfloat',
'bfloat16',
"float",
"float16",
"float32",
"float64",
"bfloat",
"bfloat16",
],
value='float16',
value="float16",
interactive=True,
)
device = gr.Radio(
label='Device',
label="Device",
choices=[
'cpu',
'cuda',
"cpu",
"cuda",
],
value='cpu',
value="cpu",
interactive=True,
)
@ -213,10 +223,10 @@ def gradio_merge_lycoris_tab(headless=False):
)
with gr.Row():
is_sdxl = gr.Checkbox(label='is SDXL', value=False, interactive=True)
is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True)
is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True)
merge_button = gr.Button('Merge model')
merge_button = gr.Button("Merge model")
merge_button.click(
merge_lycoris,

View File

@ -3,17 +3,23 @@ from easygui import msgbox
import subprocess
import os
import sys
from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button
from .common_gui import (
get_saveasfilename_path,
get_file_path,
scriptdir,
list_files,
create_refresh_button,
)
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -29,57 +35,57 @@ def resize_lora(
verbose,
):
# Check for caption_text_input
if model == '':
msgbox('Invalid model file')
if model == "":
msgbox("Invalid model file")
return
# Check if source model exist
if not os.path.isfile(model):
msgbox('The provided model is not a file')
msgbox("The provided model is not a file")
return
if dynamic_method == 'sv_ratio':
if dynamic_method == "sv_ratio":
if float(dynamic_param) < 2:
msgbox(
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
)
msgbox(f"Dynamic parameter for {dynamic_method} need to be 2 or greater...")
return
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
if dynamic_method == "sv_fro" or dynamic_method == "sv_cumulative":
if float(dynamic_param) < 0 or float(dynamic_param) > 1:
msgbox(
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
f"Dynamic parameter for {dynamic_method} need to be between 0 and 1..."
)
return
# Check if save_to end with one of the defines extension. If not add .safetensors.
if not save_to.endswith(('.pt', '.safetensors')):
save_to += '.safetensors'
if not save_to.endswith((".pt", ".safetensors")):
save_to += ".safetensors"
if device == '':
device = 'cuda'
if device == "":
device = "cuda"
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/resize_lora.py"'
run_cmd += f' --save_precision {save_precision}'
run_cmd += fr' --save_to "{save_to}"'
run_cmd += fr' --model "{model}"'
run_cmd += f' --new_rank {new_rank}'
run_cmd += f' --device {device}'
if not dynamic_method == 'None':
run_cmd += f' --dynamic_method {dynamic_method}'
run_cmd += f' --dynamic_param {dynamic_param}'
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/resize_lora.py"'
run_cmd += f" --save_precision {save_precision}"
run_cmd += rf' --save_to "{save_to}"'
run_cmd += rf' --model "{model}"'
run_cmd += f" --new_rank {new_rank}"
run_cmd += f" --device {device}"
if not dynamic_method == "None":
run_cmd += f" --dynamic_method {dynamic_method}"
run_cmd += f" --dynamic_param {dynamic_param}"
if verbose:
run_cmd += f' --verbose'
run_cmd += f" --verbose"
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
log.info('Done resizing...')
log.info("Done resizing...")
###
@ -101,25 +107,30 @@ def gradio_resize_lora_tab(headless=False):
current_save_dir = path
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
with gr.Tab('Resize LoRA'):
gr.Markdown('This utility can resize a LoRA.')
with gr.Tab("Resize LoRA"):
gr.Markdown("This utility can resize a LoRA.")
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
with gr.Group(), gr.Row():
model = gr.Dropdown(
label='Source LoRA (path to the LoRA to resize)',
label="Source LoRA (path to the LoRA to resize)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
model,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_lora_a_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_a_model_file.click(
@ -129,17 +140,22 @@ def gradio_resize_lora_tab(headless=False):
show_progress=False,
)
save_to = gr.Dropdown(
label='Save to (path for the LoRA file to save...)',
label="Save to (path for the LoRA file to save...)",
interactive=True,
choices=[""] + list_save_to(current_save_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_save_to = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_save_to.click(
@ -162,7 +178,7 @@ def gradio_resize_lora_tab(headless=False):
)
with gr.Row():
new_rank = gr.Slider(
label='Desired LoRA rank',
label="Desired LoRA rank",
minimum=1,
maximum=1024,
step=1,
@ -170,37 +186,37 @@ def gradio_resize_lora_tab(headless=False):
interactive=True,
)
dynamic_method = gr.Radio(
choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'],
value='sv_fro',
label='Dynamic method',
choices=["None", "sv_ratio", "sv_fro", "sv_cumulative"],
value="sv_fro",
label="Dynamic method",
interactive=True,
)
dynamic_param = gr.Textbox(
label='Dynamic parameter',
value='0.9',
label="Dynamic parameter",
value="0.9",
interactive=True,
placeholder='Value for the dynamic method selected.',
placeholder="Value for the dynamic method selected.",
)
with gr.Row():
verbose = gr.Checkbox(label='Verbose logging', value=True)
verbose = gr.Checkbox(label="Verbose logging", value=True)
save_precision = gr.Radio(
label='Save precision',
choices=['fp16', 'bf16', 'float'],
value='fp16',
label="Save precision",
choices=["fp16", "bf16", "float"],
value="fp16",
interactive=True,
)
device = gr.Radio(
label='Device',
label="Device",
choices=[
'cpu',
'cuda',
"cpu",
"cuda",
],
value='cuda',
value="cuda",
interactive=True,
)
convert_button = gr.Button('Resize model')
convert_button = gr.Button("Resize model")
convert_button.click(
resize_lora,

View File

@ -5,7 +5,6 @@ import os
import sys
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
scriptdir,
list_files,
@ -17,10 +16,10 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
@ -53,49 +52,51 @@ def svd_merge_lora(
ratio_c /= total_ratio
ratio_d /= total_ratio
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"'
run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --precision {precision}'
run_cmd += fr' --save_to "{save_to}"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"'
run_cmd += f" --save_precision {save_precision}"
run_cmd += f" --precision {precision}"
run_cmd += rf' --save_to "{save_to}"'
run_cmd_models = ' --models'
run_cmd_ratios = ' --ratios'
run_cmd_models = " --models"
run_cmd_ratios = " --ratios"
# Add non-empty models and their ratios to the command
if lora_a_model:
if not os.path.isfile(lora_a_model):
msgbox('The provided model A is not a file')
msgbox("The provided model A is not a file")
return
run_cmd_models += fr' "{lora_a_model}"'
run_cmd_ratios += f' {ratio_a}'
run_cmd_models += rf' "{lora_a_model}"'
run_cmd_ratios += f" {ratio_a}"
if lora_b_model:
if not os.path.isfile(lora_b_model):
msgbox('The provided model B is not a file')
msgbox("The provided model B is not a file")
return
run_cmd_models += fr' "{lora_b_model}"'
run_cmd_ratios += f' {ratio_b}'
run_cmd_models += rf' "{lora_b_model}"'
run_cmd_ratios += f" {ratio_b}"
if lora_c_model:
if not os.path.isfile(lora_c_model):
msgbox('The provided model C is not a file')
msgbox("The provided model C is not a file")
return
run_cmd_models += fr' "{lora_c_model}"'
run_cmd_ratios += f' {ratio_c}'
run_cmd_models += rf' "{lora_c_model}"'
run_cmd_ratios += f" {ratio_c}"
if lora_d_model:
if not os.path.isfile(lora_d_model):
msgbox('The provided model D is not a file')
msgbox("The provided model D is not a file")
return
run_cmd_models += fr' "{lora_d_model}"'
run_cmd_ratios += f' {ratio_d}'
run_cmd_models += rf' "{lora_d_model}"'
run_cmd_ratios += f" {ratio_d}"
run_cmd += run_cmd_models
run_cmd += run_cmd_ratios
run_cmd += f' --device {device}'
run_cmd += f" --device {device}"
run_cmd += f' --new_rank "{new_rank}"'
run_cmd += f' --new_conv_rank "{new_conv_rank}"'
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
@ -138,13 +139,13 @@ def gradio_svd_merge_lora_tab(headless=False):
current_save_dir = path
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
with gr.Tab('Merge LoRA (SVD)'):
with gr.Tab("Merge LoRA (SVD)"):
gr.Markdown(
'This utility can merge two LoRA networks together into a new LoRA.'
"This utility can merge two LoRA networks together into a new LoRA."
)
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
with gr.Group(), gr.Row():
lora_a_model = gr.Dropdown(
@ -154,11 +155,16 @@ def gradio_svd_merge_lora_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small")
create_refresh_button(
lora_a_model,
lambda: None,
lambda: {"choices": list_a_models(current_a_model_dir)},
"open_folder_small",
)
button_lora_a_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_a_model_file.click(
@ -175,11 +181,16 @@ def gradio_svd_merge_lora_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small")
create_refresh_button(
lora_b_model,
lambda: None,
lambda: {"choices": list_b_models(current_b_model_dir)},
"open_folder_small",
)
button_lora_b_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_b_model_file.click(
@ -202,7 +213,7 @@ def gradio_svd_merge_lora_tab(headless=False):
)
with gr.Row():
ratio_a = gr.Slider(
label='Merge ratio model A',
label="Merge ratio model A",
minimum=0,
maximum=1,
step=0.01,
@ -210,7 +221,7 @@ def gradio_svd_merge_lora_tab(headless=False):
interactive=True,
)
ratio_b = gr.Slider(
label='Merge ratio model B',
label="Merge ratio model B",
minimum=0,
maximum=1,
step=0.01,
@ -225,11 +236,16 @@ def gradio_svd_merge_lora_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small")
create_refresh_button(
lora_c_model,
lambda: None,
lambda: {"choices": list_c_models(current_c_model_dir)},
"open_folder_small",
)
button_lora_c_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_c_model_file.click(
@ -246,11 +262,16 @@ def gradio_svd_merge_lora_tab(headless=False):
value="",
allow_custom_value=True,
)
create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small")
create_refresh_button(
lora_d_model,
lambda: None,
lambda: {"choices": list_d_models(current_d_model_dir)},
"open_folder_small",
)
button_lora_d_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_d_model_file.click(
@ -274,7 +295,7 @@ def gradio_svd_merge_lora_tab(headless=False):
)
with gr.Row():
ratio_c = gr.Slider(
label='Merge ratio model C',
label="Merge ratio model C",
minimum=0,
maximum=1,
step=0.01,
@ -282,7 +303,7 @@ def gradio_svd_merge_lora_tab(headless=False):
interactive=True,
)
ratio_d = gr.Slider(
label='Merge ratio model D',
label="Merge ratio model D",
minimum=0,
maximum=1,
step=0.01,
@ -291,7 +312,7 @@ def gradio_svd_merge_lora_tab(headless=False):
)
with gr.Row():
new_rank = gr.Slider(
label='New Rank',
label="New Rank",
minimum=1,
maximum=1024,
step=1,
@ -299,7 +320,7 @@ def gradio_svd_merge_lora_tab(headless=False):
interactive=True,
)
new_conv_rank = gr.Slider(
label='New Conv Rank',
label="New Conv Rank",
minimum=1,
maximum=1024,
step=1,
@ -309,17 +330,22 @@ def gradio_svd_merge_lora_tab(headless=False):
with gr.Group(), gr.Row():
save_to = gr.Dropdown(
label='Save to (path for the new LoRA file to save...)',
label="Save to (path for the new LoRA file to save...)",
interactive=True,
choices=[""] + list_save_to(current_d_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
create_refresh_button(
save_to,
lambda: None,
lambda: {"choices": list_save_to(current_save_dir)},
"open_folder_small",
)
button_save_to = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_save_to.click(
@ -336,28 +362,28 @@ def gradio_svd_merge_lora_tab(headless=False):
)
with gr.Group(), gr.Row():
precision = gr.Radio(
label='Merge precision',
choices=['fp16', 'bf16', 'float'],
value='float',
label="Merge precision",
choices=["fp16", "bf16", "float"],
value="float",
interactive=True,
)
save_precision = gr.Radio(
label='Save precision',
choices=['fp16', 'bf16', 'float'],
value='float',
label="Save precision",
choices=["fp16", "bf16", "float"],
value="float",
interactive=True,
)
device = gr.Radio(
label='Device',
label="Device",
choices=[
'cpu',
'cuda',
"cpu",
"cuda",
],
value='cuda',
value="cuda",
interactive=True,
)
convert_button = gr.Button('Merge model')
convert_button = gr.Button("Merge model")
convert_button.click(
svd_merge_lora,

View File

@ -10,55 +10,56 @@ from .custom_logging import setup_logging
log = setup_logging()
tensorboard_proc = None
TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe'
TENSORBOARD = "tensorboard" if os.name == "posix" else "tensorboard.exe"
# Set the default tensorboard port
DEFAULT_TENSORBOARD_PORT = 6006
def start_tensorboard(headless, logging_dir, wait_time=5):
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
global tensorboard_proc
headless_bool = True if headless.get('label') == 'True' else False
headless_bool = True if headless.get("label") == "True" else False
# Read the TENSORBOARD_PORT from the environment, or use the default
tensorboard_port = os.environ.get('TENSORBOARD_PORT', DEFAULT_TENSORBOARD_PORT)
tensorboard_port = os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT)
# Check if logging directory exists and is not empty; if not, warn the user and exit
if not os.path.exists(logging_dir) or not os.listdir(logging_dir):
log.error('Error: logging folder does not exist or does not contain logs.')
msgbox(msg='Error: logging folder does not exist or does not contain logs.')
log.error("Error: logging folder does not exist or does not contain logs.")
msgbox(msg="Error: logging folder does not exist or does not contain logs.")
return # Exit the function with an error code
run_cmd = [
TENSORBOARD,
'--logdir',
"--logdir",
logging_dir,
'--host',
'0.0.0.0',
'--port',
"--host",
"0.0.0.0",
"--port",
str(tensorboard_port),
]
log.info(run_cmd)
if tensorboard_proc is not None:
log.info(
'Tensorboard is already running. Terminating existing process before starting new one...'
"Tensorboard is already running. Terminating existing process before starting new one..."
)
stop_tensorboard()
# Start background process
log.info('Starting TensorBoard on port {}'.format(tensorboard_port))
log.info("Starting TensorBoard on port {}".format(tensorboard_port))
try:
# Copy the current environment
env = os.environ.copy()
# Set your specific environment variable
env['TF_ENABLE_ONEDNN_OPTS'] = '0'
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
tensorboard_proc = subprocess.Popen(run_cmd, env=env)
except Exception as e:
log.error('Failed to start Tensorboard:', e)
log.error("Failed to start Tensorboard:", e)
return
if not headless_bool:
@ -66,28 +67,28 @@ def start_tensorboard(headless, logging_dir, wait_time=5):
time.sleep(wait_time)
# Open the TensorBoard URL in the default browser
tensorboard_url = f'http://localhost:{tensorboard_port}'
log.info(f'Opening TensorBoard URL in browser: {tensorboard_url}')
tensorboard_url = f"http://localhost:{tensorboard_port}"
log.info(f"Opening TensorBoard URL in browser: {tensorboard_url}")
webbrowser.open(tensorboard_url)
def stop_tensorboard():
global tensorboard_proc
if tensorboard_proc is not None:
log.info('Stopping tensorboard process...')
log.info("Stopping tensorboard process...")
try:
tensorboard_proc.terminate()
tensorboard_proc = None
log.info('...process stopped')
log.info("...process stopped")
except Exception as e:
log.error('Failed to stop Tensorboard:', e)
log.error("Failed to stop Tensorboard:", e)
else:
log.warning('Tensorboard is not running...')
log.warning("Tensorboard is not running...")
def gradio_tensorboard():
with gr.Row():
button_start_tensorboard = gr.Button('Start tensorboard')
button_stop_tensorboard = gr.Button('Stop tensorboard')
button_start_tensorboard = gr.Button("Start tensorboard")
button_stop_tensorboard = gr.Button("Stop tensorboard")
return (button_start_tensorboard, button_stop_tensorboard)

View File

@ -2,7 +2,6 @@ import gradio as gr
import json
import math
import os
import pathlib
from datetime import datetime
from .common_gui import (
get_file_path,
@ -463,7 +462,9 @@ def train_model(
return
if dataset_config:
log.info("Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations...")
log.info(
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
)
else:
# Get a list of all subfolders in train_data_dir
subfolders = [
@ -526,7 +527,9 @@ def train_model(
log.info(f"max_train_steps = {max_train_steps}")
# calculate stop encoder training
if stop_text_encoder_training_pct == None or (not max_train_steps == "" or not max_train_steps == "0"):
if stop_text_encoder_training_pct == None or (
not max_train_steps == "" or not max_train_steps == "0"
):
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
@ -709,7 +712,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
with gr.Accordion("Accelerate launch", open=False), gr.Column():
accelerate_launch = AccelerateLaunch()
with gr.Column():
source_model = SourceModel(
save_model_as_choices=[
@ -722,7 +725,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
with gr.Accordion("Folders", open=False), gr.Group():
folders = Folders(headless=headless, config=config)
with gr.Accordion("Parameters", open=False), gr.Column():
with gr.Accordion("Basic", open="True"):
with gr.Group(elem_id="basic_tab"):
@ -733,7 +736,9 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
current_embedding_dir = path
return list(
list_files(
path, exts=[".pt", ".ckpt", ".safetensors"], all=True
path,
exts=[".pt", ".ckpt", ".safetensors"],
all=True,
)
)

View File

@ -1,10 +1,4 @@
# v1: initial release
# v2: add open and save folder icons
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities
import gradio as gr
import os
from .basic_caption_gui import gradio_basic_caption_gui_tab
from .convert_model_gui import gradio_convert_model_tab
@ -23,9 +17,9 @@ def utilities_tab(
logging_dir_input=gr.Dropdown(),
enable_copy_info_button=bool(False),
enable_dreambooth_tab=True,
headless=False
headless=False,
):
with gr.Tab('Captioning'):
with gr.Tab("Captioning"):
gradio_basic_caption_gui_tab(headless=headless)
gradio_blip_caption_gui_tab(headless=headless)
gradio_blip2_caption_gui_tab(headless=headless)

View File

@ -4,8 +4,6 @@ import subprocess
import os
import sys
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
scriptdir,
list_files,
@ -17,37 +15,41 @@ from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = sys.executable
def verify_lora(
lora_model,
):
# verify for caption_text_input
if lora_model == '':
msgbox('Invalid model A file')
if lora_model == "":
msgbox("Invalid model A file")
return
# verify if source model exist
if not os.path.isfile(lora_model):
msgbox('The provided model A is not a file')
msgbox("The provided model A is not a file")
return
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"'
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"'
log.info(run_cmd)
env = os.environ.copy()
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
process = subprocess.Popen(
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env,
run_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
output, error = process.communicate()
@ -67,27 +69,32 @@ def gradio_verify_lora_tab(headless=False):
current_model_dir = path
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
with gr.Tab('Verify LoRA'):
with gr.Tab("Verify LoRA"):
gr.Markdown(
'This utility can verify a LoRA network to make sure it is properly trained.'
"This utility can verify a LoRA network to make sure it is properly trained."
)
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
lora_ext = gr.Textbox(value="*.pt *.safetensors", visible=False)
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
with gr.Group(), gr.Row():
lora_model = gr.Dropdown(
label='LoRA model (path to the LoRA model to verify)',
label="LoRA model (path to the LoRA model to verify)",
interactive=True,
choices=[""] + list_models(current_model_dir),
value="",
allow_custom_value=True,
)
create_refresh_button(lora_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
create_refresh_button(
lora_model,
lambda: None,
lambda: {"choices": list_models(current_model_dir)},
"open_folder_small",
)
button_lora_model_file = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_lora_model_file.click(
@ -96,7 +103,7 @@ def gradio_verify_lora_tab(headless=False):
outputs=lora_model,
show_progress=False,
)
verify_button = gr.Button('Verify', variant='primary')
verify_button = gr.Button("Verify", variant="primary")
lora_model.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
@ -106,16 +113,16 @@ def gradio_verify_lora_tab(headless=False):
)
lora_model_verif_output = gr.Textbox(
label='Output',
placeholder='Verification output',
label="Output",
placeholder="Verification output",
interactive=False,
lines=1,
max_lines=10,
)
lora_model_verif_error = gr.Textbox(
label='Error',
placeholder='Verification error',
label="Error",
placeholder="Verification error",
interactive=False,
lines=1,
max_lines=10,

View File

@ -164,9 +164,9 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None):
)
caption_extension = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file (e.g., .caption, .txt)',
value='.txt',
label="Caption file extension",
placeholder="Extension for caption file (e.g., .caption, .txt)",
value=".txt",
interactive=True,
)
@ -266,7 +266,7 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None):
)
character_threshold = gr.Slider(
value=0.35,
label='Character threshold',
label="Character threshold",
minimum=0,
maximum=1,
step=0.05,