diff --git a/assets/style.css b/assets/style.css index d4d77f6..38c22d4 100644 --- a/assets/style.css +++ b/assets/style.css @@ -35,13 +35,13 @@ } #myTensorButton { - background: radial-gradient(ellipse, #007bff, #00b0ff); + background: radial-gradient(ellipse, #3a99ff, #52c8ff); color: white; border: none; } #myTensorButtonStop { - background: radial-gradient(ellipse, #00b0ff, #007bff); + background: radial-gradient(ellipse, #52c8ff, #3a99ff); color: black; border: none; } \ No newline at end of file diff --git a/kohya_gui/class_command_executor.py b/kohya_gui/class_command_executor.py index 355c673..87d9f77 100644 --- a/kohya_gui/class_command_executor.py +++ b/kohya_gui/class_command_executor.py @@ -1,6 +1,8 @@ import subprocess import psutil +import os import gradio as gr +import shlex from .custom_logging import setup_logging # Set up logging @@ -29,7 +31,27 @@ class CommandExecutor: if self.process and self.process.poll() is None: log.info("The command is already running. Please wait for it to finish.") else: - self.process = subprocess.Popen(run_cmd, **kwargs) + if os.name == 'nt': + run_cmd = run_cmd.replace('\\', '/') + + # Split the command string into components + parts = shlex.split(run_cmd) + # The first part is the executable, and it doesn't need quoting + executable = parts[0] + + # The remaining parts are the arguments, which we will quote for safety + safe_args = [shlex.quote(part) for part in parts[1:]] + + # Log the executable and arguments to debug path issues + log.info(f"Executable: {executable}") + log.info(f"Arguments: {' '.join(safe_args)}") + + # Reconstruct the safe command string for display + command_to_run = ' '.join([executable] + safe_args) + log.info(f"Executing command: {command_to_run}") + + # Execute the command securely + self.process = subprocess.Popen([executable] + safe_args, **kwargs) def kill_command(self): """ diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 1b3e0fc..a8ffc88 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -9,6 +9,7 @@ import gradio as gr import sys import json import math +import shutil # Set up logging log = setup_logging() @@ -19,6 +20,8 @@ save_style_symbol = "\U0001f4be" # 💾 document_symbol = "\U0001F4C4" # 📄 scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) +if os.name == 'nt': + scriptdir = scriptdir.replace('\\', '/') # insert sd-scripts path into PYTHONPATH sys.path.insert(0, os.path.join(scriptdir, "sd-scripts")) @@ -54,6 +57,27 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"] +def get_executable_path(executable_name: str = None) -> str: + """ + Retrieve and sanitize the path to an executable in the system's PATH. + + Args: + executable_name (str): The name of the executable to find. + + Returns: + str: The full, sanitized path to the executable if found, otherwise an empty string. + """ + if executable_name: + executable_path = shutil.which(executable_name) + if executable_path: + # Replace backslashes with forward slashes on Windows + if os.name == 'nt': + executable_path = executable_path.replace('\\', '/') + return executable_path + else: + return "" # Return empty string if the executable is not found + else: + return "" # Return empty string if no executable name is provided def calculate_max_train_steps( total_steps: int, diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index a8381c5..290eff4 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -5,6 +5,7 @@ import os import sys from datetime import datetime from .common_gui import ( + get_executable_path, get_file_path, get_saveasfile_path, color_aug_changed, @@ -563,8 +564,9 @@ def train_model( lr_warmup_steps = 0 log.info(f"lr_warmup_steps = {lr_warmup_steps}") - # run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' - run_cmd = "accelerate launch" + accelerate_path = get_executable_path("accelerate") + + run_cmd = f"{accelerate_path} launch" run_cmd += AccelerateLaunch.run_cmd( num_processes=num_processes,