mirror of https://github.com/bmaltais/kohya_ss
Format with black
parent
c827268bf3
commit
49f76343b5
|
|
@ -1,7 +1,13 @@
|
|||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
from .common_gui import get_folder_path, add_pre_postfix, find_replace, scriptdir, list_dirs
|
||||
from .common_gui import (
|
||||
get_folder_path,
|
||||
add_pre_postfix,
|
||||
find_replace,
|
||||
scriptdir,
|
||||
list_dirs,
|
||||
)
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
@ -12,6 +18,7 @@ log = setup_logging()
|
|||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
def caption_images(
|
||||
caption_text: str,
|
||||
images_dir: str,
|
||||
|
|
@ -41,26 +48,26 @@ def caption_images(
|
|||
# Check if images_dir is provided
|
||||
if not images_dir:
|
||||
msgbox(
|
||||
'Image folder is missing. Please provide the directory containing the images to caption.'
|
||||
"Image folder is missing. Please provide the directory containing the images to caption."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if caption_ext is provided
|
||||
if not caption_ext:
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
msgbox("Please provide an extension for the caption files.")
|
||||
return
|
||||
|
||||
# Log the captioning process
|
||||
if caption_text:
|
||||
log.info(f'Captioning files in {images_dir} with {caption_text}...')
|
||||
log.info(f"Captioning files in {images_dir} with {caption_text}...")
|
||||
|
||||
# Build the command to run caption.py
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/caption.py"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/caption.py"'
|
||||
run_cmd += f' --caption_text="{caption_text}"'
|
||||
|
||||
# Add optional flags to the command
|
||||
if overwrite:
|
||||
run_cmd += f' --overwrite'
|
||||
run_cmd += f" --overwrite"
|
||||
if caption_ext:
|
||||
run_cmd += f' --caption_file_ext="{caption_ext}"'
|
||||
|
||||
|
|
@ -71,7 +78,9 @@ def caption_images(
|
|||
|
||||
# Set the environment variable for the Python path
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command based on the operating system
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
|
@ -102,7 +111,7 @@ def caption_images(
|
|||
)
|
||||
|
||||
# Log the end of the captioning process
|
||||
log.info('Captioning done.')
|
||||
log.info("Captioning done.")
|
||||
|
||||
|
||||
# Gradio UI
|
||||
|
|
@ -121,7 +130,11 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
from .common_gui import create_refresh_button
|
||||
|
||||
# Set default images directory if not provided
|
||||
default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data")
|
||||
default_images_dir = (
|
||||
default_images_dir
|
||||
if default_images_dir is not None
|
||||
else os.path.join(scriptdir, "data")
|
||||
)
|
||||
current_images_dir = default_images_dir
|
||||
|
||||
# Function to list directories
|
||||
|
|
@ -141,26 +154,34 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
return list(list_dirs(path))
|
||||
|
||||
# Gradio tab for basic captioning
|
||||
with gr.Tab('Basic Captioning'):
|
||||
with gr.Tab("Basic Captioning"):
|
||||
# Markdown description
|
||||
gr.Markdown(
|
||||
'This utility allows you to create simple caption files for each image in a folder.'
|
||||
"This utility allows you to create simple caption files for each image in a folder."
|
||||
)
|
||||
# Group and row for image folder selection
|
||||
with gr.Group(), gr.Row():
|
||||
# Dropdown for image folder
|
||||
images_dir = gr.Dropdown(
|
||||
label='Image folder to caption (containing the images to caption)',
|
||||
label="Image folder to caption (containing the images to caption)",
|
||||
choices=[""] + list_images_dirs(default_images_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# Refresh button for image folder
|
||||
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small")
|
||||
create_refresh_button(
|
||||
images_dir,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_images_dirs(current_images_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
# Button to open folder
|
||||
folder_button = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
# Event handler for button click
|
||||
folder_button.click(
|
||||
|
|
@ -170,14 +191,14 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
)
|
||||
# Textbox for caption file extension
|
||||
caption_ext = gr.Textbox(
|
||||
label='Caption file extension',
|
||||
placeholder='Extension for caption file (e.g., .caption, .txt)',
|
||||
value='.txt',
|
||||
label="Caption file extension",
|
||||
placeholder="Extension for caption file (e.g., .caption, .txt)",
|
||||
value=".txt",
|
||||
interactive=True,
|
||||
)
|
||||
# Checkbox to overwrite existing captions
|
||||
overwrite = gr.Checkbox(
|
||||
label='Overwrite existing captions in folder',
|
||||
label="Overwrite existing captions in folder",
|
||||
interactive=True,
|
||||
value=False,
|
||||
)
|
||||
|
|
@ -185,41 +206,41 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
with gr.Row():
|
||||
# Textbox for caption prefix
|
||||
prefix = gr.Textbox(
|
||||
label='Prefix to add to caption',
|
||||
placeholder='(Optional)',
|
||||
label="Prefix to add to caption",
|
||||
placeholder="(Optional)",
|
||||
interactive=True,
|
||||
)
|
||||
# Textbox for caption text
|
||||
caption_text = gr.Textbox(
|
||||
label='Caption text',
|
||||
label="Caption text",
|
||||
placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.',
|
||||
interactive=True,
|
||||
lines=2,
|
||||
)
|
||||
# Textbox for caption postfix
|
||||
postfix = gr.Textbox(
|
||||
label='Postfix to add to caption',
|
||||
placeholder='(Optional)',
|
||||
label="Postfix to add to caption",
|
||||
placeholder="(Optional)",
|
||||
interactive=True,
|
||||
)
|
||||
# Group and row for find and replace text
|
||||
with gr.Group(), gr.Row():
|
||||
# Textbox for find text
|
||||
find_text = gr.Textbox(
|
||||
label='Find text',
|
||||
label="Find text",
|
||||
placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.',
|
||||
interactive=True,
|
||||
lines=2,
|
||||
)
|
||||
# Textbox for replace text
|
||||
replace_text = gr.Textbox(
|
||||
label='Replacement text',
|
||||
label="Replacement text",
|
||||
placeholder='e.g., "by some artist". Leave empty if you want to replace with nothing.',
|
||||
interactive=True,
|
||||
lines=2,
|
||||
)
|
||||
# Button to caption images
|
||||
caption_button = gr.Button('Caption images')
|
||||
caption_button = gr.Button("Caption images")
|
||||
# Event handler for button click
|
||||
caption_button.click(
|
||||
caption_images,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import gradio as gr
|
||||
import os
|
||||
|
||||
from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs
|
||||
from .common_gui import get_folder_path, scriptdir, list_dirs
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
|
|
|
|||
|
|
@ -74,7 +74,9 @@ def caption_images(
|
|||
|
||||
# Set up the environment
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command in the sd-scripts folder context
|
||||
subprocess.run(run_cmd, shell=True, env=env, cwd=f"{scriptdir}/sd-scripts")
|
||||
|
|
|
|||
|
|
@ -69,12 +69,12 @@ class AccelerateLaunch:
|
|||
|
||||
def run_cmd(**kwargs):
|
||||
run_cmd = ""
|
||||
|
||||
|
||||
if "extra_accelerate_launch_args" in kwargs:
|
||||
extra_accelerate_launch_args = kwargs.get("extra_accelerate_launch_args")
|
||||
if extra_accelerate_launch_args != "":
|
||||
run_cmd += fr' {extra_accelerate_launch_args}'
|
||||
|
||||
run_cmd += rf" {extra_accelerate_launch_args}"
|
||||
|
||||
if "gpu_ids" in kwargs:
|
||||
gpu_ids = kwargs.get("gpu_ids")
|
||||
if not gpu_ids == "":
|
||||
|
|
@ -109,4 +109,4 @@ class AccelerateLaunch:
|
|||
f" --num_cpu_threads_per_process={int(num_cpu_threads_per_process)}"
|
||||
)
|
||||
|
||||
return run_cmd
|
||||
return run_cmd
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from .common_gui import (
|
|||
list_files,
|
||||
list_dirs,
|
||||
create_refresh_button,
|
||||
document_symbol
|
||||
document_symbol,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -221,12 +221,20 @@ class AdvancedTraining:
|
|||
choices=["none", "sdpa", "xformers"],
|
||||
value="xformers",
|
||||
)
|
||||
self.color_aug = gr.Checkbox(label="Color augmentation", value=False, info="Enable weak color augmentation")
|
||||
self.flip_aug = gr.Checkbox(label="Flip augmentation", value=False, info="Enable horizontal flip augmentation")
|
||||
self.color_aug = gr.Checkbox(
|
||||
label="Color augmentation",
|
||||
value=False,
|
||||
info="Enable weak color augmentation",
|
||||
)
|
||||
self.flip_aug = gr.Checkbox(
|
||||
label="Flip augmentation",
|
||||
value=False,
|
||||
info="Enable horizontal flip augmentation",
|
||||
)
|
||||
self.masked_loss = gr.Checkbox(
|
||||
label="Masked loss",
|
||||
value=False,
|
||||
info="Apply mask for calculating loss. conditioning_data_dir is required for dataset"
|
||||
info="Apply mask for calculating loss. conditioning_data_dir is required for dataset",
|
||||
)
|
||||
with gr.Row():
|
||||
self.scale_v_pred_loss_like_noise_pred = gr.Checkbox(
|
||||
|
|
@ -296,7 +304,7 @@ class AdvancedTraining:
|
|||
"Multires",
|
||||
],
|
||||
value="Original",
|
||||
scale=1
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row(visible=True) as self.noise_offset_original:
|
||||
self.noise_offset = gr.Slider(
|
||||
|
|
@ -305,12 +313,12 @@ class AdvancedTraining:
|
|||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
info='Recommended values are 0.05 - 0.15',
|
||||
info="Recommended values are 0.05 - 0.15",
|
||||
)
|
||||
self.noise_offset_random_strength = gr.Checkbox(
|
||||
self.noise_offset_random_strength = gr.Checkbox(
|
||||
label="Noise offset random strength",
|
||||
value=False,
|
||||
info='Use random strength between 0~noise_offset for noise offset',
|
||||
info="Use random strength between 0~noise_offset for noise offset",
|
||||
)
|
||||
self.adaptive_noise_scale = gr.Slider(
|
||||
label="Adaptive noise scale",
|
||||
|
|
@ -327,7 +335,7 @@ class AdvancedTraining:
|
|||
minimum=0,
|
||||
maximum=64,
|
||||
step=1,
|
||||
info='Enable multires noise (recommended values are 6-10)',
|
||||
info="Enable multires noise (recommended values are 6-10)",
|
||||
)
|
||||
self.multires_noise_discount = gr.Slider(
|
||||
label="Multires noise discount",
|
||||
|
|
@ -335,7 +343,7 @@ class AdvancedTraining:
|
|||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
info='Recommended values are 0.8. For LoRAs with small datasets, 0.1-0.3',
|
||||
info="Recommended values are 0.8. For LoRAs with small datasets, 0.1-0.3",
|
||||
)
|
||||
with gr.Row(visible=True):
|
||||
self.ip_noise_gamma = gr.Slider(
|
||||
|
|
@ -344,12 +352,12 @@ class AdvancedTraining:
|
|||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
info='enable input perturbation noise. used for regularization. recommended value: around 0.1',
|
||||
info="enable input perturbation noise. used for regularization. recommended value: around 0.1",
|
||||
)
|
||||
self.ip_noise_gamma_random_strength = gr.Checkbox(
|
||||
self.ip_noise_gamma_random_strength = gr.Checkbox(
|
||||
label="IP noise gamma random strength",
|
||||
value=False,
|
||||
info='Use random strength between 0~ip_noise_gamma for input perturbation noise',
|
||||
info="Use random strength between 0~ip_noise_gamma for input perturbation noise",
|
||||
)
|
||||
self.noise_offset_type.change(
|
||||
noise_offset_type_change,
|
||||
|
|
@ -371,9 +379,10 @@ class AdvancedTraining:
|
|||
)
|
||||
with gr.Group(), gr.Row():
|
||||
self.save_state = gr.Checkbox(label="Save training state", value=False)
|
||||
|
||||
|
||||
self.save_state_on_train_end = gr.Checkbox(label="Save training state at end of training", value=False)
|
||||
|
||||
self.save_state_on_train_end = gr.Checkbox(
|
||||
label="Save training state at end of training", value=False
|
||||
)
|
||||
|
||||
def list_state_dirs(path):
|
||||
self.current_state_dir = path if not path == "" else "."
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
|
||||
class CommandExecutor:
|
||||
"""
|
||||
A class to execute and manage commands.
|
||||
|
|
@ -43,7 +44,9 @@ class CommandExecutor:
|
|||
log.info("The running process has been terminated.")
|
||||
except psutil.NoSuchProcess:
|
||||
# Explicitly handle the case where the process does not exist
|
||||
log.info("The process does not exist. It might have terminated before the kill command was issued.")
|
||||
log.info(
|
||||
"The process does not exist. It might have terminated before the kill command was issued."
|
||||
)
|
||||
except Exception as e:
|
||||
# General exception handling for any other errors
|
||||
log.info(f"Error when terminating process: {e}")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ class ConfigurationFile:
|
|||
A class to handle configuration file operations in the GUI.
|
||||
"""
|
||||
|
||||
def __init__(self, headless: bool = False, config_dir: str = None, config:dict = {}):
|
||||
def __init__(
|
||||
self, headless: bool = False, config_dir: str = None, config: dict = {}
|
||||
):
|
||||
"""
|
||||
Initialize the ConfigurationFile class.
|
||||
|
||||
|
|
@ -22,11 +24,13 @@ class ConfigurationFile:
|
|||
"""
|
||||
|
||||
self.headless = headless
|
||||
|
||||
|
||||
self.config = config
|
||||
|
||||
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
|
||||
self.current_config_dir = self.config.get('config_dir', os.path.join(scriptdir, "presets"))
|
||||
self.current_config_dir = self.config.get(
|
||||
"config_dir", os.path.join(scriptdir, "presets")
|
||||
)
|
||||
|
||||
# Initialize the GUI components for configuration.
|
||||
self.create_config_gui()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
from .common_gui import get_folder_path, scriptdir, list_dirs, list_files, create_refresh_button
|
||||
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
|
||||
|
||||
|
||||
class Folders:
|
||||
"""
|
||||
A class to handle folder operations in the GUI.
|
||||
"""
|
||||
def __init__(self, finetune: bool = False, headless: bool = False, config:dict = {}):
|
||||
|
||||
def __init__(
|
||||
self, finetune: bool = False, headless: bool = False, config: dict = {}
|
||||
):
|
||||
"""
|
||||
Initialize the Folders class.
|
||||
|
||||
|
|
@ -21,9 +25,15 @@ class Folders:
|
|||
self.config = config
|
||||
|
||||
# Set default directories if not provided
|
||||
self.current_output_dir = self.config.get('output_dir', os.path.join(scriptdir, "outputs"))
|
||||
self.current_logging_dir = self.config.get('logging_dir', os.path.join(scriptdir, "logs"))
|
||||
self.current_reg_data_dir = self.config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
|
||||
self.current_output_dir = self.config.get(
|
||||
"output_dir", os.path.join(scriptdir, "outputs")
|
||||
)
|
||||
self.current_logging_dir = self.config.get(
|
||||
"logging_dir", os.path.join(scriptdir, "logs")
|
||||
)
|
||||
self.current_reg_data_dir = self.config.get(
|
||||
"reg_data_dir", os.path.join(scriptdir, "reg")
|
||||
)
|
||||
|
||||
# Create directories if they don't exist
|
||||
self.create_directory_if_not_exists(self.current_output_dir)
|
||||
|
|
@ -39,10 +49,13 @@ class Folders:
|
|||
Parameters:
|
||||
- directory (str): The directory to create.
|
||||
"""
|
||||
if directory is not None and directory.strip() != "" and not os.path.exists(directory):
|
||||
if (
|
||||
directory is not None
|
||||
and directory.strip() != ""
|
||||
and not os.path.exists(directory)
|
||||
):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
def list_output_dirs(self, path: str) -> list:
|
||||
"""
|
||||
List directories in the output directory.
|
||||
|
|
@ -96,10 +109,20 @@ class Folders:
|
|||
allow_custom_value=True,
|
||||
)
|
||||
# Refresh button for output directory
|
||||
create_refresh_button(self.output_dir, lambda: None, lambda: {"choices": [""] + self.list_output_dirs(self.current_output_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
self.output_dir,
|
||||
lambda: None,
|
||||
lambda: {
|
||||
"choices": [""] + self.list_output_dirs(self.current_output_dir)
|
||||
},
|
||||
"open_folder_small",
|
||||
)
|
||||
# Output directory button
|
||||
self.output_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
# Output directory button click event
|
||||
self.output_dir_folder.click(
|
||||
|
|
@ -110,17 +133,31 @@ class Folders:
|
|||
|
||||
# Regularisation directory dropdown
|
||||
self.reg_data_dir = gr.Dropdown(
|
||||
label='Regularisation directory (Optional. containing regularisation images)' if not self.finetune else 'Train config directory (Optional. where config files will be saved)',
|
||||
label=(
|
||||
"Regularisation directory (Optional. containing regularisation images)"
|
||||
if not self.finetune
|
||||
else "Train config directory (Optional. where config files will be saved)"
|
||||
),
|
||||
choices=[""] + self.list_reg_data_dirs(self.current_reg_data_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# Refresh button for regularisation directory
|
||||
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
self.reg_data_dir,
|
||||
lambda: None,
|
||||
lambda: {
|
||||
"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)
|
||||
},
|
||||
"open_folder_small",
|
||||
)
|
||||
# Regularisation directory button
|
||||
self.reg_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
# Regularisation directory button click event
|
||||
self.reg_data_dir_folder.click(
|
||||
|
|
@ -131,17 +168,27 @@ class Folders:
|
|||
with gr.Row():
|
||||
# Logging directory dropdown
|
||||
self.logging_dir = gr.Dropdown(
|
||||
label='Logging directory (Optional. to enable logging and output Tensorboard log)',
|
||||
label="Logging directory (Optional. to enable logging and output Tensorboard log)",
|
||||
choices=[""] + self.list_logging_dirs(self.current_logging_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# Refresh button for logging directory
|
||||
create_refresh_button(self.logging_dir, lambda: None, lambda: {"choices": [""] + self.list_logging_dirs(self.current_logging_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
self.logging_dir,
|
||||
lambda: None,
|
||||
lambda: {
|
||||
"choices": [""] + self.list_logging_dirs(self.current_logging_dir)
|
||||
},
|
||||
"open_folder_small",
|
||||
)
|
||||
# Logging directory button
|
||||
self.logging_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
# Logging directory button click event
|
||||
self.logging_dir_folder.click(
|
||||
|
|
@ -159,14 +206,18 @@ class Folders:
|
|||
)
|
||||
# Change event for regularisation directory dropdown
|
||||
self.reg_data_dir.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(
|
||||
choices=[""] + self.list_reg_data_dirs(path)
|
||||
),
|
||||
inputs=self.reg_data_dir,
|
||||
outputs=self.reg_data_dir,
|
||||
show_progress=False,
|
||||
)
|
||||
# Change event for logging directory dropdown
|
||||
self.logging_dir.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + self.list_logging_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(
|
||||
choices=[""] + self.list_logging_dirs(path)
|
||||
),
|
||||
inputs=self.logging_dir,
|
||||
outputs=self.logging_dir,
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
|
||||
class KohyaSSGUIConfig:
|
||||
"""
|
||||
A class to handle the configuration for the Kohya SS GUI.
|
||||
|
|
@ -30,7 +31,9 @@ class KohyaSSGUIConfig:
|
|||
except FileNotFoundError:
|
||||
# If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully.
|
||||
config = {}
|
||||
log.debug(f"No configuration file found at {config_file_path}. Initializing empty configuration.")
|
||||
log.debug(
|
||||
f"No configuration file found at {config_file_path}. Initializing empty configuration."
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
|
@ -66,7 +69,9 @@ class KohyaSSGUIConfig:
|
|||
log.debug(k)
|
||||
# If the key is not found in the current data, return the default value
|
||||
if k not in data:
|
||||
log.debug(f"Key '{key}' not found in configuration. Returning default value.")
|
||||
log.debug(
|
||||
f"Key '{key}' not found in configuration. Returning default value."
|
||||
)
|
||||
return default
|
||||
|
||||
# Update `data` to the value associated with the current key
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ class LoRATools:
|
|||
def __init__(self, headless: bool = False):
|
||||
self.headless = headless
|
||||
|
||||
gr.Markdown(
|
||||
'This section provide various LoRA tools...'
|
||||
)
|
||||
gr.Markdown("This section provide various LoRA tools...")
|
||||
gradio_extract_dylora_tab(headless=headless)
|
||||
gradio_convert_lcm_tab(headless=headless)
|
||||
gradio_extract_lora_tab(headless=headless)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -38,10 +37,10 @@ def run_cmd_sample(
|
|||
Returns:
|
||||
str: The command string for sampling images.
|
||||
"""
|
||||
output_dir = os.path.join(output_dir, 'sample')
|
||||
output_dir = os.path.join(output_dir, "sample")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
run_cmd = ''
|
||||
run_cmd = ""
|
||||
|
||||
if sample_every_n_epochs is None:
|
||||
sample_every_n_epochs = 0
|
||||
|
|
@ -53,19 +52,19 @@ def run_cmd_sample(
|
|||
return run_cmd
|
||||
|
||||
# Create the prompt file and get its path
|
||||
sample_prompts_path = os.path.join(output_dir, 'prompt.txt')
|
||||
sample_prompts_path = os.path.join(output_dir, "prompt.txt")
|
||||
|
||||
with open(sample_prompts_path, 'w') as f:
|
||||
with open(sample_prompts_path, "w") as f:
|
||||
f.write(sample_prompts)
|
||||
|
||||
run_cmd += f' --sample_sampler={sample_sampler}'
|
||||
run_cmd += f" --sample_sampler={sample_sampler}"
|
||||
run_cmd += f' --sample_prompts="{sample_prompts_path}"'
|
||||
|
||||
if sample_every_n_epochs != 0:
|
||||
run_cmd += f' --sample_every_n_epochs={sample_every_n_epochs}'
|
||||
run_cmd += f" --sample_every_n_epochs={sample_every_n_epochs}"
|
||||
|
||||
if sample_every_n_steps != 0:
|
||||
run_cmd += f' --sample_every_n_steps={sample_every_n_steps}'
|
||||
run_cmd += f" --sample_every_n_steps={sample_every_n_steps}"
|
||||
|
||||
return run_cmd
|
||||
|
||||
|
|
@ -89,45 +88,45 @@ class SampleImages:
|
|||
"""
|
||||
with gr.Row():
|
||||
self.sample_every_n_steps = gr.Number(
|
||||
label='Sample every n steps',
|
||||
label="Sample every n steps",
|
||||
value=0,
|
||||
precision=0,
|
||||
interactive=True,
|
||||
)
|
||||
self.sample_every_n_epochs = gr.Number(
|
||||
label='Sample every n epochs',
|
||||
label="Sample every n epochs",
|
||||
value=0,
|
||||
precision=0,
|
||||
interactive=True,
|
||||
)
|
||||
self.sample_sampler = gr.Dropdown(
|
||||
label='Sample sampler',
|
||||
label="Sample sampler",
|
||||
choices=[
|
||||
'ddim',
|
||||
'pndm',
|
||||
'lms',
|
||||
'euler',
|
||||
'euler_a',
|
||||
'heun',
|
||||
'dpm_2',
|
||||
'dpm_2_a',
|
||||
'dpmsolver',
|
||||
'dpmsolver++',
|
||||
'dpmsingle',
|
||||
'k_lms',
|
||||
'k_euler',
|
||||
'k_euler_a',
|
||||
'k_dpm_2',
|
||||
'k_dpm_2_a',
|
||||
"ddim",
|
||||
"pndm",
|
||||
"lms",
|
||||
"euler",
|
||||
"euler_a",
|
||||
"heun",
|
||||
"dpm_2",
|
||||
"dpm_2_a",
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"dpmsingle",
|
||||
"k_lms",
|
||||
"k_euler",
|
||||
"k_euler_a",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_a",
|
||||
],
|
||||
value='euler_a',
|
||||
value="euler_a",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
self.sample_prompts = gr.Textbox(
|
||||
lines=5,
|
||||
label='Sample prompts',
|
||||
label="Sample prompts",
|
||||
interactive=True,
|
||||
placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28',
|
||||
info='Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.',
|
||||
placeholder="masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28",
|
||||
info="Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import gradio as gr
|
||||
|
||||
|
||||
class SDXLParameters:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -7,25 +8,23 @@ class SDXLParameters:
|
|||
show_sdxl_cache_text_encoder_outputs: bool = True,
|
||||
):
|
||||
self.sdxl_checkbox = sdxl_checkbox
|
||||
self.show_sdxl_cache_text_encoder_outputs = (
|
||||
show_sdxl_cache_text_encoder_outputs
|
||||
)
|
||||
self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs
|
||||
self.initialize_accordion()
|
||||
|
||||
def initialize_accordion(self):
|
||||
with gr.Accordion(
|
||||
visible=False, open=True, label='SDXL Specific Parameters'
|
||||
visible=False, open=True, label="SDXL Specific Parameters"
|
||||
) as self.sdxl_row:
|
||||
with gr.Row():
|
||||
self.sdxl_cache_text_encoder_outputs = gr.Checkbox(
|
||||
label='Cache text encoder outputs',
|
||||
info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.',
|
||||
label="Cache text encoder outputs",
|
||||
info="Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.",
|
||||
value=False,
|
||||
visible=self.show_sdxl_cache_text_encoder_outputs,
|
||||
)
|
||||
self.sdxl_no_half_vae = gr.Checkbox(
|
||||
label='No half VAE',
|
||||
info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.',
|
||||
label="No half VAE",
|
||||
info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.",
|
||||
value=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,8 +62,9 @@ class SourceModel:
|
|||
self.current_train_data_dir = self.config.get(
|
||||
"train_data_dir", os.path.join(scriptdir, "data")
|
||||
)
|
||||
self.current_dataset_config_dir = self.config.get('dataset_config_dir', os.path.join(scriptdir, "dataset_config"))
|
||||
|
||||
self.current_dataset_config_dir = self.config.get(
|
||||
"dataset_config_dir", os.path.join(scriptdir, "dataset_config")
|
||||
)
|
||||
|
||||
model_checkpoints = list(
|
||||
list_files(
|
||||
|
|
@ -82,7 +83,7 @@ class SourceModel:
|
|||
def list_train_data_dirs(path):
|
||||
self.current_train_data_dir = path if not path == "" else "."
|
||||
return list(list_dirs(path))
|
||||
|
||||
|
||||
def list_dataset_config_dirs(path: str) -> list:
|
||||
"""
|
||||
List directories and toml files in the dataset_config directory.
|
||||
|
|
@ -95,8 +96,9 @@ class SourceModel:
|
|||
"""
|
||||
current_dataset_config_dir = path if not path == "" else "."
|
||||
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
|
||||
return list(list_files(current_dataset_config_dir, exts=[".toml"], all=True))
|
||||
|
||||
return list(
|
||||
list_files(current_dataset_config_dir, exts=[".toml"], all=True)
|
||||
)
|
||||
|
||||
with gr.Accordion("Model", open=True):
|
||||
with gr.Column(), gr.Group():
|
||||
|
|
@ -143,7 +145,7 @@ class SourceModel:
|
|||
outputs=self.pretrained_model_name_or_path,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
with gr.Column(), gr.Row():
|
||||
self.output_name = gr.Textbox(
|
||||
label="Trained Model output name",
|
||||
|
|
@ -188,29 +190,49 @@ class SourceModel:
|
|||
with gr.Column(), gr.Row():
|
||||
# Toml directory dropdown
|
||||
self.dataset_config = gr.Dropdown(
|
||||
label='Dataset config file (Optional. Select the toml configuration file to use for the dataset)',
|
||||
choices=[""] + list_dataset_config_dirs(self.current_dataset_config_dir),
|
||||
label="Dataset config file (Optional. Select the toml configuration file to use for the dataset)",
|
||||
choices=[""]
|
||||
+ list_dataset_config_dirs(self.current_dataset_config_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# Refresh button for dataset_config directory
|
||||
create_refresh_button(self.dataset_config, lambda: None, lambda: {"choices": [""] + list_dataset_config_dirs(self.current_dataset_config_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
self.dataset_config,
|
||||
lambda: None,
|
||||
lambda: {
|
||||
"choices": [""]
|
||||
+ list_dataset_config_dirs(
|
||||
self.current_dataset_config_dir
|
||||
)
|
||||
},
|
||||
"open_folder_small",
|
||||
)
|
||||
# Toml directory button
|
||||
self.dataset_config_folder = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
|
||||
|
||||
# Toml directory button click event
|
||||
self.dataset_config_folder.click(
|
||||
get_file_path,
|
||||
inputs=[self.dataset_config, gr.Textbox(value='*.toml', visible=False), gr.Textbox(value='Dataset config types', visible=False)],
|
||||
inputs=[
|
||||
self.dataset_config,
|
||||
gr.Textbox(value="*.toml", visible=False),
|
||||
gr.Textbox(value="Dataset config types", visible=False),
|
||||
],
|
||||
outputs=self.dataset_config,
|
||||
show_progress=False,
|
||||
)
|
||||
# Change event for dataset_config directory dropdown
|
||||
self.dataset_config.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_config_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(
|
||||
choices=[""] + list_dataset_config_dirs(path)
|
||||
),
|
||||
inputs=self.dataset_config,
|
||||
outputs=self.dataset_config,
|
||||
show_progress=False,
|
||||
|
|
@ -273,7 +295,9 @@ class SourceModel:
|
|||
)
|
||||
|
||||
self.train_data_dir.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(
|
||||
choices=[""] + list_train_data_dirs(path)
|
||||
),
|
||||
inputs=self.train_data_dir,
|
||||
outputs=self.train_data_dir,
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from .custom_logging import setup_logging
|
|||
import os
|
||||
import re
|
||||
import gradio as gr
|
||||
import shutil
|
||||
import sys
|
||||
import json
|
||||
import math
|
||||
|
|
@ -55,6 +54,7 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX
|
|||
|
||||
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]
|
||||
|
||||
|
||||
def calculate_max_train_steps(
|
||||
total_steps: int,
|
||||
train_batch_size: int,
|
||||
|
|
@ -72,6 +72,7 @@ def calculate_max_train_steps(
|
|||
)
|
||||
)
|
||||
|
||||
|
||||
def check_if_model_exist(
|
||||
output_name: str, output_dir: str, save_model_as: str, headless: bool = False
|
||||
) -> bool:
|
||||
|
|
@ -1097,14 +1098,14 @@ def run_cmd_advanced_training(**kwargs):
|
|||
|
||||
if kwargs.get("gradient_checkpointing"):
|
||||
run_cmd += " --gradient_checkpointing"
|
||||
|
||||
|
||||
if kwargs.get("ip_noise_gamma"):
|
||||
if float(kwargs["ip_noise_gamma"]) > 0:
|
||||
run_cmd += f' --ip_noise_gamma={kwargs["ip_noise_gamma"]}'
|
||||
|
||||
|
||||
if kwargs.get("ip_noise_gamma_random_strength"):
|
||||
if kwargs["ip_noise_gamma_random_strength"]:
|
||||
run_cmd += f' --ip_noise_gamma_random_strength'
|
||||
run_cmd += f" --ip_noise_gamma_random_strength"
|
||||
|
||||
if "keep_tokens" in kwargs and int(kwargs["keep_tokens"]) > 0:
|
||||
run_cmd += f' --keep_tokens="{int(kwargs["keep_tokens"])}"'
|
||||
|
|
@ -1180,7 +1181,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"'
|
||||
|
||||
if "masked_loss" in kwargs:
|
||||
if kwargs.get("masked_loss"): # Test if the value is true as it could be false
|
||||
if kwargs.get("masked_loss"): # Test if the value is true as it could be false
|
||||
run_cmd += " --masked_loss"
|
||||
|
||||
if "max_data_loader_n_workers" in kwargs:
|
||||
|
|
@ -1194,7 +1195,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f' --max_grad_norm="{max_grad_norm}"'
|
||||
|
||||
if "max_resolution" in kwargs:
|
||||
run_cmd += fr' --resolution="{kwargs.get("max_resolution")}"'
|
||||
run_cmd += rf' --resolution="{kwargs.get("max_resolution")}"'
|
||||
|
||||
if "max_timestep" in kwargs:
|
||||
max_timestep = kwargs.get("max_timestep")
|
||||
|
|
@ -1217,7 +1218,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f' --max_train_steps="{max_train_steps}"'
|
||||
|
||||
if "mem_eff_attn" in kwargs:
|
||||
if kwargs.get("mem_eff_attn"): # Test if the value is true as it could be false
|
||||
if kwargs.get("mem_eff_attn"): # Test if the value is true as it could be false
|
||||
run_cmd += " --mem_eff_attn"
|
||||
|
||||
if "min_snr_gamma" in kwargs:
|
||||
|
|
@ -1229,20 +1230,20 @@ def run_cmd_advanced_training(**kwargs):
|
|||
min_timestep = kwargs.get("min_timestep")
|
||||
if int(min_timestep) > -1:
|
||||
run_cmd += f" --min_timestep={int(min_timestep)}"
|
||||
|
||||
|
||||
if "mixed_precision" in kwargs:
|
||||
run_cmd += rf' --mixed_precision="{kwargs.get("mixed_precision")}"'
|
||||
|
||||
if "network_alpha" in kwargs:
|
||||
run_cmd += fr' --network_alpha="{kwargs.get("network_alpha")}"'
|
||||
run_cmd += rf' --network_alpha="{kwargs.get("network_alpha")}"'
|
||||
|
||||
if "network_args" in kwargs:
|
||||
network_args = kwargs.get("network_args")
|
||||
if network_args != "":
|
||||
run_cmd += f' --network_args{network_args}'
|
||||
run_cmd += f" --network_args{network_args}"
|
||||
|
||||
if "network_dim" in kwargs:
|
||||
run_cmd += fr' --network_dim={kwargs.get("network_dim")}'
|
||||
run_cmd += rf' --network_dim={kwargs.get("network_dim")}'
|
||||
|
||||
if "network_dropout" in kwargs:
|
||||
network_dropout = kwargs.get("network_dropout")
|
||||
|
|
@ -1252,7 +1253,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
if "network_module" in kwargs:
|
||||
network_module = kwargs.get("network_module")
|
||||
if network_module != "":
|
||||
run_cmd += f' --network_module={network_module}'
|
||||
run_cmd += f" --network_module={network_module}"
|
||||
|
||||
if "network_train_text_encoder_only" in kwargs:
|
||||
if kwargs.get("network_train_text_encoder_only"):
|
||||
|
|
@ -1263,11 +1264,13 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += " --network_train_unet_only"
|
||||
|
||||
if "no_half_vae" in kwargs:
|
||||
if kwargs.get("no_half_vae"): # Test if the value is true as it could be false
|
||||
if kwargs.get("no_half_vae"): # Test if the value is true as it could be false
|
||||
run_cmd += " --no_half_vae"
|
||||
|
||||
if "no_token_padding" in kwargs:
|
||||
if kwargs.get("no_token_padding"): # Test if the value is true as it could be false
|
||||
if kwargs.get(
|
||||
"no_token_padding"
|
||||
): # Test if the value is true as it could be false
|
||||
run_cmd += " --no_token_padding"
|
||||
|
||||
if "noise_offset_type" in kwargs:
|
||||
|
|
@ -1283,18 +1286,24 @@ def run_cmd_advanced_training(**kwargs):
|
|||
adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0))
|
||||
if adaptive_noise_scale != 0 and noise_offset > 0:
|
||||
run_cmd += f" --adaptive_noise_scale={adaptive_noise_scale}"
|
||||
|
||||
|
||||
if "noise_offset_random_strength" in kwargs:
|
||||
if kwargs.get("noise_offset_random_strength"):
|
||||
run_cmd += f" --noise_offset_random_strength"
|
||||
elif noise_offset_type == "Multires":
|
||||
if "multires_noise_iterations" in kwargs:
|
||||
multires_noise_iterations = int(kwargs.get("multires_noise_iterations", 0))
|
||||
multires_noise_iterations = int(
|
||||
kwargs.get("multires_noise_iterations", 0)
|
||||
)
|
||||
if multires_noise_iterations > 0:
|
||||
run_cmd += f' --multires_noise_iterations="{multires_noise_iterations}"'
|
||||
run_cmd += (
|
||||
f' --multires_noise_iterations="{multires_noise_iterations}"'
|
||||
)
|
||||
|
||||
if "multires_noise_discount" in kwargs:
|
||||
multires_noise_discount = float(kwargs.get("multires_noise_discount", 0))
|
||||
multires_noise_discount = float(
|
||||
kwargs.get("multires_noise_discount", 0)
|
||||
)
|
||||
if multires_noise_discount > 0:
|
||||
run_cmd += f' --multires_noise_discount="{multires_noise_discount}"'
|
||||
|
||||
|
|
@ -1304,7 +1313,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f" --optimizer_args {optimizer_args}"
|
||||
|
||||
if "optimizer" in kwargs:
|
||||
run_cmd += fr' --optimizer_type="{kwargs.get("optimizer")}"'
|
||||
run_cmd += rf' --optimizer_type="{kwargs.get("optimizer")}"'
|
||||
|
||||
if "output_dir" in kwargs:
|
||||
output_dir = kwargs.get("output_dir")
|
||||
|
|
@ -1323,9 +1332,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += " --persistent_data_loader_workers"
|
||||
|
||||
if "pretrained_model_name_or_path" in kwargs:
|
||||
run_cmd += (
|
||||
rf' --pretrained_model_name_or_path="{kwargs.get("pretrained_model_name_or_path")}"'
|
||||
)
|
||||
run_cmd += rf' --pretrained_model_name_or_path="{kwargs.get("pretrained_model_name_or_path")}"'
|
||||
|
||||
if "prior_loss_weight" in kwargs:
|
||||
prior_loss_weight = kwargs.get("prior_loss_weight")
|
||||
|
|
@ -1376,12 +1383,12 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f" --save_model_as={save_model_as}"
|
||||
|
||||
if "save_precision" in kwargs:
|
||||
run_cmd += fr' --save_precision="{kwargs.get("save_precision")}"'
|
||||
run_cmd += rf' --save_precision="{kwargs.get("save_precision")}"'
|
||||
|
||||
if "save_state" in kwargs:
|
||||
if kwargs.get("save_state"):
|
||||
run_cmd += " --save_state"
|
||||
|
||||
|
||||
if "save_state_on_train_end" in kwargs:
|
||||
if kwargs.get("save_state_on_train_end"):
|
||||
run_cmd += " --save_state_on_train_end"
|
||||
|
|
@ -1415,7 +1422,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f" --text_encoder_lr={text_encoder_lr}"
|
||||
|
||||
if "train_batch_size" in kwargs:
|
||||
run_cmd += fr' --train_batch_size="{kwargs.get("train_batch_size")}"'
|
||||
run_cmd += rf' --train_batch_size="{kwargs.get("train_batch_size")}"'
|
||||
|
||||
training_comment = kwargs.get("training_comment")
|
||||
if training_comment and len(training_comment):
|
||||
|
|
|
|||
|
|
@ -22,17 +22,12 @@ document_symbol = "\U0001F4C4" # 📄
|
|||
PYTHON = sys.executable
|
||||
|
||||
|
||||
def convert_lcm(
|
||||
name,
|
||||
model_path,
|
||||
lora_scale,
|
||||
model_type
|
||||
):
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'
|
||||
def convert_lcm(name, model_path, lora_scale, model_type):
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'
|
||||
|
||||
# Check if source model exist
|
||||
if not os.path.isfile(model_path):
|
||||
log.error('The provided DyLoRA model is not a file')
|
||||
log.error("The provided DyLoRA model is not a file")
|
||||
return
|
||||
|
||||
if os.path.dirname(name) == "":
|
||||
|
|
@ -46,12 +41,11 @@ def convert_lcm(
|
|||
path, ext = os.path.splitext(save_to)
|
||||
save_to = f"{path}_lcm{ext}"
|
||||
|
||||
|
||||
# Construct the command to run the script
|
||||
run_cmd += f" --lora-scale {lora_scale}"
|
||||
run_cmd += f' --model "{model_path}"'
|
||||
run_cmd += f' --name "{name}"'
|
||||
|
||||
|
||||
if model_type == "SDXL":
|
||||
run_cmd += f" --sdxl"
|
||||
if model_type == "SSD-1B":
|
||||
|
|
@ -60,7 +54,9 @@ def convert_lcm(
|
|||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
|
@ -98,11 +94,16 @@ def gradio_convert_lcm_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(model_path, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
model_path,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_model_path_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=['tool'],
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_path_file.click(
|
||||
|
|
@ -119,11 +120,16 @@ def gradio_convert_lcm_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
name,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_name = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=['tool'],
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_name.click(
|
||||
|
|
@ -154,7 +160,7 @@ def gradio_convert_lcm_tab(headless=False):
|
|||
value=1.0,
|
||||
interactive=True,
|
||||
)
|
||||
# with gr.Row():
|
||||
# with gr.Row():
|
||||
# no_half = gr.Checkbox(label="Convert the new LCM model to FP32", value=False)
|
||||
model_type = gr.Radio(
|
||||
label="Model type", choices=["SD15", "SDXL", "SD-1B"], value="SD15"
|
||||
|
|
@ -164,11 +170,6 @@ def gradio_convert_lcm_tab(headless=False):
|
|||
|
||||
extract_button.click(
|
||||
convert_lcm,
|
||||
inputs=[
|
||||
name,
|
||||
model_path,
|
||||
lora_scale,
|
||||
model_type
|
||||
],
|
||||
inputs=[name, model_path, lora_scale, model_type],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import gradio as gr
|
|||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from .common_gui import get_folder_path, get_file_path, scriptdir, list_files, list_dirs
|
||||
|
||||
|
|
@ -11,10 +10,10 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -29,52 +28,51 @@ def convert_model(
|
|||
unet_use_linear_projection,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if source_model_type == '':
|
||||
msgbox('Invalid source model type')
|
||||
if source_model_type == "":
|
||||
msgbox("Invalid source model type")
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if os.path.isfile(source_model_input):
|
||||
log.info('The provided source model is a file')
|
||||
log.info("The provided source model is a file")
|
||||
elif os.path.isdir(source_model_input):
|
||||
log.info('The provided model is a folder')
|
||||
log.info("The provided model is a folder")
|
||||
else:
|
||||
msgbox('The provided source model is neither a file nor a folder')
|
||||
msgbox("The provided source model is neither a file nor a folder")
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if os.path.isdir(target_model_folder_input):
|
||||
log.info('The provided model folder exist')
|
||||
log.info("The provided model folder exist")
|
||||
else:
|
||||
msgbox('The provided target folder does not exist')
|
||||
msgbox("The provided target folder does not exist")
|
||||
return
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"'
|
||||
run_cmd = (
|
||||
rf'"{PYTHON}" "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"'
|
||||
)
|
||||
|
||||
v1_models = [
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
# check if v1 models
|
||||
if str(source_model_type) in v1_models:
|
||||
log.info('SD v1 model specified. Setting --v1 parameter')
|
||||
run_cmd += ' --v1'
|
||||
log.info("SD v1 model specified. Setting --v1 parameter")
|
||||
run_cmd += " --v1"
|
||||
else:
|
||||
log.info('SD v2 model specified. Setting --v2 parameter')
|
||||
run_cmd += ' --v2'
|
||||
log.info("SD v2 model specified. Setting --v2 parameter")
|
||||
run_cmd += " --v2"
|
||||
|
||||
if not target_save_precision_type == 'unspecified':
|
||||
run_cmd += f' --{target_save_precision_type}'
|
||||
if not target_save_precision_type == "unspecified":
|
||||
run_cmd += f" --{target_save_precision_type}"
|
||||
|
||||
if (
|
||||
target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
):
|
||||
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
||||
run_cmd += f' --reference_model="{source_model_type}"'
|
||||
|
||||
if target_model_type == 'diffuser_safetensors':
|
||||
run_cmd += ' --use_safetensors'
|
||||
if target_model_type == "diffuser_safetensors":
|
||||
run_cmd += " --use_safetensors"
|
||||
|
||||
# Fix for stabilityAI diffusers format. When saving v2 models in Diffusers format in training scripts and conversion scripts,
|
||||
# it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is
|
||||
|
|
@ -82,14 +80,11 @@ def convert_model(
|
|||
# when using the weight files directly.
|
||||
|
||||
if unet_use_linear_projection:
|
||||
run_cmd += ' --unet_use_linear_projection'
|
||||
run_cmd += " --unet_use_linear_projection"
|
||||
|
||||
run_cmd += f' "{source_model_input}"'
|
||||
|
||||
if (
|
||||
target_model_type == 'diffuser'
|
||||
or target_model_type == 'diffuser_safetensors'
|
||||
):
|
||||
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
||||
target_model_path = os.path.join(
|
||||
target_model_folder_input, target_model_name_input
|
||||
)
|
||||
|
|
@ -97,74 +92,20 @@ def convert_model(
|
|||
else:
|
||||
target_model_path = os.path.join(
|
||||
target_model_folder_input,
|
||||
f'{target_model_name_input}.{target_model_type}',
|
||||
f"{target_model_name_input}.{target_model_type}",
|
||||
)
|
||||
run_cmd += f' "{target_model_path}"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
# if (
|
||||
# not target_model_type == 'diffuser'
|
||||
# or target_model_type == 'diffuser_safetensors'
|
||||
# ):
|
||||
|
||||
# v2_models = [
|
||||
# 'stabilityai/stable-diffusion-2-1-base',
|
||||
# 'stabilityai/stable-diffusion-2-base',
|
||||
# ]
|
||||
# v_parameterization = [
|
||||
# 'stabilityai/stable-diffusion-2-1',
|
||||
# 'stabilityai/stable-diffusion-2',
|
||||
# ]
|
||||
|
||||
# if str(source_model_type) in v2_models:
|
||||
# inference_file = os.path.join(
|
||||
# target_model_folder_input, f'{target_model_name_input}.yaml'
|
||||
# )
|
||||
# log.info(f'Saving v2-inference.yaml as {inference_file}')
|
||||
# shutil.copy(
|
||||
# fr'{scriptdir}/v2_inference/v2-inference.yaml',
|
||||
# f'{inference_file}',
|
||||
# )
|
||||
|
||||
# if str(source_model_type) in v_parameterization:
|
||||
# inference_file = os.path.join(
|
||||
# target_model_folder_input, f'{target_model_name_input}.yaml'
|
||||
# )
|
||||
# log.info(f'Saving v2-inference-v.yaml as {inference_file}')
|
||||
# shutil.copy(
|
||||
# fr'{scriptdir}/v2_inference/v2-inference-v.yaml',
|
||||
# f'{inference_file}',
|
||||
# )
|
||||
|
||||
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--v1", action='store_true',
|
||||
# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||
# parser.add_argument("--v2", action='store_true',
|
||||
# help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
||||
# parser.add_argument("--fp16", action='store_true',
|
||||
# help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
||||
# parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
||||
# parser.add_argument("--float", action='store_true',
|
||||
# help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
||||
# parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
||||
# parser.add_argument("--global_step", type=int, default=0,
|
||||
# help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||
# parser.add_argument("--reference_model", type=str, default=None,
|
||||
# help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
||||
|
||||
# parser.add_argument("model_to_load", type=str, default=None,
|
||||
# help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||
# parser.add_argument("model_to_save", type=str, default=None,
|
||||
# help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||
|
||||
|
||||
###
|
||||
# Gradio UI
|
||||
|
|
@ -189,124 +130,136 @@ def gradio_convert_model_tab(headless=False):
|
|||
current_target_folder = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
with gr.Tab('Convert model'):
|
||||
with gr.Tab("Convert model"):
|
||||
gr.Markdown(
|
||||
'This utility can be used to convert from one stable diffusion model format to another.'
|
||||
"This utility can be used to convert from one stable diffusion model format to another."
|
||||
)
|
||||
|
||||
model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
|
||||
model_ext_name = gr.Textbox(value='Model types', visible=False)
|
||||
model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
|
||||
model_ext_name = gr.Textbox(value="Model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
with gr.Column(), gr.Row():
|
||||
source_model_input = gr.Dropdown(
|
||||
label='Source model (path to source model folder of file to convert...)',
|
||||
interactive=True,
|
||||
choices=[""] + list_source_model(default_source_model),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(source_model_input, lambda: None, lambda: {"choices": list_source_model(current_source_model)}, "open_folder_small")
|
||||
button_source_model_dir = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_source_model_dir.click(
|
||||
get_folder_path,
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column(), gr.Row():
|
||||
source_model_input = gr.Dropdown(
|
||||
label="Source model (path to source model folder of file to convert...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_source_model(default_source_model),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
source_model_input,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_source_model(current_source_model)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_source_model_dir = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_source_model_dir.click(
|
||||
get_folder_path,
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_source_model_file = gr.Button(
|
||||
document_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_source_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[source_model_input, model_ext, model_ext_name],
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
)
|
||||
button_source_model_file = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_source_model_file.click(
|
||||
get_file_path,
|
||||
inputs=[source_model_input, model_ext, model_ext_name],
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
source_model_input.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)),
|
||||
inputs=source_model_input,
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column(), gr.Row():
|
||||
source_model_type = gr.Dropdown(
|
||||
label='Source model type',
|
||||
choices=[
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
],
|
||||
)
|
||||
source_model_input.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)),
|
||||
inputs=source_model_input,
|
||||
outputs=source_model_input,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column(), gr.Row():
|
||||
source_model_type = gr.Dropdown(
|
||||
label="Source model type",
|
||||
choices=[
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"stabilityai/stable-diffusion-2-base",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
with gr.Column(), gr.Row():
|
||||
target_model_folder_input = gr.Dropdown(
|
||||
label='Target model folder (path to target model folder of file name to create...)',
|
||||
interactive=True,
|
||||
choices=[""] + list_target_folder(default_target_folder),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(target_model_folder_input, lambda: None, lambda: {"choices": list_target_folder(current_target_folder)},"open_folder_small")
|
||||
button_target_model_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_target_model_folder.click(
|
||||
get_folder_path,
|
||||
outputs=target_model_folder_input,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column(), gr.Row():
|
||||
target_model_folder_input = gr.Dropdown(
|
||||
label="Target model folder (path to target model folder of file name to create...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_target_folder(default_target_folder),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
target_model_folder_input,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_target_folder(current_target_folder)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_target_model_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_target_model_folder.click(
|
||||
get_folder_path,
|
||||
outputs=target_model_folder_input,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
target_model_folder_input.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_target_folder(path)),
|
||||
inputs=target_model_folder_input,
|
||||
outputs=target_model_folder_input,
|
||||
show_progress=False,
|
||||
)
|
||||
target_model_folder_input.change(
|
||||
fn=lambda path: gr.Dropdown(
|
||||
choices=[""] + list_target_folder(path)
|
||||
),
|
||||
inputs=target_model_folder_input,
|
||||
outputs=target_model_folder_input,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Column(), gr.Row():
|
||||
target_model_name_input = gr.Textbox(
|
||||
label='Target model name',
|
||||
placeholder='target model name...',
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(), gr.Row():
|
||||
target_model_name_input = gr.Textbox(
|
||||
label="Target model name",
|
||||
placeholder="target model name...",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
target_model_type = gr.Dropdown(
|
||||
label='Target model type',
|
||||
label="Target model type",
|
||||
choices=[
|
||||
'diffuser',
|
||||
'diffuser_safetensors',
|
||||
'ckpt',
|
||||
'safetensors',
|
||||
"diffuser",
|
||||
"diffuser_safetensors",
|
||||
"ckpt",
|
||||
"safetensors",
|
||||
],
|
||||
)
|
||||
target_save_precision_type = gr.Dropdown(
|
||||
label='Target model precision',
|
||||
choices=['unspecified', 'fp16', 'bf16', 'float'],
|
||||
value='unspecified',
|
||||
label="Target model precision",
|
||||
choices=["unspecified", "fp16", "bf16", "float"],
|
||||
value="unspecified",
|
||||
)
|
||||
unet_use_linear_projection = gr.Checkbox(
|
||||
label='UNet linear projection',
|
||||
label="UNet linear projection",
|
||||
value=False,
|
||||
info="Enable for Hugging Face's stabilityai models",
|
||||
)
|
||||
|
||||
convert_button = gr.Button('Convert model')
|
||||
convert_button = gr.Button("Convert model")
|
||||
|
||||
convert_button.click(
|
||||
convert_model,
|
||||
|
|
|
|||
|
|
@ -11,34 +11,71 @@ from rich.traceback import install as traceback_install
|
|||
|
||||
log = None
|
||||
|
||||
|
||||
def setup_logging(clean=False, debug=False):
|
||||
global log
|
||||
|
||||
|
||||
if log is not None:
|
||||
return log
|
||||
|
||||
|
||||
try:
|
||||
if clean and os.path.isfile('setup.log'):
|
||||
os.remove('setup.log')
|
||||
time.sleep(0.1) # prevent race condition
|
||||
if clean and os.path.isfile("setup.log"):
|
||||
os.remove("setup.log")
|
||||
time.sleep(0.1) # prevent race condition
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', encoding='utf-8', force=True)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s | %(levelname)s | %(pathname)s | %(message)s",
|
||||
filename="setup.log",
|
||||
filemode="a",
|
||||
encoding="utf-8",
|
||||
force=True,
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', force=True)
|
||||
|
||||
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
|
||||
"traceback.border": "black",
|
||||
"traceback.border.syntax_error": "black",
|
||||
"inspect.value.border": "black",
|
||||
}))
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s | %(levelname)s | %(pathname)s | %(message)s",
|
||||
filename="setup.log",
|
||||
filemode="a",
|
||||
force=True,
|
||||
)
|
||||
|
||||
console = Console(
|
||||
log_time=True,
|
||||
log_time_format="%H:%M:%S-%f",
|
||||
theme=Theme(
|
||||
{
|
||||
"traceback.border": "black",
|
||||
"traceback.border.syntax_error": "black",
|
||||
"inspect.value.border": "black",
|
||||
}
|
||||
),
|
||||
)
|
||||
pretty_install(console=console)
|
||||
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
|
||||
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG if debug else logging.INFO, console=console)
|
||||
traceback_install(
|
||||
console=console,
|
||||
extra_lines=1,
|
||||
width=console.width,
|
||||
word_wrap=False,
|
||||
indent_guides=False,
|
||||
suppress=[],
|
||||
)
|
||||
rh = RichHandler(
|
||||
show_time=True,
|
||||
omit_repeated_times=False,
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
markup=False,
|
||||
rich_tracebacks=True,
|
||||
log_time_format="%H:%M:%S-%f",
|
||||
level=logging.DEBUG if debug else logging.INFO,
|
||||
console=console,
|
||||
)
|
||||
rh.set_name(logging.DEBUG if debug else logging.INFO)
|
||||
log = logging.getLogger("sd")
|
||||
log.addHandler(rh)
|
||||
|
||||
|
||||
return log
|
||||
|
|
|
|||
|
|
@ -9,29 +9,22 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
# def select_folder():
|
||||
# # Open a file dialog to select a directory
|
||||
# folder = filedialog.askdirectory()
|
||||
|
||||
# # Update the GUI to display the selected folder
|
||||
# selected_folder_label.config(text=folder)
|
||||
|
||||
|
||||
def dataset_balancing(concept_repeats, folder, insecure):
|
||||
|
||||
if not concept_repeats > 0:
|
||||
# Display an error message if the total number of repeats is not a valid integer
|
||||
msgbox('Please enter a valid integer for the total number of repeats.')
|
||||
msgbox("Please enter a valid integer for the total number of repeats.")
|
||||
return
|
||||
|
||||
concept_repeats = int(concept_repeats)
|
||||
|
||||
# Check if folder exist
|
||||
if folder == '' or not os.path.isdir(folder):
|
||||
msgbox('Please enter a valid folder for balancing.')
|
||||
if folder == "" or not os.path.isdir(folder):
|
||||
msgbox("Please enter a valid folder for balancing.")
|
||||
return
|
||||
|
||||
pattern = re.compile(r'^\d+_.+$')
|
||||
pattern = re.compile(r"^\d+_.+$")
|
||||
|
||||
# Iterate over the subdirectories in the selected folder
|
||||
for subdir in os.listdir(folder):
|
||||
|
|
@ -44,7 +37,7 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||
image_files = [
|
||||
f
|
||||
for f in files
|
||||
if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
|
||||
if f.endswith((".jpg", ".jpeg", ".png", ".gif", ".webp"))
|
||||
]
|
||||
|
||||
# Count the number of image files
|
||||
|
|
@ -52,20 +45,18 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||
|
||||
if images == 0:
|
||||
log.info(
|
||||
f'No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}'
|
||||
f"No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}"
|
||||
)
|
||||
|
||||
# Check if the subdirectory name starts with a number inside braces,
|
||||
# indicating that the repeats value should be multiplied
|
||||
match = re.match(r'^\{(\d+\.?\d*)\}', subdir)
|
||||
match = re.match(r"^\{(\d+\.?\d*)\}", subdir)
|
||||
if match:
|
||||
# Multiply the repeats value by the number inside the braces
|
||||
if not images == 0:
|
||||
repeats = max(
|
||||
1,
|
||||
round(
|
||||
concept_repeats / images * float(match.group(1))
|
||||
),
|
||||
round(concept_repeats / images * float(match.group(1))),
|
||||
)
|
||||
else:
|
||||
repeats = 0
|
||||
|
|
@ -77,32 +68,30 @@ def dataset_balancing(concept_repeats, folder, insecure):
|
|||
repeats = 0
|
||||
|
||||
# Check if the subdirectory name already has a number at the beginning
|
||||
match = re.match(r'^\d+_', subdir)
|
||||
match = re.match(r"^\d+_", subdir)
|
||||
if match:
|
||||
# Replace the existing number with the new number
|
||||
old_name = os.path.join(folder, subdir)
|
||||
new_name = os.path.join(
|
||||
folder, f'{repeats}_{subdir[match.end():]}'
|
||||
)
|
||||
new_name = os.path.join(folder, f"{repeats}_{subdir[match.end():]}")
|
||||
else:
|
||||
# Add the new number at the beginning of the name
|
||||
old_name = os.path.join(folder, subdir)
|
||||
new_name = os.path.join(folder, f'{repeats}_{subdir}')
|
||||
new_name = os.path.join(folder, f"{repeats}_{subdir}")
|
||||
|
||||
os.rename(old_name, new_name)
|
||||
else:
|
||||
log.info(
|
||||
f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...'
|
||||
f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..."
|
||||
)
|
||||
|
||||
msgbox('Dataset balancing completed...')
|
||||
msgbox("Dataset balancing completed...")
|
||||
|
||||
|
||||
def warning(insecure):
|
||||
if insecure:
|
||||
if boolbox(
|
||||
f'WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?',
|
||||
choices=('Yes, I like danger', 'No, get me out of here'),
|
||||
f"WARNING!!! You have asked to rename non kohya_ss <num>_<text> folders...\n\nAre you sure you want to do that?",
|
||||
choices=("Yes, I like danger", "No, get me out of here"),
|
||||
):
|
||||
return True
|
||||
else:
|
||||
|
|
@ -113,12 +102,12 @@ def gradio_dataset_balancing_tab(headless=False):
|
|||
|
||||
current_dataset_dir = os.path.join(scriptdir, "data")
|
||||
|
||||
with gr.Tab('Dreambooth/LoRA Dataset balancing'):
|
||||
with gr.Tab("Dreambooth/LoRA Dataset balancing"):
|
||||
gr.Markdown(
|
||||
'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
|
||||
"This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training."
|
||||
)
|
||||
gr.Markdown(
|
||||
'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!'
|
||||
"WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!"
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
|
||||
|
|
@ -128,15 +117,23 @@ def gradio_dataset_balancing_tab(headless=False):
|
|||
return list(list_dirs(path))
|
||||
|
||||
select_dataset_folder_input = gr.Dropdown(
|
||||
label='Dataset folder (folder containing the concepts folders to balance...)',
|
||||
label="Dataset folder (folder containing the concepts folders to balance...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_dataset_dirs(current_dataset_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(select_dataset_folder_input, lambda: None, lambda: {"choices": list_dataset_dirs(current_dataset_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
select_dataset_folder_input,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_dataset_dirs(current_dataset_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
select_dataset_folder_button = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
select_dataset_folder_button.click(
|
||||
get_folder_path,
|
||||
|
|
@ -147,7 +144,7 @@ def gradio_dataset_balancing_tab(headless=False):
|
|||
total_repeats_number = gr.Number(
|
||||
value=1000,
|
||||
interactive=True,
|
||||
label='Training steps per concept per epoch',
|
||||
label="Training steps per concept per epoch",
|
||||
)
|
||||
select_dataset_folder_input.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_dirs(path)),
|
||||
|
|
@ -156,13 +153,13 @@ def gradio_dataset_balancing_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Accordion('Advanced options', open=False):
|
||||
with gr.Accordion("Advanced options", open=False):
|
||||
insecure = gr.Checkbox(
|
||||
value=False,
|
||||
label='DANGER!!! -- Insecure folder renaming -- DANGER!!!',
|
||||
label="DANGER!!! -- Insecure folder renaming -- DANGER!!!",
|
||||
)
|
||||
insecure.change(warning, inputs=insecure, outputs=insecure)
|
||||
balance_button = gr.Button('Balance dataset')
|
||||
balance_button = gr.Button("Balance dataset")
|
||||
balance_button.click(
|
||||
dataset_balancing,
|
||||
inputs=[
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import gradio as gr
|
||||
from easygui import diropenbox, msgbox
|
||||
from easygui import msgbox
|
||||
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
|
||||
import shutil
|
||||
import os
|
||||
|
|
@ -12,13 +12,13 @@ log = setup_logging()
|
|||
|
||||
|
||||
def copy_info_to_Folders_tab(training_folder):
|
||||
img_folder = os.path.join(training_folder, 'img')
|
||||
if os.path.exists(os.path.join(training_folder, 'reg')):
|
||||
reg_folder = os.path.join(training_folder, 'reg')
|
||||
img_folder = os.path.join(training_folder, "img")
|
||||
if os.path.exists(os.path.join(training_folder, "reg")):
|
||||
reg_folder = os.path.join(training_folder, "reg")
|
||||
else:
|
||||
reg_folder = ''
|
||||
model_folder = os.path.join(training_folder, 'model')
|
||||
log_folder = os.path.join(training_folder, 'log')
|
||||
reg_folder = ""
|
||||
model_folder = os.path.join(training_folder, "model")
|
||||
log_folder = os.path.join(training_folder, "log")
|
||||
|
||||
return img_folder, reg_folder, model_folder, log_folder
|
||||
|
||||
|
|
@ -44,17 +44,17 @@ def dreambooth_folder_preparation(
|
|||
os.makedirs(util_training_dir_output, exist_ok=True)
|
||||
|
||||
# Check for instance prompt
|
||||
if util_instance_prompt_input == '':
|
||||
msgbox('Instance prompt missing...')
|
||||
if util_instance_prompt_input == "":
|
||||
msgbox("Instance prompt missing...")
|
||||
return
|
||||
|
||||
# Check for class prompt
|
||||
if util_class_prompt_input == '':
|
||||
msgbox('Class prompt missing...')
|
||||
if util_class_prompt_input == "":
|
||||
msgbox("Class prompt missing...")
|
||||
return
|
||||
|
||||
# Create the training_dir path
|
||||
if util_training_images_dir_input == '':
|
||||
if util_training_images_dir_input == "":
|
||||
log.info(
|
||||
"Training images directory is missing... can't perform the required task..."
|
||||
)
|
||||
|
|
@ -62,60 +62,54 @@ def dreambooth_folder_preparation(
|
|||
else:
|
||||
training_dir = os.path.join(
|
||||
util_training_dir_output,
|
||||
f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
|
||||
f"img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}",
|
||||
)
|
||||
|
||||
# Remove folders if they exist
|
||||
if os.path.exists(training_dir):
|
||||
log.info(f'Removing existing directory {training_dir}...')
|
||||
log.info(f"Removing existing directory {training_dir}...")
|
||||
shutil.rmtree(training_dir)
|
||||
|
||||
# Copy the training images to their respective directories
|
||||
log.info(f'Copy {util_training_images_dir_input} to {training_dir}...')
|
||||
log.info(f"Copy {util_training_images_dir_input} to {training_dir}...")
|
||||
shutil.copytree(util_training_images_dir_input, training_dir)
|
||||
|
||||
if not util_regularization_images_dir_input == '':
|
||||
if not util_regularization_images_dir_input == "":
|
||||
# Create the regularization_dir path
|
||||
if not util_regularization_images_repeat_input > 0:
|
||||
log.info(
|
||||
'Repeats is missing... not copying regularisation images...'
|
||||
)
|
||||
log.info("Repeats is missing... not copying regularisation images...")
|
||||
else:
|
||||
regularization_dir = os.path.join(
|
||||
util_training_dir_output,
|
||||
f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
|
||||
f"reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}",
|
||||
)
|
||||
|
||||
# Remove folders if they exist
|
||||
if os.path.exists(regularization_dir):
|
||||
log.info(
|
||||
f'Removing existing directory {regularization_dir}...'
|
||||
)
|
||||
log.info(f"Removing existing directory {regularization_dir}...")
|
||||
shutil.rmtree(regularization_dir)
|
||||
|
||||
# Copy the regularisation images to their respective directories
|
||||
log.info(
|
||||
f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
|
||||
)
|
||||
shutil.copytree(
|
||||
util_regularization_images_dir_input, regularization_dir
|
||||
f"Copy {util_regularization_images_dir_input} to {regularization_dir}..."
|
||||
)
|
||||
shutil.copytree(util_regularization_images_dir_input, regularization_dir)
|
||||
else:
|
||||
log.info(
|
||||
'Regularization images directory is missing... not copying regularisation images...'
|
||||
"Regularization images directory is missing... not copying regularisation images..."
|
||||
)
|
||||
|
||||
# create log and model folder
|
||||
# Check if the log folder exists and create it if it doesn't
|
||||
if not os.path.exists(os.path.join(util_training_dir_output, 'log')):
|
||||
os.makedirs(os.path.join(util_training_dir_output, 'log'))
|
||||
if not os.path.exists(os.path.join(util_training_dir_output, "log")):
|
||||
os.makedirs(os.path.join(util_training_dir_output, "log"))
|
||||
|
||||
# Check if the model folder exists and create it if it doesn't
|
||||
if not os.path.exists(os.path.join(util_training_dir_output, 'model')):
|
||||
os.makedirs(os.path.join(util_training_dir_output, 'model'))
|
||||
if not os.path.exists(os.path.join(util_training_dir_output, "model")):
|
||||
os.makedirs(os.path.join(util_training_dir_output, "model"))
|
||||
|
||||
log.info(
|
||||
f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
|
||||
f"Done creating kohya_ss training folder structure at {util_training_dir_output}..."
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -132,22 +126,22 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
current_reg_data_dir = os.path.join(scriptdir, "data")
|
||||
current_train_output_dir = os.path.join(scriptdir, "data")
|
||||
|
||||
with gr.Tab('Dreambooth/LoRA Folder preparation'):
|
||||
with gr.Tab("Dreambooth/LoRA Folder preparation"):
|
||||
gr.Markdown(
|
||||
'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.'
|
||||
"This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly."
|
||||
)
|
||||
with gr.Row():
|
||||
util_instance_prompt_input = gr.Textbox(
|
||||
label='Instance prompt',
|
||||
placeholder='Eg: asd',
|
||||
label="Instance prompt",
|
||||
placeholder="Eg: asd",
|
||||
interactive=True,
|
||||
value = config.get(key="dataset_preparation.instance_prompt", default="")
|
||||
value=config.get(key="dataset_preparation.instance_prompt", default=""),
|
||||
)
|
||||
util_class_prompt_input = gr.Textbox(
|
||||
label='Class prompt',
|
||||
placeholder='Eg: person',
|
||||
label="Class prompt",
|
||||
placeholder="Eg: person",
|
||||
interactive=True,
|
||||
value = config.get(key="dataset_preparation.class_prompt", default=""),
|
||||
value=config.get(key="dataset_preparation.class_prompt", default=""),
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
|
||||
|
|
@ -157,15 +151,26 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
return list(list_dirs(path))
|
||||
|
||||
util_training_images_dir_input = gr.Dropdown(
|
||||
label='Training images (directory containing the training images)',
|
||||
label="Training images (directory containing the training images)",
|
||||
interactive=True,
|
||||
choices=[config.get(key="dataset_preparation.images_folder", default="")] + list_train_data_dirs(current_train_data_dir),
|
||||
choices=[
|
||||
config.get(key="dataset_preparation.images_folder", default="")
|
||||
]
|
||||
+ list_train_data_dirs(current_train_data_dir),
|
||||
value=config.get(key="dataset_preparation.images_folder", default=""),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(util_training_images_dir_input, lambda: None, lambda: {"choices": list_train_data_dirs(current_train_data_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
util_training_images_dir_input,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_train_data_dirs(current_train_data_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_util_training_images_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_util_training_images_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
@ -173,10 +178,10 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
show_progress=False,
|
||||
)
|
||||
util_training_images_repeat_input = gr.Number(
|
||||
label='Repeats',
|
||||
label="Repeats",
|
||||
value=40,
|
||||
interactive=True,
|
||||
elem_id='number_input',
|
||||
elem_id="number_input",
|
||||
)
|
||||
util_training_images_dir_input.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
|
||||
|
|
@ -186,21 +191,35 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
|
||||
def list_reg_data_dirs(path):
|
||||
nonlocal current_reg_data_dir
|
||||
current_reg_data_dir = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
util_regularization_images_dir_input = gr.Dropdown(
|
||||
label='Regularisation images (Optional. directory containing the regularisation images)',
|
||||
label="Regularisation images (Optional. directory containing the regularisation images)",
|
||||
interactive=True,
|
||||
choices=[config.get(key="dataset_preparation.reg_images_folder", default="")] + list_reg_data_dirs(current_reg_data_dir),
|
||||
value=config.get(key="dataset_preparation.reg_images_folder", default=""),
|
||||
choices=[
|
||||
config.get(key="dataset_preparation.reg_images_folder", default="")
|
||||
]
|
||||
+ list_reg_data_dirs(current_reg_data_dir),
|
||||
value=config.get(
|
||||
key="dataset_preparation.reg_images_folder", default=""
|
||||
),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(util_regularization_images_dir_input, lambda: None, lambda: {"choices": list_reg_data_dirs(current_reg_data_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
util_regularization_images_dir_input,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_reg_data_dirs(current_reg_data_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_util_regularization_images_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_util_regularization_images_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
@ -208,10 +227,10 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
show_progress=False,
|
||||
)
|
||||
util_regularization_images_repeat_input = gr.Number(
|
||||
label='Repeats',
|
||||
label="Repeats",
|
||||
value=1,
|
||||
interactive=True,
|
||||
elem_id='number_input',
|
||||
elem_id="number_input",
|
||||
)
|
||||
util_regularization_images_dir_input.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_reg_data_dirs(path)),
|
||||
|
|
@ -220,32 +239,44 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
show_progress=False,
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
|
||||
def list_train_output_dirs(path):
|
||||
nonlocal current_train_output_dir
|
||||
current_train_output_dir = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
util_training_dir_output = gr.Dropdown(
|
||||
label='Destination training directory (where formatted training and regularisation folders will be placed)',
|
||||
label="Destination training directory (where formatted training and regularisation folders will be placed)",
|
||||
interactive=True,
|
||||
choices=[config.get(key="train_data_dir", default="")] + list_train_output_dirs(current_train_output_dir),
|
||||
choices=[config.get(key="train_data_dir", default="")]
|
||||
+ list_train_output_dirs(current_train_output_dir),
|
||||
value=config.get(key="train_data_dir", default=""),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(util_training_dir_output, lambda: None, lambda: {"choices": list_train_output_dirs(current_train_output_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
util_training_dir_output,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_train_output_dirs(current_train_output_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_util_training_dir_output = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_util_training_dir_output.click(
|
||||
get_folder_path, outputs=util_training_dir_output
|
||||
)
|
||||
util_training_dir_output.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_train_output_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(
|
||||
choices=[""] + list_train_output_dirs(path)
|
||||
),
|
||||
inputs=util_training_dir_output,
|
||||
outputs=util_training_dir_output,
|
||||
show_progress=False,
|
||||
)
|
||||
button_prepare_training_data = gr.Button('Prepare training data')
|
||||
button_prepare_training_data = gr.Button("Prepare training data")
|
||||
button_prepare_training_data.click(
|
||||
dreambooth_folder_preparation,
|
||||
inputs=[
|
||||
|
|
@ -259,15 +290,3 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
],
|
||||
show_progress=False,
|
||||
)
|
||||
# button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
|
||||
# button_copy_info_to_Folders_tab.click(
|
||||
# copy_info_to_Folders_tab,
|
||||
# inputs=[util_training_dir_output],
|
||||
# outputs=[
|
||||
# train_data_dir_input,
|
||||
# reg_data_dir_input,
|
||||
# output_dir_input,
|
||||
# logging_dir_input,
|
||||
# ],
|
||||
# show_progress=False,
|
||||
# )
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import json
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
from .common_gui import (
|
||||
get_file_path,
|
||||
|
|
@ -25,7 +24,6 @@ from .class_basic_training import BasicTraining
|
|||
from .class_advanced_training import AdvancedTraining
|
||||
from .class_folders import Folders
|
||||
from .class_command_executor import CommandExecutor
|
||||
from .class_sdxl_parameters import SDXLParameters
|
||||
from .tensorboard_gui import (
|
||||
gradio_tensorboard,
|
||||
start_tensorboard,
|
||||
|
|
@ -459,7 +457,9 @@ def train_model(
|
|||
return
|
||||
|
||||
if dataset_config:
|
||||
log.info("Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations...")
|
||||
log.info(
|
||||
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
|
||||
)
|
||||
else:
|
||||
# Get a list of all subfolders in train_data_dir, excluding hidden folders
|
||||
subfolders = [
|
||||
|
|
@ -509,7 +509,9 @@ def train_model(
|
|||
log.info(f"Folder {folder} : steps {steps}")
|
||||
|
||||
if total_steps == 0:
|
||||
log.info(f"No images were found in folder {train_data_dir}... please rectify!")
|
||||
log.info(
|
||||
f"No images were found in folder {train_data_dir}... please rectify!"
|
||||
)
|
||||
return
|
||||
|
||||
# Print the result
|
||||
|
|
@ -541,7 +543,9 @@ def train_model(
|
|||
# calculate stop encoder training
|
||||
if int(stop_text_encoder_training_pct) == -1:
|
||||
stop_text_encoder_training = -1
|
||||
elif stop_text_encoder_training_pct == None or (not max_train_steps == "" or not max_train_steps == "0"):
|
||||
elif stop_text_encoder_training_pct == None or (
|
||||
not max_train_steps == "" or not max_train_steps == "0"
|
||||
):
|
||||
stop_text_encoder_training = 0
|
||||
else:
|
||||
stop_text_encoder_training = math.ceil(
|
||||
|
|
@ -729,7 +733,7 @@ def dreambooth_tab(
|
|||
|
||||
with gr.Tab("Training"), gr.Column(variant="compact"):
|
||||
gr.Markdown("Train a custom model using kohya dreambooth python code...")
|
||||
|
||||
|
||||
with gr.Accordion("Accelerate launch", open=False), gr.Column():
|
||||
accelerate_launch = AccelerateLaunch()
|
||||
|
||||
|
|
@ -738,7 +742,7 @@ def dreambooth_tab(
|
|||
|
||||
with gr.Accordion("Folders", open=False), gr.Group():
|
||||
folders = Folders(headless=headless, config=config)
|
||||
|
||||
|
||||
with gr.Accordion("Parameters", open=False), gr.Column():
|
||||
with gr.Accordion("Basic", open="True"):
|
||||
with gr.Group(elem_id="basic_tab"):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import subprocess
|
|||
import os
|
||||
import sys
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
|
|
@ -16,10 +15,10 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -30,13 +29,13 @@ def extract_dylora(
|
|||
unit,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model == '':
|
||||
msgbox('Invalid DyLoRA model file')
|
||||
if model == "":
|
||||
msgbox("Invalid DyLoRA model file")
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if not os.path.isfile(model):
|
||||
msgbox('The provided DyLoRA model is not a file')
|
||||
msgbox("The provided DyLoRA model is not a file")
|
||||
return
|
||||
|
||||
if os.path.dirname(save_to) == "":
|
||||
|
|
@ -51,21 +50,23 @@ def extract_dylora(
|
|||
save_to = f"{path}_tmp{ext}"
|
||||
|
||||
run_cmd = (
|
||||
fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"'
|
||||
rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"'
|
||||
)
|
||||
run_cmd += fr' --save_to "{save_to}"'
|
||||
run_cmd += fr' --model "{model}"'
|
||||
run_cmd += f' --unit {unit}'
|
||||
run_cmd += rf' --save_to "{save_to}"'
|
||||
run_cmd += rf' --model "{model}"'
|
||||
run_cmd += f" --unit {unit}"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
log.info('Done extracting DyLoRA...')
|
||||
log.info("Done extracting DyLoRA...")
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -77,12 +78,10 @@ def gradio_extract_dylora_tab(headless=False):
|
|||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
|
||||
with gr.Tab('Extract DyLoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can extract a DyLoRA network from a finetuned model.'
|
||||
)
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
with gr.Tab("Extract DyLoRA"):
|
||||
gr.Markdown("This utility can extract a DyLoRA network from a finetuned model.")
|
||||
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
|
||||
def list_models(path):
|
||||
nonlocal current_model_dir
|
||||
|
|
@ -96,17 +95,22 @@ def gradio_extract_dylora_tab(headless=False):
|
|||
|
||||
with gr.Group(), gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label='DyLoRA model (path to the DyLoRA model to extract from)',
|
||||
label="DyLoRA model (path to the DyLoRA model to extract from)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_file.click(
|
||||
|
|
@ -117,17 +121,22 @@ def gradio_extract_dylora_tab(headless=False):
|
|||
)
|
||||
|
||||
save_to = gr.Dropdown(
|
||||
label='Save to (path where to save the extracted LoRA model...)',
|
||||
label="Save to (path where to save the extracted LoRA model...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
unit = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=256,
|
||||
label='Network Dimension (Rank)',
|
||||
label="Network Dimension (Rank)",
|
||||
value=1,
|
||||
step=1,
|
||||
interactive=True,
|
||||
|
|
@ -146,7 +155,7 @@ def gradio_extract_dylora_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
extract_button = gr.Button('Extract LoRA model')
|
||||
extract_button = gr.Button("Extract LoRA model")
|
||||
|
||||
extract_button.click(
|
||||
extract_dylora,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -17,10 +16,10 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -42,21 +41,21 @@ def extract_lora(
|
|||
load_precision,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model_tuned == '':
|
||||
log.info('Invalid finetuned model file')
|
||||
if model_tuned == "":
|
||||
log.info("Invalid finetuned model file")
|
||||
return
|
||||
|
||||
if model_org == '':
|
||||
log.info('Invalid base model file')
|
||||
if model_org == "":
|
||||
log.info("Invalid base model file")
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if not os.path.isfile(model_tuned):
|
||||
log.info('The provided finetuned model is not a file')
|
||||
log.info("The provided finetuned model is not a file")
|
||||
return
|
||||
|
||||
if not os.path.isfile(model_org):
|
||||
log.info('The provided base model is not a file')
|
||||
log.info("The provided base model is not a file")
|
||||
return
|
||||
|
||||
if os.path.dirname(save_to) == "":
|
||||
|
|
@ -74,31 +73,33 @@ def extract_lora(
|
|||
return
|
||||
|
||||
run_cmd = (
|
||||
fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"'
|
||||
rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"'
|
||||
)
|
||||
run_cmd += f' --load_precision {load_precision}'
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += fr' --save_to "{save_to}"'
|
||||
run_cmd += fr' --model_org "{model_org}"'
|
||||
run_cmd += fr' --model_tuned "{model_tuned}"'
|
||||
run_cmd += f' --dim {dim}'
|
||||
run_cmd += f' --device {device}'
|
||||
run_cmd += f" --load_precision {load_precision}"
|
||||
run_cmd += f" --save_precision {save_precision}"
|
||||
run_cmd += rf' --save_to "{save_to}"'
|
||||
run_cmd += rf' --model_org "{model_org}"'
|
||||
run_cmd += rf' --model_tuned "{model_tuned}"'
|
||||
run_cmd += f" --dim {dim}"
|
||||
run_cmd += f" --device {device}"
|
||||
if conv_dim > 0:
|
||||
run_cmd += f' --conv_dim {conv_dim}'
|
||||
run_cmd += f" --conv_dim {conv_dim}"
|
||||
if v2:
|
||||
run_cmd += f' --v2'
|
||||
run_cmd += f" --v2"
|
||||
if sdxl:
|
||||
run_cmd += f' --sdxl'
|
||||
run_cmd += f' --clamp_quantile {clamp_quantile}'
|
||||
run_cmd += f' --min_diff {min_diff}'
|
||||
run_cmd += f" --sdxl"
|
||||
run_cmd += f" --clamp_quantile {clamp_quantile}"
|
||||
run_cmd += f" --min_diff {min_diff}"
|
||||
if sdxl:
|
||||
run_cmd += f' --load_original_model_to {load_original_model_to}'
|
||||
run_cmd += f' --load_tuned_model_to {load_tuned_model_to}'
|
||||
run_cmd += f" --load_original_model_to {load_original_model_to}"
|
||||
run_cmd += f" --load_tuned_model_to {load_tuned_model_to}"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
|
@ -132,29 +133,31 @@ def gradio_extract_lora_tab(headless=False):
|
|||
def change_sdxl(sdxl):
|
||||
return gr.Dropdown(visible=sdxl), gr.Dropdown(visible=sdxl)
|
||||
|
||||
|
||||
with gr.Tab('Extract LoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can extract a LoRA network from a finetuned model.'
|
||||
)
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
|
||||
model_ext_name = gr.Textbox(value='Model types', visible=False)
|
||||
with gr.Tab("Extract LoRA"):
|
||||
gr.Markdown("This utility can extract a LoRA network from a finetuned model.")
|
||||
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
model_ext = gr.Textbox(value="*.ckpt *.safetensors", visible=False)
|
||||
model_ext_name = gr.Textbox(value="Model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
model_tuned = gr.Dropdown(
|
||||
label='Finetuned model (path to the finetuned model to extract)',
|
||||
label="Finetuned model (path to the finetuned model to extract)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(model_tuned, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
model_tuned,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_model_tuned_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_tuned_file.click(
|
||||
|
|
@ -164,25 +167,31 @@ def gradio_extract_lora_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
load_tuned_model_to = gr.Radio(
|
||||
label='Load finetuned model to',
|
||||
choices=['cpu', 'cuda', 'cuda:0'],
|
||||
value='cpu',
|
||||
interactive=True, scale=1,
|
||||
label="Load finetuned model to",
|
||||
choices=["cpu", "cuda", "cuda:0"],
|
||||
value="cpu",
|
||||
interactive=True,
|
||||
scale=1,
|
||||
info="only for SDXL",
|
||||
visible=False,
|
||||
)
|
||||
model_org = gr.Dropdown(
|
||||
label='Stable Diffusion base model (original model: ckpt or safetensors file)',
|
||||
label="Stable Diffusion base model (original model: ckpt or safetensors file)",
|
||||
interactive=True,
|
||||
choices=[""] + list_org_models(current_model_org_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(model_org, lambda: None, lambda: {"choices": list_org_models(current_model_org_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
model_org,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_org_models(current_model_org_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_model_org_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_model_org_file.click(
|
||||
|
|
@ -192,27 +201,33 @@ def gradio_extract_lora_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
load_original_model_to = gr.Dropdown(
|
||||
label='Load Stable Diffusion base model to',
|
||||
choices=['cpu', 'cuda', 'cuda:0'],
|
||||
value='cpu',
|
||||
interactive=True, scale=1,
|
||||
label="Load Stable Diffusion base model to",
|
||||
choices=["cpu", "cuda", "cuda:0"],
|
||||
value="cpu",
|
||||
interactive=True,
|
||||
scale=1,
|
||||
info="only for SDXL",
|
||||
visible=False,
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
save_to = gr.Dropdown(
|
||||
label='Save to (path where to save the extracted LoRA model...)',
|
||||
label="Save to (path where to save the extracted LoRA model...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_save_to.click(
|
||||
|
|
@ -222,16 +237,18 @@ def gradio_extract_lora_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
save_precision = gr.Radio(
|
||||
label='Save precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='fp16',
|
||||
interactive=True, scale=1,
|
||||
label="Save precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="fp16",
|
||||
interactive=True,
|
||||
scale=1,
|
||||
)
|
||||
load_precision = gr.Radio(
|
||||
label='Load precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='fp16',
|
||||
interactive=True, scale=1,
|
||||
label="Load precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="fp16",
|
||||
interactive=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
model_tuned.change(
|
||||
|
|
@ -256,7 +273,7 @@ def gradio_extract_lora_tab(headless=False):
|
|||
dim = gr.Slider(
|
||||
minimum=4,
|
||||
maximum=1024,
|
||||
label='Network Dimension (Rank)',
|
||||
label="Network Dimension (Rank)",
|
||||
value=128,
|
||||
step=1,
|
||||
interactive=True,
|
||||
|
|
@ -264,13 +281,13 @@ def gradio_extract_lora_tab(headless=False):
|
|||
conv_dim = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
label='Conv Dimension (Rank)',
|
||||
label="Conv Dimension (Rank)",
|
||||
value=128,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
clamp_quantile = gr.Number(
|
||||
label='Clamp Quantile',
|
||||
label="Clamp Quantile",
|
||||
value=0.99,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
|
|
@ -278,7 +295,7 @@ def gradio_extract_lora_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
min_diff = gr.Number(
|
||||
label='Minimum difference',
|
||||
label="Minimum difference",
|
||||
value=0.01,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
|
|
@ -286,21 +303,25 @@ def gradio_extract_lora_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
|
||||
sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True)
|
||||
v2 = gr.Checkbox(label="v2", value=False, interactive=True)
|
||||
sdxl = gr.Checkbox(label="SDXL", value=False, interactive=True)
|
||||
device = gr.Radio(
|
||||
label='Device',
|
||||
label="Device",
|
||||
choices=[
|
||||
'cpu',
|
||||
'cuda',
|
||||
"cpu",
|
||||
"cuda",
|
||||
],
|
||||
value='cuda',
|
||||
value="cuda",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
sdxl.change(change_sdxl, inputs=sdxl, outputs=[load_tuned_model_to, load_original_model_to])
|
||||
|
||||
extract_button = gr.Button('Extract LoRA model')
|
||||
sdxl.change(
|
||||
change_sdxl,
|
||||
inputs=sdxl,
|
||||
outputs=[load_tuned_model_to, load_original_model_to],
|
||||
)
|
||||
|
||||
extract_button = gr.Button("Extract LoRA model")
|
||||
|
||||
extract_button.click(
|
||||
extract_lora,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import os
|
|||
import sys
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
|
|
@ -74,7 +73,7 @@ def extract_lycoris_locon(
|
|||
path, ext = os.path.splitext(output_name)
|
||||
output_name = f"{path}_tmp{ext}"
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/lycoris_locon_extract.py"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lycoris_locon_extract.py"'
|
||||
if is_sdxl:
|
||||
run_cmd += f" --is_sdxl"
|
||||
if is_v2:
|
||||
|
|
@ -99,19 +98,21 @@ def extract_lycoris_locon(
|
|||
run_cmd += f" --sparsity {sparsity}"
|
||||
if disable_cp:
|
||||
run_cmd += f" --disable_cp"
|
||||
run_cmd += fr' "{base_model}"'
|
||||
run_cmd += fr' "{db_model}"'
|
||||
run_cmd += fr' "{output_name}"'
|
||||
run_cmd += rf' "{base_model}"'
|
||||
run_cmd += rf' "{db_model}"'
|
||||
run_cmd += rf' "{output_name}"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
log.info('Done extracting...')
|
||||
log.info("Done extracting...")
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -185,11 +186,16 @@ def gradio_extract_lycoris_locon_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(db_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
db_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_db_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=['tool'],
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_db_model_file.click(
|
||||
|
|
@ -205,11 +211,16 @@ def gradio_extract_lycoris_locon_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(base_model, lambda: None, lambda: {"choices": list_base_models(current_base_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
base_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_base_models(current_base_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_base_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=['tool'],
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_base_model_file.click(
|
||||
|
|
@ -227,11 +238,16 @@ def gradio_extract_lycoris_locon_tab(headless=False):
|
|||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
output_name,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_output_name = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=['tool'],
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_output_name.click(
|
||||
|
|
@ -270,7 +286,9 @@ def gradio_extract_lycoris_locon_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True, scale=1)
|
||||
is_sdxl = gr.Checkbox(
|
||||
label="is SDXL", value=False, interactive=True, scale=1
|
||||
)
|
||||
|
||||
is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True, scale=1)
|
||||
with gr.Row():
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import math
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
from .common_gui import (
|
||||
get_file_path,
|
||||
|
|
@ -50,7 +49,8 @@ document_symbol = "\U0001F4C4" # 📄
|
|||
|
||||
PYTHON = sys.executable
|
||||
|
||||
presets_dir = fr'{scriptdir}/presets'
|
||||
presets_dir = rf"{scriptdir}/presets"
|
||||
|
||||
|
||||
def save_configuration(
|
||||
save_as,
|
||||
|
|
@ -319,7 +319,7 @@ def open_configuration(
|
|||
# Check if we are "applying" a preset or a config
|
||||
if apply_preset:
|
||||
log.info(f"Applying preset {training_preset}...")
|
||||
file_path = fr'{presets_dir}/finetune/{training_preset}.json'
|
||||
file_path = rf"{presets_dir}/finetune/{training_preset}.json"
|
||||
else:
|
||||
# If not applying a preset, set the `training_preset` field to an empty string
|
||||
# Find the index of the `training_preset` parameter using the `index()` method
|
||||
|
|
@ -482,32 +482,38 @@ def train_model(
|
|||
logging_dir=logging_dir,
|
||||
log_tracker_config=log_tracker_config,
|
||||
resume=resume,
|
||||
dataset_config=dataset_config
|
||||
dataset_config=dataset_config,
|
||||
):
|
||||
return
|
||||
|
||||
if not print_only_bool and check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
|
||||
if not print_only_bool and check_if_model_exist(
|
||||
output_name, output_dir, save_model_as, headless_bool
|
||||
):
|
||||
return
|
||||
|
||||
if dataset_config:
|
||||
log.info("Dataset config toml file used, skipping caption json file, image buckets, total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps creation...")
|
||||
log.info(
|
||||
"Dataset config toml file used, skipping caption json file, image buckets, total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps creation..."
|
||||
)
|
||||
else:
|
||||
# create caption json file
|
||||
if generate_caption_database:
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"'
|
||||
if caption_extension == "":
|
||||
run_cmd += f' --caption_extension=".caption"'
|
||||
else:
|
||||
run_cmd += f" --caption_extension={caption_extension}"
|
||||
run_cmd += fr' "{image_folder}"'
|
||||
run_cmd += fr' "{train_dir}/{caption_metadata_filename}"'
|
||||
run_cmd += rf' "{image_folder}"'
|
||||
run_cmd += rf' "{train_dir}/{caption_metadata_filename}"'
|
||||
if full_path:
|
||||
run_cmd += f" --full_path"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
if not print_only_bool:
|
||||
# Run the command
|
||||
|
|
@ -515,11 +521,11 @@ def train_model(
|
|||
|
||||
# create images buckets
|
||||
if generate_image_buckets:
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"'
|
||||
run_cmd += fr' "{image_folder}"'
|
||||
run_cmd += fr' "{train_dir}/{caption_metadata_filename}"'
|
||||
run_cmd += fr' "{train_dir}/{latent_metadata_filename}"'
|
||||
run_cmd += fr' "{pretrained_model_name_or_path}"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"'
|
||||
run_cmd += rf' "{image_folder}"'
|
||||
run_cmd += rf' "{train_dir}/{caption_metadata_filename}"'
|
||||
run_cmd += rf' "{train_dir}/{latent_metadata_filename}"'
|
||||
run_cmd += rf' "{pretrained_model_name_or_path}"'
|
||||
run_cmd += f" --batch_size={batch_size}"
|
||||
run_cmd += f" --max_resolution={max_resolution}"
|
||||
run_cmd += f" --min_bucket_reso={min_bucket_reso}"
|
||||
|
|
@ -530,13 +536,17 @@ def train_model(
|
|||
if full_path:
|
||||
run_cmd += f" --full_path"
|
||||
if sdxl_checkbox and sdxl_no_half_vae:
|
||||
log.info("Using mixed_precision = no because no half vae is selected...")
|
||||
log.info(
|
||||
"Using mixed_precision = no because no half vae is selected..."
|
||||
)
|
||||
run_cmd += f' --mixed_precision="no"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
if not print_only_bool:
|
||||
# Run the command
|
||||
|
|
@ -591,14 +601,14 @@ def train_model(
|
|||
)
|
||||
|
||||
if sdxl_checkbox:
|
||||
run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train.py"'
|
||||
run_cmd += rf' "{scriptdir}/sd-scripts/sdxl_train.py"'
|
||||
else:
|
||||
run_cmd += fr' "{scriptdir}/sd-scripts/fine_tune.py"'
|
||||
run_cmd += rf' "{scriptdir}/sd-scripts/fine_tune.py"'
|
||||
|
||||
in_json = (
|
||||
fr"{train_dir}/{latent_metadata_filename}"
|
||||
rf"{train_dir}/{latent_metadata_filename}"
|
||||
if use_latent_files == "Yes"
|
||||
else fr"{train_dir}/{caption_metadata_filename}"
|
||||
else rf"{train_dir}/{caption_metadata_filename}"
|
||||
)
|
||||
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
|
||||
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
|
||||
|
|
@ -732,7 +742,9 @@ def train_model(
|
|||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
executor.execute_command(run_cmd=run_cmd, env=env)
|
||||
|
|
@ -744,12 +756,14 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
with gr.Tab("Training"), gr.Column(variant="compact"):
|
||||
gr.Markdown("Train a custom model using kohya finetune python code...")
|
||||
|
||||
|
||||
with gr.Accordion("Accelerate launch", open=False), gr.Column():
|
||||
accelerate_launch = AccelerateLaunch()
|
||||
|
||||
with gr.Column():
|
||||
source_model = SourceModel(headless=headless, finetuning=True, config=config)
|
||||
source_model = SourceModel(
|
||||
headless=headless, finetuning=True, config=config
|
||||
)
|
||||
image_folder = source_model.train_data_dir
|
||||
output_name = source_model.output_name
|
||||
|
||||
|
|
@ -803,14 +817,17 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
|
||||
with gr.Row():
|
||||
gradient_accumulation_steps = gr.Number(
|
||||
label="Gradient accumulate steps", value="1",
|
||||
label="Gradient accumulate steps",
|
||||
value="1",
|
||||
)
|
||||
block_lr = gr.Textbox(
|
||||
label="Block LR (SDXL)",
|
||||
placeholder="(Optional)",
|
||||
info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3",
|
||||
)
|
||||
advanced_training = AdvancedTraining(headless=headless, finetuning=True, config=config)
|
||||
advanced_training = AdvancedTraining(
|
||||
headless=headless, finetuning=True, config=config
|
||||
)
|
||||
advanced_training.color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[advanced_training.color_aug],
|
||||
|
|
@ -866,7 +883,6 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
with gr.Accordion("Configuration", open=False):
|
||||
configuration = ConfigurationFile(headless=headless, config=config)
|
||||
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
|
@ -1005,7 +1021,9 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
inputs=[dummy_db_true, dummy_db_false, configuration.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[configuration.config_file_name] + settings_list + [training_preset],
|
||||
outputs=[configuration.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -1021,7 +1039,9 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
inputs=[dummy_db_false, dummy_db_false, configuration.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[configuration.config_file_name] + settings_list + [training_preset],
|
||||
outputs=[configuration.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -1062,16 +1082,16 @@ def finetune_tab(headless=False, config: dict = {}):
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
#config.button_save_as_config.click(
|
||||
# config.button_save_as_config.click(
|
||||
# save_configuration,
|
||||
# inputs=[dummy_db_true, config.config_file_name] + settings_list,
|
||||
# outputs=[config.config_file_name],
|
||||
# show_progress=False,
|
||||
#)
|
||||
# )
|
||||
|
||||
with gr.Tab("Guides"):
|
||||
gr.Markdown("This section provide Various Finetuning guides and information...")
|
||||
top_level_path = fr"{scriptdir}/docs/Finetuning/top_level.md"
|
||||
top_level_path = rf"{scriptdir}/docs/Finetuning/top_level.md"
|
||||
if os.path.exists(top_level_path):
|
||||
with open(os.path.join(top_level_path), "r", encoding="utf8") as file:
|
||||
guides_top_level = file.read() + "\n"
|
||||
|
|
|
|||
|
|
@ -24,31 +24,31 @@ def caption_images(
|
|||
postfix,
|
||||
):
|
||||
# Check for images_dir_input
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder is missing...')
|
||||
if train_data_dir == "":
|
||||
msgbox("Image folder is missing...")
|
||||
return
|
||||
|
||||
if caption_ext == '':
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
if caption_ext == "":
|
||||
msgbox("Please provide an extension for the caption files.")
|
||||
return
|
||||
|
||||
log.info(f'GIT captioning files in {train_data_dir}...')
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"'
|
||||
if not model_id == '':
|
||||
log.info(f"GIT captioning files in {train_data_dir}...")
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"'
|
||||
if not model_id == "":
|
||||
run_cmd += f' --model_id="{model_id}"'
|
||||
run_cmd += f' --batch_size="{int(batch_size)}"'
|
||||
run_cmd += (
|
||||
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
||||
)
|
||||
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
||||
run_cmd += f' --max_length="{int(max_length)}"'
|
||||
if caption_ext != '':
|
||||
if caption_ext != "":
|
||||
run_cmd += f' --caption_extension="{caption_ext}"'
|
||||
run_cmd += f' "{train_data_dir}"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
|
@ -61,7 +61,7 @@ def caption_images(
|
|||
postfix=postfix,
|
||||
)
|
||||
|
||||
log.info('...captioning done')
|
||||
log.info("...captioning done")
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -72,7 +72,11 @@ def caption_images(
|
|||
def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
default_train_dir = default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data")
|
||||
default_train_dir = (
|
||||
default_train_dir
|
||||
if default_train_dir is not None
|
||||
else os.path.join(scriptdir, "data")
|
||||
)
|
||||
current_train_dir = default_train_dir
|
||||
|
||||
def list_train_dirs(path):
|
||||
|
|
@ -80,21 +84,29 @@ def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
|
|||
current_train_dir = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
with gr.Tab('GIT Captioning'):
|
||||
with gr.Tab("GIT Captioning"):
|
||||
gr.Markdown(
|
||||
'This utility will use GIT to caption files for each images in a folder.'
|
||||
"This utility will use GIT to caption files for each images in a folder."
|
||||
)
|
||||
with gr.Group(), gr.Row():
|
||||
train_data_dir = gr.Dropdown(
|
||||
label='Image folder to caption (containing the images to caption)',
|
||||
label="Image folder to caption (containing the images to caption)",
|
||||
choices=[""] + list_train_dirs(default_train_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(current_train_dir)},"open_folder_small")
|
||||
create_refresh_button(
|
||||
train_data_dir,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_train_dirs(current_train_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_train_data_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_train_data_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
@ -103,42 +115,38 @@ def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
|
|||
)
|
||||
with gr.Row():
|
||||
caption_ext = gr.Textbox(
|
||||
label='Caption file extension',
|
||||
placeholder='Extension for caption file (e.g., .caption, .txt)',
|
||||
value='.txt',
|
||||
label="Caption file extension",
|
||||
placeholder="Extension for caption file (e.g., .caption, .txt)",
|
||||
value=".txt",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
prefix = gr.Textbox(
|
||||
label='Prefix to add to GIT caption',
|
||||
placeholder='(Optional)',
|
||||
label="Prefix to add to GIT caption",
|
||||
placeholder="(Optional)",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
postfix = gr.Textbox(
|
||||
label='Postfix to add to GIT caption',
|
||||
placeholder='(Optional)',
|
||||
label="Postfix to add to GIT caption",
|
||||
placeholder="(Optional)",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
batch_size = gr.Number(
|
||||
value=1, label='Batch size', interactive=True
|
||||
)
|
||||
batch_size = gr.Number(value=1, label="Batch size", interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
max_data_loader_n_workers = gr.Number(
|
||||
value=2, label='Number of workers', interactive=True
|
||||
)
|
||||
max_length = gr.Number(
|
||||
value=75, label='Max length', interactive=True
|
||||
value=2, label="Number of workers", interactive=True
|
||||
)
|
||||
max_length = gr.Number(value=75, label="Max length", interactive=True)
|
||||
model_id = gr.Textbox(
|
||||
label='Model',
|
||||
placeholder='(Optional) model id for GIT in Hugging Face',
|
||||
label="Model",
|
||||
placeholder="(Optional) model id for GIT in Hugging Face",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
caption_button = gr.Button('Caption images')
|
||||
caption_button = gr.Button("Caption images")
|
||||
|
||||
caption_button.click(
|
||||
caption_images,
|
||||
|
|
|
|||
|
|
@ -22,38 +22,40 @@ def group_images(
|
|||
generate_captions,
|
||||
caption_ext,
|
||||
):
|
||||
if input_folder == '':
|
||||
msgbox('Input folder is missing...')
|
||||
if input_folder == "":
|
||||
msgbox("Input folder is missing...")
|
||||
return
|
||||
|
||||
if output_folder == '':
|
||||
msgbox('Please provide an output folder.')
|
||||
if output_folder == "":
|
||||
msgbox("Please provide an output folder.")
|
||||
return
|
||||
|
||||
log.info(f'Grouping images in {input_folder}...')
|
||||
log.info(f"Grouping images in {input_folder}...")
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/group_images.py"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/group_images.py"'
|
||||
run_cmd += f' "{input_folder}"'
|
||||
run_cmd += f' "{output_folder}"'
|
||||
run_cmd += f' {(group_size)}'
|
||||
run_cmd += f" {(group_size)}"
|
||||
if include_subfolders:
|
||||
run_cmd += f' --include_subfolders'
|
||||
run_cmd += f" --include_subfolders"
|
||||
if do_not_copy_other_files:
|
||||
run_cmd += f' --do_not_copy_other_files'
|
||||
run_cmd += f" --do_not_copy_other_files"
|
||||
if generate_captions:
|
||||
run_cmd += f' --caption'
|
||||
run_cmd += f" --caption"
|
||||
if caption_ext:
|
||||
run_cmd += f' --caption_ext={caption_ext}'
|
||||
run_cmd += f" --caption_ext={caption_ext}"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
log.info('...grouping done')
|
||||
log.info("...grouping done")
|
||||
|
||||
|
||||
def gradio_group_images_gui_tab(headless=False):
|
||||
|
|
@ -72,22 +74,30 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
current_output_folder = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
with gr.Tab('Group Images'):
|
||||
with gr.Tab("Group Images"):
|
||||
gr.Markdown(
|
||||
'This utility will group images in a folder based on their aspect ratio.'
|
||||
"This utility will group images in a folder based on their aspect ratio."
|
||||
)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
input_folder = gr.Dropdown(
|
||||
label='Input folder (containing the images to group)',
|
||||
label="Input folder (containing the images to group)",
|
||||
interactive=True,
|
||||
choices=[""] + list_input_dirs(current_input_folder),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(input_folder, lambda: None, lambda: {"choices": list_input_dirs(current_input_folder)},"open_folder_small")
|
||||
create_refresh_button(
|
||||
input_folder,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_input_dirs(current_input_folder)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -96,15 +106,23 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
)
|
||||
|
||||
output_folder = gr.Dropdown(
|
||||
label='Output folder (where the grouped images will be stored)',
|
||||
label="Output folder (where the grouped images will be stored)",
|
||||
interactive=True,
|
||||
choices=[""] + list_output_dirs(current_output_folder),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(output_folder, lambda: None, lambda: {"choices": list_output_dirs(current_output_folder)},"open_folder_small")
|
||||
create_refresh_button(
|
||||
output_folder,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_output_dirs(current_output_folder)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_output_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_output_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -126,9 +144,9 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
group_size = gr.Slider(
|
||||
label='Group size',
|
||||
info='Number of images to group together',
|
||||
value='4',
|
||||
label="Group size",
|
||||
info="Number of images to group together",
|
||||
value="4",
|
||||
minimum=1,
|
||||
maximum=64,
|
||||
step=1,
|
||||
|
|
@ -136,31 +154,31 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
)
|
||||
|
||||
include_subfolders = gr.Checkbox(
|
||||
label='Include Subfolders',
|
||||
label="Include Subfolders",
|
||||
value=False,
|
||||
info='Include images in subfolders as well',
|
||||
info="Include images in subfolders as well",
|
||||
)
|
||||
|
||||
do_not_copy_other_files = gr.Checkbox(
|
||||
label='Do not copy other files',
|
||||
label="Do not copy other files",
|
||||
value=False,
|
||||
info='Do not copy other files in the input folder to the output folder',
|
||||
info="Do not copy other files in the input folder to the output folder",
|
||||
)
|
||||
|
||||
generate_captions = gr.Checkbox(
|
||||
label='Generate Captions',
|
||||
label="Generate Captions",
|
||||
value=False,
|
||||
info='Generate caption files for the grouped images based on their folder name',
|
||||
info="Generate caption files for the grouped images based on their folder name",
|
||||
)
|
||||
|
||||
caption_ext = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='Caption file extension (e.g., .txt)',
|
||||
value='.txt',
|
||||
label="Caption Extension",
|
||||
placeholder="Caption file extension (e.g., .txt)",
|
||||
value=".txt",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
group_images_button = gr.Button('Group images')
|
||||
group_images_button = gr.Button("Group images")
|
||||
|
||||
group_images_button.click(
|
||||
group_images,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ localizationMap = {}
|
|||
|
||||
def load_localizations():
|
||||
localizationMap.clear()
|
||||
dirname = './localizations'
|
||||
dirname = "./localizations"
|
||||
for file in os.listdir(dirname):
|
||||
fn, ext = os.path.splitext(file)
|
||||
if ext.lower() != ".json":
|
||||
|
|
@ -28,4 +28,4 @@ def load_language_js(language_name: str) -> str:
|
|||
return f"window.localization = {json.dumps(data)}"
|
||||
|
||||
|
||||
load_localizations()
|
||||
load_localizations()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@ import kohya_gui.localization as localization
|
|||
|
||||
|
||||
def file_path(fn):
|
||||
return f'file={os.path.abspath(fn)}?{os.path.getmtime(fn)}'
|
||||
return f"file={os.path.abspath(fn)}?{os.path.getmtime(fn)}"
|
||||
|
||||
|
||||
def js_html_str(language):
|
||||
head = f'<script type="text/javascript">{localization.load_language_js(language)}</script>\n'
|
||||
head += f'<script type="text/javascript" src="{file_path("js/script.js")}"></script>\n'
|
||||
head += (
|
||||
f'<script type="text/javascript" src="{file_path("js/script.js")}"></script>\n'
|
||||
)
|
||||
head += f'<script type="text/javascript" src="{file_path("js/localization.js")}"></script>\n'
|
||||
return head
|
||||
|
||||
|
|
@ -22,12 +24,12 @@ def add_javascript(language):
|
|||
|
||||
def template_response(*args, **kwargs):
|
||||
res = localization.GrRoutesTemplateResponse(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</head>', f'{jsStr}</head>'.encode("utf8"))
|
||||
res.body = res.body.replace(b"</head>", f"{jsStr}</head>".encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
|
||||
gr.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
||||
if not hasattr(localization, 'GrRoutesTemplateResponse'):
|
||||
if not hasattr(localization, "GrRoutesTemplateResponse"):
|
||||
localization.GrRoutesTemplateResponse = gr.routes.templates.TemplateResponse
|
||||
|
|
|
|||
|
|
@ -905,22 +905,21 @@ def train_model(
|
|||
)
|
||||
# Convert learning rates to float once and store the result for re-use
|
||||
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
|
||||
text_encoder_lr_float = float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
||||
text_encoder_lr_float = (
|
||||
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
|
||||
)
|
||||
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
|
||||
|
||||
# Determine the training configuration based on learning rate values
|
||||
# Sets flags for training specific components based on the provided learning rates.
|
||||
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
|
||||
output_message(
|
||||
msg="Please input learning rate values.", headless=headless_bool
|
||||
)
|
||||
output_message(msg="Please input learning rate values.", headless=headless_bool)
|
||||
return
|
||||
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
|
||||
network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0
|
||||
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
|
||||
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
|
||||
|
||||
|
||||
# Define a dictionary of parameters
|
||||
run_cmd_params = {
|
||||
"adaptive_noise_scale": adaptive_noise_scale,
|
||||
|
|
@ -1087,10 +1086,10 @@ def lora_tab(
|
|||
gr.Markdown(
|
||||
"Train a custom model using kohya train network LoRA python code..."
|
||||
)
|
||||
|
||||
|
||||
with gr.Accordion("Accelerate launch", open=False), gr.Column():
|
||||
accelerate_launch = AccelerateLaunch()
|
||||
|
||||
|
||||
with gr.Column():
|
||||
source_model = SourceModel(
|
||||
save_model_as_choices=[
|
||||
|
|
@ -1124,7 +1123,7 @@ def lora_tab(
|
|||
json_files.append(os.path.join("user_presets", preset_name))
|
||||
|
||||
return json_files
|
||||
|
||||
|
||||
with gr.Accordion("Basic", open="True"):
|
||||
training_preset = gr.Dropdown(
|
||||
label="Presets",
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
IMAGES_TO_SHOW = 5
|
||||
IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.webp', '.bmp')
|
||||
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
|
||||
auto_save = True
|
||||
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ def _get_quick_tags(quick_tags_text):
|
|||
"""
|
||||
Gets a list of tags from the quick tags text box
|
||||
"""
|
||||
quick_tags = [t.strip() for t in quick_tags_text.split(',') if t.strip()]
|
||||
quick_tags = [t.strip() for t in quick_tags_text.split(",") if t.strip()]
|
||||
quick_tags_set = set(quick_tags)
|
||||
return quick_tags, quick_tags_set
|
||||
|
||||
|
|
@ -38,34 +38,31 @@ def _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set):
|
|||
Updates a list of caption checkboxes to show possible tags and tags
|
||||
already included in the caption
|
||||
"""
|
||||
caption_tags_have = [c.strip() for c in caption.split(',') if c.strip()]
|
||||
caption_tags_unique = [
|
||||
t for t in caption_tags_have if t not in quick_tags_set
|
||||
]
|
||||
caption_tags_have = [c.strip() for c in caption.split(",") if c.strip()]
|
||||
caption_tags_unique = [t for t in caption_tags_have if t not in quick_tags_set]
|
||||
caption_tags_all = quick_tags + caption_tags_unique
|
||||
return gr.CheckboxGroup(
|
||||
choices=caption_tags_all, value=caption_tags_have
|
||||
)
|
||||
return gr.CheckboxGroup(choices=caption_tags_all, value=caption_tags_have)
|
||||
|
||||
|
||||
def paginate_go(page, max_page):
|
||||
try:
|
||||
page = float(page)
|
||||
except:
|
||||
msgbox(f'Invalid page num: {page}')
|
||||
msgbox(f"Invalid page num: {page}")
|
||||
return
|
||||
return paginate(page, max_page, 0)
|
||||
|
||||
|
||||
def paginate(page, max_page, page_change):
|
||||
return int(max(min(page + page_change, max_page), 1))
|
||||
|
||||
|
||||
def save_caption(caption, caption_ext, image_file, images_dir):
|
||||
caption_path = _get_caption_path(image_file, images_dir, caption_ext)
|
||||
with open(caption_path, 'w+', encoding='utf8') as f:
|
||||
with open(caption_path, "w+", encoding="utf8") as f:
|
||||
f.write(caption)
|
||||
|
||||
log.info(f'Wrote captions to {caption_path}')
|
||||
log.info(f"Wrote captions to {caption_path}")
|
||||
|
||||
|
||||
def update_quick_tags(quick_tags_text, *image_caption_texts):
|
||||
|
|
@ -101,7 +98,7 @@ def update_image_tags(
|
|||
output_tags = [t for t in quick_tags if t in selected_tags_set] + [
|
||||
t for t in selected_tags if t not in quick_tags_set
|
||||
]
|
||||
caption = ', '.join(output_tags)
|
||||
caption = ", ".join(output_tags)
|
||||
|
||||
if auto_save:
|
||||
save_caption(caption, caption_ext, image_file, images_dir)
|
||||
|
|
@ -122,50 +119,46 @@ def import_tags_from_captions(
|
|||
|
||||
# Check for images_dir
|
||||
if not images_dir:
|
||||
msgbox('Image folder is missing...')
|
||||
msgbox("Image folder is missing...")
|
||||
return empty_return()
|
||||
|
||||
if not os.path.exists(images_dir):
|
||||
msgbox('Image folder does not exist...')
|
||||
msgbox("Image folder does not exist...")
|
||||
return empty_return()
|
||||
|
||||
if not caption_ext:
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
msgbox("Please provide an extension for the caption files.")
|
||||
return empty_return()
|
||||
|
||||
if quick_tags_text:
|
||||
if not boolbox(
|
||||
f'Are you sure you wish to overwrite the current quick tags?',
|
||||
choices=('Yes', 'No'),
|
||||
f"Are you sure you wish to overwrite the current quick tags?",
|
||||
choices=("Yes", "No"),
|
||||
):
|
||||
return empty_return()
|
||||
|
||||
images_list = os.listdir(images_dir)
|
||||
image_files = [
|
||||
f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)
|
||||
]
|
||||
image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
|
||||
|
||||
# Use a set for lookup but store order with list
|
||||
tags = []
|
||||
tags_set = set()
|
||||
for image_file in image_files:
|
||||
caption_file_path = _get_caption_path(
|
||||
image_file, images_dir, caption_ext
|
||||
)
|
||||
caption_file_path = _get_caption_path(image_file, images_dir, caption_ext)
|
||||
if os.path.exists(caption_file_path):
|
||||
with open(caption_file_path, 'r', encoding='utf8') as f:
|
||||
with open(caption_file_path, "r", encoding="utf8") as f:
|
||||
caption = f.read()
|
||||
for tag in caption.split(','):
|
||||
for tag in caption.split(","):
|
||||
tag = tag.strip()
|
||||
tag_key = tag.lower()
|
||||
if not tag_key in tags_set:
|
||||
# Ignore extra spaces
|
||||
total_words = len(re.findall(r'\s+', tag)) + 1
|
||||
total_words = len(re.findall(r"\s+", tag)) + 1
|
||||
if total_words <= ignore_load_tags_word_count:
|
||||
tags.append(tag)
|
||||
tags_set.add(tag_key)
|
||||
|
||||
return ', '.join(tags)
|
||||
return ", ".join(tags)
|
||||
|
||||
|
||||
def load_images(images_dir, caption_ext, loaded_images_dir, page, max_page):
|
||||
|
|
@ -180,15 +173,15 @@ def load_images(images_dir, caption_ext, loaded_images_dir, page, max_page):
|
|||
|
||||
# Check for images_dir
|
||||
if not images_dir:
|
||||
msgbox('Image folder is missing...')
|
||||
msgbox("Image folder is missing...")
|
||||
return empty_return()
|
||||
|
||||
if not os.path.exists(images_dir):
|
||||
msgbox('Image folder does not exist...')
|
||||
msgbox("Image folder does not exist...")
|
||||
return empty_return()
|
||||
|
||||
if not caption_ext:
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
msgbox("Please provide an extension for the caption files.")
|
||||
return empty_return()
|
||||
|
||||
# Load Images
|
||||
|
|
@ -212,12 +205,10 @@ def update_images(
|
|||
|
||||
# Load Images
|
||||
images_list = os.listdir(images_dir)
|
||||
image_files = [
|
||||
f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)
|
||||
]
|
||||
image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)]
|
||||
|
||||
# Quick tags
|
||||
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or '')
|
||||
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or "")
|
||||
|
||||
# Display Images
|
||||
rows = []
|
||||
|
|
@ -231,22 +222,18 @@ def update_images(
|
|||
show_row = image_index < len(image_files)
|
||||
|
||||
image_path = None
|
||||
caption = ''
|
||||
caption = ""
|
||||
tag_checkboxes = None
|
||||
if show_row:
|
||||
image_file = image_files[image_index]
|
||||
image_path = os.path.join(images_dir, image_file)
|
||||
|
||||
caption_file_path = _get_caption_path(
|
||||
image_file, images_dir, caption_ext
|
||||
)
|
||||
caption_file_path = _get_caption_path(image_file, images_dir, caption_ext)
|
||||
if os.path.exists(caption_file_path):
|
||||
with open(caption_file_path, 'r', encoding='utf8') as f:
|
||||
with open(caption_file_path, "r", encoding="utf8") as f:
|
||||
caption = f.read()
|
||||
|
||||
tag_checkboxes = _get_tag_checkbox_updates(
|
||||
caption, quick_tags, quick_tags_set
|
||||
)
|
||||
tag_checkboxes = _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
|
||||
rows.append(gr.Row(visible=show_row))
|
||||
image_paths.append(image_path)
|
||||
captions.append(caption)
|
||||
|
|
@ -266,7 +253,11 @@ def update_images(
|
|||
def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data")
|
||||
default_images_dir = (
|
||||
default_images_dir
|
||||
if default_images_dir is not None
|
||||
else os.path.join(scriptdir, "data")
|
||||
)
|
||||
current_images_dir = default_images_dir
|
||||
|
||||
# Function to list directories
|
||||
|
|
@ -276,39 +267,45 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
current_images_dir = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
with gr.Tab('Manual Captioning'):
|
||||
gr.Markdown(
|
||||
'This utility allows quick captioning and tagging of images.'
|
||||
)
|
||||
with gr.Tab("Manual Captioning"):
|
||||
gr.Markdown("This utility allows quick captioning and tagging of images.")
|
||||
page = gr.Number(-1, visible=False)
|
||||
max_page = gr.Number(1, visible=False)
|
||||
loaded_images_dir = gr.Text(visible=False)
|
||||
with gr.Group(), gr.Row():
|
||||
images_dir = gr.Dropdown(
|
||||
label='Image folder to caption (containing the images to caption)',
|
||||
label="Image folder to caption (containing the images to caption)",
|
||||
choices=[""] + list_images_dirs(default_images_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small")
|
||||
create_refresh_button(
|
||||
images_dir,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_images_dirs(current_images_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
folder_button = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless)
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
folder_button.click(
|
||||
get_folder_path,
|
||||
outputs=images_dir,
|
||||
show_progress=False,
|
||||
)
|
||||
load_images_button = gr.Button('Load', elem_id='open_folder')
|
||||
load_images_button = gr.Button("Load", elem_id="open_folder")
|
||||
caption_ext = gr.Textbox(
|
||||
label='Caption file extension',
|
||||
placeholder='Extension for caption file (e.g., .caption, .txt)',
|
||||
value='.txt',
|
||||
label="Caption file extension",
|
||||
placeholder="Extension for caption file (e.g., .caption, .txt)",
|
||||
value=".txt",
|
||||
interactive=True,
|
||||
)
|
||||
auto_save = gr.Checkbox(
|
||||
label='Autosave', info='Options', value=True, interactive=True
|
||||
label="Autosave", info="Options", value=True, interactive=True
|
||||
)
|
||||
|
||||
images_dir.change(
|
||||
|
|
@ -321,39 +318,39 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
# Caption Section
|
||||
with gr.Group(), gr.Row():
|
||||
quick_tags_text = gr.Textbox(
|
||||
label='Quick Tags',
|
||||
placeholder='Comma separated list of tags',
|
||||
label="Quick Tags",
|
||||
placeholder="Comma separated list of tags",
|
||||
interactive=True,
|
||||
)
|
||||
import_tags_button = gr.Button('Import', elem_id='open_folder')
|
||||
import_tags_button = gr.Button("Import", elem_id="open_folder")
|
||||
ignore_load_tags_word_count = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
value=3,
|
||||
step=1,
|
||||
label='Ignore Imported Tags Above Word Count',
|
||||
label="Ignore Imported Tags Above Word Count",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
# Next/Prev section generator
|
||||
def render_pagination():
|
||||
gr.Button('< Prev', elem_id='open_folder').click(
|
||||
gr.Button("< Prev", elem_id="open_folder").click(
|
||||
paginate,
|
||||
inputs=[page, max_page, gr.Number(-1, visible=False)],
|
||||
outputs=[page],
|
||||
)
|
||||
page_count = gr.Label('Page 1', label='Page')
|
||||
page_count = gr.Label("Page 1", label="Page")
|
||||
page_goto_text = gr.Textbox(
|
||||
label='Goto page',
|
||||
placeholder='Page Number',
|
||||
label="Goto page",
|
||||
placeholder="Page Number",
|
||||
interactive=True,
|
||||
)
|
||||
gr.Button('Go >', elem_id='open_folder').click(
|
||||
gr.Button("Go >", elem_id="open_folder").click(
|
||||
paginate_go,
|
||||
inputs=[page_goto_text, max_page],
|
||||
outputs=[page],
|
||||
)
|
||||
gr.Button('Next >', elem_id='open_folder').click(
|
||||
gr.Button("Next >", elem_id="open_folder").click(
|
||||
paginate,
|
||||
inputs=[page, max_page, gr.Number(1, visible=False)],
|
||||
outputs=[page],
|
||||
|
|
@ -374,19 +371,20 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
with gr.Row(visible=False) as row:
|
||||
image_file = gr.Text(visible=False)
|
||||
image_files.append(image_file)
|
||||
image_image = gr.Image(type='filepath')
|
||||
image_image = gr.Image(type="filepath")
|
||||
image_images.append(image_image)
|
||||
image_caption_text = gr.TextArea(
|
||||
label='Captions',
|
||||
placeholder='Input captions',
|
||||
label="Captions",
|
||||
placeholder="Input captions",
|
||||
interactive=True,
|
||||
)
|
||||
image_caption_texts.append(image_caption_text)
|
||||
tag_checkboxes = gr.CheckboxGroup(
|
||||
[], label='Tags', interactive=True
|
||||
)
|
||||
tag_checkboxes = gr.CheckboxGroup([], label="Tags", interactive=True)
|
||||
save_button = gr.Button(
|
||||
'💾', elem_id='open_folder_small', elem_classes=['tool'], visible=False
|
||||
"💾",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=False,
|
||||
)
|
||||
save_buttons.append(save_button)
|
||||
|
||||
|
|
@ -485,9 +483,9 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
)
|
||||
# Update the key on page and image dir change
|
||||
listener_kwargs = {
|
||||
'fn': lambda p, i: f'{p}-{i}',
|
||||
'inputs': [page, loaded_images_dir],
|
||||
'outputs': image_update_key,
|
||||
"fn": lambda p, i: f"{p}-{i}",
|
||||
"inputs": [page, loaded_images_dir],
|
||||
"outputs": image_update_key,
|
||||
}
|
||||
page.change(**listener_kwargs)
|
||||
loaded_images_dir.change(**listener_kwargs)
|
||||
|
|
@ -495,15 +493,14 @@ def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
|||
# Save buttons visibility
|
||||
# (on auto-save on/off)
|
||||
auto_save.change(
|
||||
lambda auto_save: [gr.Button(visible=not auto_save)]
|
||||
* IMAGES_TO_SHOW,
|
||||
lambda auto_save: [gr.Button(visible=not auto_save)] * IMAGES_TO_SHOW,
|
||||
inputs=auto_save,
|
||||
outputs=save_buttons,
|
||||
)
|
||||
|
||||
# Page Count
|
||||
page.change(
|
||||
lambda page, max_page: [f'Page {int(page)} / {int(max_page)}'] * 2,
|
||||
lambda page, max_page: [f"Page {int(page)} / {int(max_page)}"] * 2,
|
||||
inputs=[page, max_page],
|
||||
outputs=[page_count1, page_count2],
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -9,16 +9,22 @@ import gradio as gr
|
|||
from easygui import msgbox
|
||||
|
||||
# Local module imports
|
||||
from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
create_refresh_button,
|
||||
)
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -27,7 +33,7 @@ def check_model(model):
|
|||
if not model:
|
||||
return True
|
||||
if not os.path.isfile(model):
|
||||
msgbox(f'The provided {model} is not a file')
|
||||
msgbox(f"The provided {model} is not a file")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
@ -47,14 +53,14 @@ class GradioMergeLoRaTab:
|
|||
self.build_tab()
|
||||
|
||||
def save_inputs_to_json(self, file_path, inputs):
|
||||
with open(file_path, 'w') as file:
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(inputs, file)
|
||||
log.info(f'Saved inputs to {file_path}')
|
||||
log.info(f"Saved inputs to {file_path}")
|
||||
|
||||
def load_inputs_from_json(self, file_path):
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
inputs = json.load(file)
|
||||
log.info(f'Loaded inputs from {file_path}')
|
||||
log.info(f"Loaded inputs from {file_path}")
|
||||
return inputs
|
||||
|
||||
def build_tab(self):
|
||||
|
|
@ -95,29 +101,34 @@ class GradioMergeLoRaTab:
|
|||
current_save_dir = path
|
||||
return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
|
||||
|
||||
with gr.Tab('Merge LoRA'):
|
||||
with gr.Tab("Merge LoRA"):
|
||||
gr.Markdown(
|
||||
'This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint.'
|
||||
"This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint."
|
||||
)
|
||||
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
|
||||
ckpt_ext_name = gr.Textbox(value='SD model types', visible=False)
|
||||
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
ckpt_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
|
||||
ckpt_ext_name = gr.Textbox(value="SD model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
sd_model = gr.Dropdown(
|
||||
label='SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)',
|
||||
label="SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)",
|
||||
interactive=True,
|
||||
choices=[""] + list_sd_models(current_sd_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(sd_model, lambda: None, lambda: {"choices": list_sd_models(current_sd_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
sd_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_sd_models(current_sd_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
sd_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
sd_model_file.click(
|
||||
|
|
@ -126,7 +137,7 @@ class GradioMergeLoRaTab:
|
|||
outputs=sd_model,
|
||||
show_progress=False,
|
||||
)
|
||||
sdxl_model = gr.Checkbox(label='SDXL model', value=False)
|
||||
sdxl_model = gr.Checkbox(label="SDXL model", value=False)
|
||||
|
||||
sd_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_sd_models(path)),
|
||||
|
|
@ -143,11 +154,16 @@ class GradioMergeLoRaTab:
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_a_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_a_models(current_a_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
|
|
@ -164,11 +180,16 @@ class GradioMergeLoRaTab:
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_b_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_b_models(current_b_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_b_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
|
|
@ -193,7 +214,7 @@ class GradioMergeLoRaTab:
|
|||
|
||||
with gr.Row():
|
||||
ratio_a = gr.Slider(
|
||||
label='Model A merge ratio (eg: 0.5 mean 50%)',
|
||||
label="Model A merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -202,7 +223,7 @@ class GradioMergeLoRaTab:
|
|||
)
|
||||
|
||||
ratio_b = gr.Slider(
|
||||
label='Model B merge ratio (eg: 0.5 mean 50%)',
|
||||
label="Model B merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -218,11 +239,16 @@ class GradioMergeLoRaTab:
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_c_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_c_models(current_c_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_c_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_c_model_file.click(
|
||||
|
|
@ -239,11 +265,16 @@ class GradioMergeLoRaTab:
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_d_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_d_models(current_d_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_d_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_lora_d_model_file.click(
|
||||
|
|
@ -267,7 +298,7 @@ class GradioMergeLoRaTab:
|
|||
|
||||
with gr.Row():
|
||||
ratio_c = gr.Slider(
|
||||
label='Model C merge ratio (eg: 0.5 mean 50%)',
|
||||
label="Model C merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -276,7 +307,7 @@ class GradioMergeLoRaTab:
|
|||
)
|
||||
|
||||
ratio_d = gr.Slider(
|
||||
label='Model D merge ratio (eg: 0.5 mean 50%)',
|
||||
label="Model D merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -286,17 +317,22 @@ class GradioMergeLoRaTab:
|
|||
|
||||
with gr.Group(), gr.Row():
|
||||
save_to = gr.Dropdown(
|
||||
label='Save to (path for the file to save...)',
|
||||
label="Save to (path for the file to save...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_d_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
button_save_to.click(
|
||||
|
|
@ -306,15 +342,15 @@ class GradioMergeLoRaTab:
|
|||
show_progress=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label='Merge precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='float',
|
||||
label="Merge precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="float",
|
||||
interactive=True,
|
||||
)
|
||||
save_precision = gr.Radio(
|
||||
label='Save precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='fp16',
|
||||
label="Save precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="fp16",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
|
@ -325,7 +361,7 @@ class GradioMergeLoRaTab:
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
merge_button = gr.Button('Merge model')
|
||||
merge_button = gr.Button("Merge model")
|
||||
|
||||
merge_button.click(
|
||||
self.merge_lora,
|
||||
|
|
@ -364,7 +400,7 @@ class GradioMergeLoRaTab:
|
|||
save_precision,
|
||||
):
|
||||
|
||||
log.info('Merge model...')
|
||||
log.info("Merge model...")
|
||||
models = [
|
||||
sd_model,
|
||||
lora_a_model,
|
||||
|
|
@ -377,7 +413,7 @@ class GradioMergeLoRaTab:
|
|||
|
||||
if not verify_conditions(sd_model, lora_models):
|
||||
log.info(
|
||||
'Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided.'
|
||||
"Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided."
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -386,36 +422,36 @@ class GradioMergeLoRaTab:
|
|||
return
|
||||
|
||||
if not sdxl_model:
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/merge_lora.py"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/merge_lora.py"'
|
||||
else:
|
||||
run_cmd = (
|
||||
fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"'
|
||||
rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"'
|
||||
)
|
||||
if sd_model:
|
||||
run_cmd += fr' --sd_model "{sd_model}"'
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += f' --precision {precision}'
|
||||
run_cmd += fr' --save_to "{save_to}"'
|
||||
run_cmd += rf' --sd_model "{sd_model}"'
|
||||
run_cmd += f" --save_precision {save_precision}"
|
||||
run_cmd += f" --precision {precision}"
|
||||
run_cmd += rf' --save_to "{save_to}"'
|
||||
|
||||
# Create a space-separated string of non-empty models (from the second element onwards), enclosed in double quotes
|
||||
models_cmd = ' '.join([fr'"{model}"' for model in lora_models if model])
|
||||
models_cmd = " ".join([rf'"{model}"' for model in lora_models if model])
|
||||
|
||||
# Create a space-separated string of non-zero ratios corresponding to non-empty LoRa models
|
||||
valid_ratios = [
|
||||
ratios[i] for i, model in enumerate(lora_models) if model
|
||||
]
|
||||
ratios_cmd = ' '.join([str(ratio) for ratio in valid_ratios])
|
||||
valid_ratios = [ratios[i] for i, model in enumerate(lora_models) if model]
|
||||
ratios_cmd = " ".join([str(ratio) for ratio in valid_ratios])
|
||||
|
||||
if models_cmd:
|
||||
run_cmd += f' --models {models_cmd}'
|
||||
run_cmd += f' --ratios {ratios_cmd}'
|
||||
run_cmd += f" --models {models_cmd}"
|
||||
run_cmd += f" --ratios {ratios_cmd}"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
log.info('Done merging...')
|
||||
log.info("Done merging...")
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -34,29 +34,31 @@ def merge_lycoris(
|
|||
is_sdxl,
|
||||
is_v2,
|
||||
):
|
||||
log.info('Merge model...')
|
||||
log.info("Merge model...")
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/merge_lycoris.py"'
|
||||
run_cmd += fr' "{base_model}"'
|
||||
run_cmd += fr' "{lycoris_model}"'
|
||||
run_cmd += fr' "{output_name}"'
|
||||
run_cmd += f' --weight {weight}'
|
||||
run_cmd += f' --device {device}'
|
||||
run_cmd += f' --dtype {dtype}'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/merge_lycoris.py"'
|
||||
run_cmd += rf' "{base_model}"'
|
||||
run_cmd += rf' "{lycoris_model}"'
|
||||
run_cmd += rf' "{output_name}"'
|
||||
run_cmd += f" --weight {weight}"
|
||||
run_cmd += f" --device {device}"
|
||||
run_cmd += f" --dtype {dtype}"
|
||||
if is_sdxl:
|
||||
run_cmd += f' --is_sdxl'
|
||||
run_cmd += f" --is_sdxl"
|
||||
if is_v2:
|
||||
run_cmd += f' --is_v2'
|
||||
run_cmd += f" --is_v2"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
log.info('Done merging...')
|
||||
log.info("Done merging...")
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -84,30 +86,33 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
current_save_dir = path
|
||||
return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
|
||||
|
||||
with gr.Tab('Merge LyCORIS'):
|
||||
gr.Markdown(
|
||||
'This utility can merge a LyCORIS model into a SD checkpoint.'
|
||||
)
|
||||
with gr.Tab("Merge LyCORIS"):
|
||||
gr.Markdown("This utility can merge a LyCORIS model into a SD checkpoint.")
|
||||
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
|
||||
ckpt_ext_name = gr.Textbox(value='SD model types', visible=False)
|
||||
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
ckpt_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
|
||||
ckpt_ext_name = gr.Textbox(value="SD model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
base_model = gr.Dropdown(
|
||||
label='SD Model (Optional Stable Diffusion base model)',
|
||||
label="SD Model (Optional Stable Diffusion base model)",
|
||||
interactive=True,
|
||||
info='Provide a SD file path that you want to merge with the LyCORIS file',
|
||||
info="Provide a SD file path that you want to merge with the LyCORIS file",
|
||||
choices=[""] + list_models(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(base_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
base_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
base_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
base_model_file.click(
|
||||
|
|
@ -118,7 +123,7 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
)
|
||||
|
||||
lycoris_model = gr.Dropdown(
|
||||
label='LyCORIS model (path to the LyCORIS model)',
|
||||
label="LyCORIS model (path to the LyCORIS model)",
|
||||
interactive=True,
|
||||
choices=[""] + list_lycoris_model(current_save_dir),
|
||||
value="",
|
||||
|
|
@ -126,8 +131,8 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
)
|
||||
button_lycoris_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lycoris_model_file.click(
|
||||
|
|
@ -152,7 +157,7 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
|
||||
with gr.Row():
|
||||
weight = gr.Slider(
|
||||
label='Model A merge ratio (eg: 0.5 mean 50%)',
|
||||
label="Model A merge ratio (eg: 0.5 mean 50%)",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -162,17 +167,22 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
|
||||
with gr.Group(), gr.Row():
|
||||
output_name = gr.Dropdown(
|
||||
label='Save to (path for the checkpoint file to save...)',
|
||||
label="Save to (path for the checkpoint file to save...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
output_name,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_output_name = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_output_name.click(
|
||||
|
|
@ -182,26 +192,26 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
dtype = gr.Radio(
|
||||
label='Save dtype',
|
||||
label="Save dtype",
|
||||
choices=[
|
||||
'float',
|
||||
'float16',
|
||||
'float32',
|
||||
'float64',
|
||||
'bfloat',
|
||||
'bfloat16',
|
||||
"float",
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"bfloat",
|
||||
"bfloat16",
|
||||
],
|
||||
value='float16',
|
||||
value="float16",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
device = gr.Radio(
|
||||
label='Device',
|
||||
label="Device",
|
||||
choices=[
|
||||
'cpu',
|
||||
'cuda',
|
||||
"cpu",
|
||||
"cuda",
|
||||
],
|
||||
value='cpu',
|
||||
value="cpu",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
|
@ -213,10 +223,10 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
)
|
||||
|
||||
with gr.Row():
|
||||
is_sdxl = gr.Checkbox(label='is SDXL', value=False, interactive=True)
|
||||
is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
|
||||
is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True)
|
||||
is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True)
|
||||
|
||||
merge_button = gr.Button('Merge model')
|
||||
merge_button = gr.Button("Merge model")
|
||||
|
||||
merge_button.click(
|
||||
merge_lycoris,
|
||||
|
|
|
|||
|
|
@ -3,17 +3,23 @@ from easygui import msgbox
|
|||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
create_refresh_button,
|
||||
)
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -29,57 +35,57 @@ def resize_lora(
|
|||
verbose,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model == '':
|
||||
msgbox('Invalid model file')
|
||||
if model == "":
|
||||
msgbox("Invalid model file")
|
||||
return
|
||||
|
||||
# Check if source model exist
|
||||
if not os.path.isfile(model):
|
||||
msgbox('The provided model is not a file')
|
||||
msgbox("The provided model is not a file")
|
||||
return
|
||||
|
||||
if dynamic_method == 'sv_ratio':
|
||||
if dynamic_method == "sv_ratio":
|
||||
if float(dynamic_param) < 2:
|
||||
msgbox(
|
||||
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
|
||||
)
|
||||
msgbox(f"Dynamic parameter for {dynamic_method} need to be 2 or greater...")
|
||||
return
|
||||
|
||||
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
|
||||
if dynamic_method == "sv_fro" or dynamic_method == "sv_cumulative":
|
||||
if float(dynamic_param) < 0 or float(dynamic_param) > 1:
|
||||
msgbox(
|
||||
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
|
||||
f"Dynamic parameter for {dynamic_method} need to be between 0 and 1..."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if save_to end with one of the defines extension. If not add .safetensors.
|
||||
if not save_to.endswith(('.pt', '.safetensors')):
|
||||
save_to += '.safetensors'
|
||||
if not save_to.endswith((".pt", ".safetensors")):
|
||||
save_to += ".safetensors"
|
||||
|
||||
if device == '':
|
||||
device = 'cuda'
|
||||
if device == "":
|
||||
device = "cuda"
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/resize_lora.py"'
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += fr' --save_to "{save_to}"'
|
||||
run_cmd += fr' --model "{model}"'
|
||||
run_cmd += f' --new_rank {new_rank}'
|
||||
run_cmd += f' --device {device}'
|
||||
if not dynamic_method == 'None':
|
||||
run_cmd += f' --dynamic_method {dynamic_method}'
|
||||
run_cmd += f' --dynamic_param {dynamic_param}'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/resize_lora.py"'
|
||||
run_cmd += f" --save_precision {save_precision}"
|
||||
run_cmd += rf' --save_to "{save_to}"'
|
||||
run_cmd += rf' --model "{model}"'
|
||||
run_cmd += f" --new_rank {new_rank}"
|
||||
run_cmd += f" --device {device}"
|
||||
if not dynamic_method == "None":
|
||||
run_cmd += f" --dynamic_method {dynamic_method}"
|
||||
run_cmd += f" --dynamic_param {dynamic_param}"
|
||||
if verbose:
|
||||
run_cmd += f' --verbose'
|
||||
run_cmd += f" --verbose"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
||||
log.info('Done resizing...')
|
||||
log.info("Done resizing...")
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -101,25 +107,30 @@ def gradio_resize_lora_tab(headless=False):
|
|||
current_save_dir = path
|
||||
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
|
||||
|
||||
with gr.Tab('Resize LoRA'):
|
||||
gr.Markdown('This utility can resize a LoRA.')
|
||||
with gr.Tab("Resize LoRA"):
|
||||
gr.Markdown("This utility can resize a LoRA.")
|
||||
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label='Source LoRA (path to the LoRA to resize)',
|
||||
label="Source LoRA (path to the LoRA to resize)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
|
|
@ -129,17 +140,22 @@ def gradio_resize_lora_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
save_to = gr.Dropdown(
|
||||
label='Save to (path for the LoRA file to save...)',
|
||||
label="Save to (path for the LoRA file to save...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_save_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_save_to.click(
|
||||
|
|
@ -162,7 +178,7 @@ def gradio_resize_lora_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
new_rank = gr.Slider(
|
||||
label='Desired LoRA rank',
|
||||
label="Desired LoRA rank",
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
|
|
@ -170,37 +186,37 @@ def gradio_resize_lora_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
dynamic_method = gr.Radio(
|
||||
choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'],
|
||||
value='sv_fro',
|
||||
label='Dynamic method',
|
||||
choices=["None", "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
value="sv_fro",
|
||||
label="Dynamic method",
|
||||
interactive=True,
|
||||
)
|
||||
dynamic_param = gr.Textbox(
|
||||
label='Dynamic parameter',
|
||||
value='0.9',
|
||||
label="Dynamic parameter",
|
||||
value="0.9",
|
||||
interactive=True,
|
||||
placeholder='Value for the dynamic method selected.',
|
||||
placeholder="Value for the dynamic method selected.",
|
||||
)
|
||||
with gr.Row():
|
||||
|
||||
verbose = gr.Checkbox(label='Verbose logging', value=True)
|
||||
verbose = gr.Checkbox(label="Verbose logging", value=True)
|
||||
save_precision = gr.Radio(
|
||||
label='Save precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='fp16',
|
||||
label="Save precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="fp16",
|
||||
interactive=True,
|
||||
)
|
||||
device = gr.Radio(
|
||||
label='Device',
|
||||
label="Device",
|
||||
choices=[
|
||||
'cpu',
|
||||
'cuda',
|
||||
"cpu",
|
||||
"cuda",
|
||||
],
|
||||
value='cuda',
|
||||
value="cuda",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
convert_button = gr.Button('Resize model')
|
||||
convert_button = gr.Button("Resize model")
|
||||
|
||||
convert_button.click(
|
||||
resize_lora,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import os
|
|||
import sys
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
|
|
@ -17,10 +16,10 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
|
|
@ -53,49 +52,51 @@ def svd_merge_lora(
|
|||
ratio_c /= total_ratio
|
||||
ratio_d /= total_ratio
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"'
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += f' --precision {precision}'
|
||||
run_cmd += fr' --save_to "{save_to}"'
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"'
|
||||
run_cmd += f" --save_precision {save_precision}"
|
||||
run_cmd += f" --precision {precision}"
|
||||
run_cmd += rf' --save_to "{save_to}"'
|
||||
|
||||
run_cmd_models = ' --models'
|
||||
run_cmd_ratios = ' --ratios'
|
||||
run_cmd_models = " --models"
|
||||
run_cmd_ratios = " --ratios"
|
||||
# Add non-empty models and their ratios to the command
|
||||
if lora_a_model:
|
||||
if not os.path.isfile(lora_a_model):
|
||||
msgbox('The provided model A is not a file')
|
||||
msgbox("The provided model A is not a file")
|
||||
return
|
||||
run_cmd_models += fr' "{lora_a_model}"'
|
||||
run_cmd_ratios += f' {ratio_a}'
|
||||
run_cmd_models += rf' "{lora_a_model}"'
|
||||
run_cmd_ratios += f" {ratio_a}"
|
||||
if lora_b_model:
|
||||
if not os.path.isfile(lora_b_model):
|
||||
msgbox('The provided model B is not a file')
|
||||
msgbox("The provided model B is not a file")
|
||||
return
|
||||
run_cmd_models += fr' "{lora_b_model}"'
|
||||
run_cmd_ratios += f' {ratio_b}'
|
||||
run_cmd_models += rf' "{lora_b_model}"'
|
||||
run_cmd_ratios += f" {ratio_b}"
|
||||
if lora_c_model:
|
||||
if not os.path.isfile(lora_c_model):
|
||||
msgbox('The provided model C is not a file')
|
||||
msgbox("The provided model C is not a file")
|
||||
return
|
||||
run_cmd_models += fr' "{lora_c_model}"'
|
||||
run_cmd_ratios += f' {ratio_c}'
|
||||
run_cmd_models += rf' "{lora_c_model}"'
|
||||
run_cmd_ratios += f" {ratio_c}"
|
||||
if lora_d_model:
|
||||
if not os.path.isfile(lora_d_model):
|
||||
msgbox('The provided model D is not a file')
|
||||
msgbox("The provided model D is not a file")
|
||||
return
|
||||
run_cmd_models += fr' "{lora_d_model}"'
|
||||
run_cmd_ratios += f' {ratio_d}'
|
||||
run_cmd_models += rf' "{lora_d_model}"'
|
||||
run_cmd_ratios += f" {ratio_d}"
|
||||
|
||||
run_cmd += run_cmd_models
|
||||
run_cmd += run_cmd_ratios
|
||||
run_cmd += f' --device {device}'
|
||||
run_cmd += f" --device {device}"
|
||||
run_cmd += f' --new_rank "{new_rank}"'
|
||||
run_cmd += f' --new_conv_rank "{new_conv_rank}"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, shell=True, env=env)
|
||||
|
|
@ -138,13 +139,13 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
current_save_dir = path
|
||||
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
|
||||
|
||||
with gr.Tab('Merge LoRA (SVD)'):
|
||||
with gr.Tab("Merge LoRA (SVD)"):
|
||||
gr.Markdown(
|
||||
'This utility can merge two LoRA networks together into a new LoRA.'
|
||||
"This utility can merge two LoRA networks together into a new LoRA."
|
||||
)
|
||||
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
lora_a_model = gr.Dropdown(
|
||||
|
|
@ -154,11 +155,16 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_a_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_a_models(current_a_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
|
|
@ -175,11 +181,16 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_b_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_b_models(current_b_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_b_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
|
|
@ -202,7 +213,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
ratio_a = gr.Slider(
|
||||
label='Merge ratio model A',
|
||||
label="Merge ratio model A",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -210,7 +221,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
ratio_b = gr.Slider(
|
||||
label='Merge ratio model B',
|
||||
label="Merge ratio model B",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -225,11 +236,16 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_c_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_c_models(current_c_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_c_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lora_c_model_file.click(
|
||||
|
|
@ -246,11 +262,16 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_d_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_d_models(current_d_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_d_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lora_d_model_file.click(
|
||||
|
|
@ -274,7 +295,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
ratio_c = gr.Slider(
|
||||
label='Merge ratio model C',
|
||||
label="Merge ratio model C",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -282,7 +303,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
ratio_d = gr.Slider(
|
||||
label='Merge ratio model D',
|
||||
label="Merge ratio model D",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
|
|
@ -291,7 +312,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
new_rank = gr.Slider(
|
||||
label='New Rank',
|
||||
label="New Rank",
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
|
|
@ -299,7 +320,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
interactive=True,
|
||||
)
|
||||
new_conv_rank = gr.Slider(
|
||||
label='New Conv Rank',
|
||||
label="New Conv Rank",
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
|
|
@ -309,17 +330,22 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
|
||||
with gr.Group(), gr.Row():
|
||||
save_to = gr.Dropdown(
|
||||
label='Save to (path for the new LoRA file to save...)',
|
||||
label="Save to (path for the new LoRA file to save...)",
|
||||
interactive=True,
|
||||
choices=[""] + list_save_to(current_d_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
save_to,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_save_to(current_save_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_save_to.click(
|
||||
|
|
@ -336,28 +362,28 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
)
|
||||
with gr.Group(), gr.Row():
|
||||
precision = gr.Radio(
|
||||
label='Merge precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='float',
|
||||
label="Merge precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="float",
|
||||
interactive=True,
|
||||
)
|
||||
save_precision = gr.Radio(
|
||||
label='Save precision',
|
||||
choices=['fp16', 'bf16', 'float'],
|
||||
value='float',
|
||||
label="Save precision",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
value="float",
|
||||
interactive=True,
|
||||
)
|
||||
device = gr.Radio(
|
||||
label='Device',
|
||||
label="Device",
|
||||
choices=[
|
||||
'cpu',
|
||||
'cuda',
|
||||
"cpu",
|
||||
"cuda",
|
||||
],
|
||||
value='cuda',
|
||||
value="cuda",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
convert_button = gr.Button('Merge model')
|
||||
convert_button = gr.Button("Merge model")
|
||||
|
||||
convert_button.click(
|
||||
svd_merge_lora,
|
||||
|
|
|
|||
|
|
@ -10,55 +10,56 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
tensorboard_proc = None
|
||||
TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe'
|
||||
TENSORBOARD = "tensorboard" if os.name == "posix" else "tensorboard.exe"
|
||||
|
||||
# Set the default tensorboard port
|
||||
DEFAULT_TENSORBOARD_PORT = 6006
|
||||
|
||||
|
||||
def start_tensorboard(headless, logging_dir, wait_time=5):
|
||||
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
global tensorboard_proc
|
||||
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
|
||||
headless_bool = True if headless.get("label") == "True" else False
|
||||
|
||||
# Read the TENSORBOARD_PORT from the environment, or use the default
|
||||
tensorboard_port = os.environ.get('TENSORBOARD_PORT', DEFAULT_TENSORBOARD_PORT)
|
||||
|
||||
tensorboard_port = os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT)
|
||||
|
||||
# Check if logging directory exists and is not empty; if not, warn the user and exit
|
||||
if not os.path.exists(logging_dir) or not os.listdir(logging_dir):
|
||||
log.error('Error: logging folder does not exist or does not contain logs.')
|
||||
msgbox(msg='Error: logging folder does not exist or does not contain logs.')
|
||||
log.error("Error: logging folder does not exist or does not contain logs.")
|
||||
msgbox(msg="Error: logging folder does not exist or does not contain logs.")
|
||||
return # Exit the function with an error code
|
||||
|
||||
run_cmd = [
|
||||
TENSORBOARD,
|
||||
'--logdir',
|
||||
"--logdir",
|
||||
logging_dir,
|
||||
'--host',
|
||||
'0.0.0.0',
|
||||
'--port',
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
str(tensorboard_port),
|
||||
]
|
||||
|
||||
log.info(run_cmd)
|
||||
if tensorboard_proc is not None:
|
||||
log.info(
|
||||
'Tensorboard is already running. Terminating existing process before starting new one...'
|
||||
"Tensorboard is already running. Terminating existing process before starting new one..."
|
||||
)
|
||||
stop_tensorboard()
|
||||
|
||||
# Start background process
|
||||
log.info('Starting TensorBoard on port {}'.format(tensorboard_port))
|
||||
log.info("Starting TensorBoard on port {}".format(tensorboard_port))
|
||||
try:
|
||||
# Copy the current environment
|
||||
env = os.environ.copy()
|
||||
|
||||
# Set your specific environment variable
|
||||
env['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
||||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
tensorboard_proc = subprocess.Popen(run_cmd, env=env)
|
||||
except Exception as e:
|
||||
log.error('Failed to start Tensorboard:', e)
|
||||
log.error("Failed to start Tensorboard:", e)
|
||||
return
|
||||
|
||||
if not headless_bool:
|
||||
|
|
@ -66,28 +67,28 @@ def start_tensorboard(headless, logging_dir, wait_time=5):
|
|||
time.sleep(wait_time)
|
||||
|
||||
# Open the TensorBoard URL in the default browser
|
||||
tensorboard_url = f'http://localhost:{tensorboard_port}'
|
||||
log.info(f'Opening TensorBoard URL in browser: {tensorboard_url}')
|
||||
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||
log.info(f"Opening TensorBoard URL in browser: {tensorboard_url}")
|
||||
webbrowser.open(tensorboard_url)
|
||||
|
||||
|
||||
def stop_tensorboard():
|
||||
global tensorboard_proc
|
||||
if tensorboard_proc is not None:
|
||||
log.info('Stopping tensorboard process...')
|
||||
log.info("Stopping tensorboard process...")
|
||||
try:
|
||||
tensorboard_proc.terminate()
|
||||
tensorboard_proc = None
|
||||
log.info('...process stopped')
|
||||
log.info("...process stopped")
|
||||
except Exception as e:
|
||||
log.error('Failed to stop Tensorboard:', e)
|
||||
log.error("Failed to stop Tensorboard:", e)
|
||||
else:
|
||||
log.warning('Tensorboard is not running...')
|
||||
log.warning("Tensorboard is not running...")
|
||||
|
||||
|
||||
def gradio_tensorboard():
|
||||
with gr.Row():
|
||||
button_start_tensorboard = gr.Button('Start tensorboard')
|
||||
button_stop_tensorboard = gr.Button('Stop tensorboard')
|
||||
button_start_tensorboard = gr.Button("Start tensorboard")
|
||||
button_stop_tensorboard = gr.Button("Stop tensorboard")
|
||||
|
||||
return (button_start_tensorboard, button_stop_tensorboard)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import gradio as gr
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
from .common_gui import (
|
||||
get_file_path,
|
||||
|
|
@ -463,7 +462,9 @@ def train_model(
|
|||
return
|
||||
|
||||
if dataset_config:
|
||||
log.info("Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations...")
|
||||
log.info(
|
||||
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
|
||||
)
|
||||
else:
|
||||
# Get a list of all subfolders in train_data_dir
|
||||
subfolders = [
|
||||
|
|
@ -526,7 +527,9 @@ def train_model(
|
|||
log.info(f"max_train_steps = {max_train_steps}")
|
||||
|
||||
# calculate stop encoder training
|
||||
if stop_text_encoder_training_pct == None or (not max_train_steps == "" or not max_train_steps == "0"):
|
||||
if stop_text_encoder_training_pct == None or (
|
||||
not max_train_steps == "" or not max_train_steps == "0"
|
||||
):
|
||||
stop_text_encoder_training = 0
|
||||
else:
|
||||
stop_text_encoder_training = math.ceil(
|
||||
|
|
@ -709,7 +712,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
|
|||
|
||||
with gr.Accordion("Accelerate launch", open=False), gr.Column():
|
||||
accelerate_launch = AccelerateLaunch()
|
||||
|
||||
|
||||
with gr.Column():
|
||||
source_model = SourceModel(
|
||||
save_model_as_choices=[
|
||||
|
|
@ -722,7 +725,7 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
|
|||
|
||||
with gr.Accordion("Folders", open=False), gr.Group():
|
||||
folders = Folders(headless=headless, config=config)
|
||||
|
||||
|
||||
with gr.Accordion("Parameters", open=False), gr.Column():
|
||||
with gr.Accordion("Basic", open="True"):
|
||||
with gr.Group(elem_id="basic_tab"):
|
||||
|
|
@ -733,7 +736,9 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
|
|||
current_embedding_dir = path
|
||||
return list(
|
||||
list_files(
|
||||
path, exts=[".pt", ".ckpt", ".safetensors"], all=True
|
||||
path,
|
||||
exts=[".pt", ".ckpt", ".safetensors"],
|
||||
all=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,4 @@
|
|||
# v1: initial release
|
||||
# v2: add open and save folder icons
|
||||
# v3: Add new Utilities tab for Dreambooth folder preparation
|
||||
# v3.1: Adding captionning of images to utilities
|
||||
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
from .basic_caption_gui import gradio_basic_caption_gui_tab
|
||||
from .convert_model_gui import gradio_convert_model_tab
|
||||
|
|
@ -23,9 +17,9 @@ def utilities_tab(
|
|||
logging_dir_input=gr.Dropdown(),
|
||||
enable_copy_info_button=bool(False),
|
||||
enable_dreambooth_tab=True,
|
||||
headless=False
|
||||
headless=False,
|
||||
):
|
||||
with gr.Tab('Captioning'):
|
||||
with gr.Tab("Captioning"):
|
||||
gradio_basic_caption_gui_tab(headless=headless)
|
||||
gradio_blip_caption_gui_tab(headless=headless)
|
||||
gradio_blip2_caption_gui_tab(headless=headless)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@ import subprocess
|
|||
import os
|
||||
import sys
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
scriptdir,
|
||||
list_files,
|
||||
|
|
@ -17,37 +15,41 @@ from .custom_logging import setup_logging
|
|||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
|
||||
def verify_lora(
|
||||
lora_model,
|
||||
):
|
||||
# verify for caption_text_input
|
||||
if lora_model == '':
|
||||
msgbox('Invalid model A file')
|
||||
if lora_model == "":
|
||||
msgbox("Invalid model A file")
|
||||
return
|
||||
|
||||
# verify if source model exist
|
||||
if not os.path.isfile(lora_model):
|
||||
msgbox('The provided model A is not a file')
|
||||
msgbox("The provided model A is not a file")
|
||||
return
|
||||
|
||||
run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"'
|
||||
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
env["PYTHONPATH"] = (
|
||||
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
)
|
||||
|
||||
# Run the command
|
||||
process = subprocess.Popen(
|
||||
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env,
|
||||
run_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
output, error = process.communicate()
|
||||
|
||||
|
|
@ -67,27 +69,32 @@ def gradio_verify_lora_tab(headless=False):
|
|||
current_model_dir = path
|
||||
return list(list_files(path, exts=[".pt", ".safetensors"], all=True))
|
||||
|
||||
with gr.Tab('Verify LoRA'):
|
||||
with gr.Tab("Verify LoRA"):
|
||||
gr.Markdown(
|
||||
'This utility can verify a LoRA network to make sure it is properly trained.'
|
||||
"This utility can verify a LoRA network to make sure it is properly trained."
|
||||
)
|
||||
|
||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
lora_ext = gr.Textbox(value="*.pt *.safetensors", visible=False)
|
||||
lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)
|
||||
|
||||
with gr.Group(), gr.Row():
|
||||
lora_model = gr.Dropdown(
|
||||
label='LoRA model (path to the LoRA model to verify)',
|
||||
label="LoRA model (path to the LoRA model to verify)",
|
||||
interactive=True,
|
||||
choices=[""] + list_models(current_model_dir),
|
||||
value="",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(lora_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small")
|
||||
create_refresh_button(
|
||||
lora_model,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(current_model_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
button_lora_model_file = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
button_lora_model_file.click(
|
||||
|
|
@ -96,7 +103,7 @@ def gradio_verify_lora_tab(headless=False):
|
|||
outputs=lora_model,
|
||||
show_progress=False,
|
||||
)
|
||||
verify_button = gr.Button('Verify', variant='primary')
|
||||
verify_button = gr.Button("Verify", variant="primary")
|
||||
|
||||
lora_model.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
|
||||
|
|
@ -106,16 +113,16 @@ def gradio_verify_lora_tab(headless=False):
|
|||
)
|
||||
|
||||
lora_model_verif_output = gr.Textbox(
|
||||
label='Output',
|
||||
placeholder='Verification output',
|
||||
label="Output",
|
||||
placeholder="Verification output",
|
||||
interactive=False,
|
||||
lines=1,
|
||||
max_lines=10,
|
||||
)
|
||||
|
||||
lora_model_verif_error = gr.Textbox(
|
||||
label='Error',
|
||||
placeholder='Verification error',
|
||||
label="Error",
|
||||
placeholder="Verification error",
|
||||
interactive=False,
|
||||
lines=1,
|
||||
max_lines=10,
|
||||
|
|
|
|||
|
|
@ -164,9 +164,9 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None):
|
|||
)
|
||||
|
||||
caption_extension = gr.Textbox(
|
||||
label='Caption file extension',
|
||||
placeholder='Extension for caption file (e.g., .caption, .txt)',
|
||||
value='.txt',
|
||||
label="Caption file extension",
|
||||
placeholder="Extension for caption file (e.g., .caption, .txt)",
|
||||
value=".txt",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
|
@ -266,7 +266,7 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None):
|
|||
)
|
||||
character_threshold = gr.Slider(
|
||||
value=0.35,
|
||||
label='Character threshold',
|
||||
label="Character threshold",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.05,
|
||||
|
|
|
|||
Loading…
Reference in New Issue