mirror of https://github.com/bmaltais/kohya_ss
Fix issue for Dreambooth training... quick fix
parent
e8b54e630c
commit
f34eea41ca
|
|
@ -35,13 +35,13 @@
|
|||
}
|
||||
|
||||
#myTensorButton {
|
||||
background: radial-gradient(ellipse, #007bff, #00b0ff);
|
||||
background: radial-gradient(ellipse, #3a99ff, #52c8ff);
|
||||
color: white;
|
||||
border: none;
|
||||
}
|
||||
|
||||
#myTensorButtonStop {
|
||||
background: radial-gradient(ellipse, #00b0ff, #007bff);
|
||||
background: radial-gradient(ellipse, #52c8ff, #3a99ff);
|
||||
color: black;
|
||||
border: none;
|
||||
}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import subprocess
|
||||
import psutil
|
||||
import os
|
||||
import gradio as gr
|
||||
import shlex
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
|
|
@ -29,7 +31,27 @@ class CommandExecutor:
|
|||
if self.process and self.process.poll() is None:
|
||||
log.info("The command is already running. Please wait for it to finish.")
|
||||
else:
|
||||
self.process = subprocess.Popen(run_cmd, **kwargs)
|
||||
if os.name == 'nt':
|
||||
run_cmd = run_cmd.replace('\\', '/')
|
||||
|
||||
# Split the command string into components
|
||||
parts = shlex.split(run_cmd)
|
||||
# The first part is the executable, and it doesn't need quoting
|
||||
executable = parts[0]
|
||||
|
||||
# The remaining parts are the arguments, which we will quote for safety
|
||||
safe_args = [shlex.quote(part) for part in parts[1:]]
|
||||
|
||||
# Log the executable and arguments to debug path issues
|
||||
log.info(f"Executable: {executable}")
|
||||
log.info(f"Arguments: {' '.join(safe_args)}")
|
||||
|
||||
# Reconstruct the safe command string for display
|
||||
command_to_run = ' '.join([executable] + safe_args)
|
||||
log.info(f"Executing command: {command_to_run}")
|
||||
|
||||
# Execute the command securely
|
||||
self.process = subprocess.Popen([executable] + safe_args, **kwargs)
|
||||
|
||||
def kill_command(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import gradio as gr
|
|||
import sys
|
||||
import json
|
||||
import math
|
||||
import shutil
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -19,6 +20,8 @@ save_style_symbol = "\U0001f4be" # 💾
|
|||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
if os.name == 'nt':
|
||||
scriptdir = scriptdir.replace('\\', '/')
|
||||
|
||||
# insert sd-scripts path into PYTHONPATH
|
||||
sys.path.insert(0, os.path.join(scriptdir, "sd-scripts"))
|
||||
|
|
@ -54,6 +57,27 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX
|
|||
|
||||
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]
|
||||
|
||||
def get_executable_path(executable_name: str = None) -> str:
|
||||
"""
|
||||
Retrieve and sanitize the path to an executable in the system's PATH.
|
||||
|
||||
Args:
|
||||
executable_name (str): The name of the executable to find.
|
||||
|
||||
Returns:
|
||||
str: The full, sanitized path to the executable if found, otherwise an empty string.
|
||||
"""
|
||||
if executable_name:
|
||||
executable_path = shutil.which(executable_name)
|
||||
if executable_path:
|
||||
# Replace backslashes with forward slashes on Windows
|
||||
if os.name == 'nt':
|
||||
executable_path = executable_path.replace('\\', '/')
|
||||
return executable_path
|
||||
else:
|
||||
return "" # Return empty string if the executable is not found
|
||||
else:
|
||||
return "" # Return empty string if no executable name is provided
|
||||
|
||||
def calculate_max_train_steps(
|
||||
total_steps: int,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import os
|
|||
import sys
|
||||
from datetime import datetime
|
||||
from .common_gui import (
|
||||
get_executable_path,
|
||||
get_file_path,
|
||||
get_saveasfile_path,
|
||||
color_aug_changed,
|
||||
|
|
@ -563,8 +564,9 @@ def train_model(
|
|||
lr_warmup_steps = 0
|
||||
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
||||
|
||||
# run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
|
||||
run_cmd = "accelerate launch"
|
||||
accelerate_path = get_executable_path("accelerate")
|
||||
|
||||
run_cmd = f"{accelerate_path} launch"
|
||||
|
||||
run_cmd += AccelerateLaunch.run_cmd(
|
||||
num_processes=num_processes,
|
||||
|
|
|
|||
Loading…
Reference in New Issue