mirror of https://github.com/bmaltais/kohya_ss
Add better python, model and data dir validation
parent
8c760b5e03
commit
33ccb84446
9
gui.bat
9
gui.bat
|
|
@ -1,10 +1,15 @@
|
|||
@echo off
|
||||
|
||||
set PYTHON_VER=3.10.9
|
||||
|
||||
:: Deactivate the virtual environment
|
||||
call .\venv\Scripts\deactivate.bat
|
||||
|
||||
:: Calling external python program to check for local modules
|
||||
:: python .\setup\check_local_modules.py --no_question
|
||||
:: Check if Python version meets the recommended version
|
||||
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
|
||||
if errorlevel 1 (
|
||||
echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run.
|
||||
)
|
||||
|
||||
:: Activate the virtual environment
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
|
|
|||
|
|
@ -15,6 +15,18 @@ refresh_symbol = '\U0001f504' # 🔄
|
|||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
default_models = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0',
|
||||
'stabilityai/stable-diffusion-xl-refiner-1.0',
|
||||
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
]
|
||||
|
||||
class SourceModel:
|
||||
def __init__(
|
||||
|
|
@ -39,19 +51,6 @@ class SourceModel:
|
|||
self.save_model_as_choices = save_model_as_choices
|
||||
self.finetuning = finetuning
|
||||
|
||||
default_models = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0',
|
||||
'stabilityai/stable-diffusion-xl-refiner-1.0',
|
||||
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
]
|
||||
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs")
|
||||
|
|
|
|||
|
|
@ -560,23 +560,23 @@ def train_model(
|
|||
log.info(f"Start training LoRA {LoRA_type} ...")
|
||||
headless_bool = True if headless.get("label") == "True" else False
|
||||
|
||||
if pretrained_model_name_or_path == "":
|
||||
output_message(
|
||||
msg="Source model information is missing", headless=headless_bool
|
||||
)
|
||||
return
|
||||
from .class_source_model import default_models
|
||||
|
||||
if train_data_dir == "":
|
||||
output_message(msg="Image folder path is missing", headless=headless_bool)
|
||||
# Check if the pretrained_model_name_or_path is valid
|
||||
if pretrained_model_name_or_path not in default_models:
|
||||
# If not one of the default models, check if it's a valid path
|
||||
if not pretrained_model_name_or_path or not os.path.exists(pretrained_model_name_or_path):
|
||||
log.error(f"Source model path '{pretrained_model_name_or_path}' is missing or does not exist")
|
||||
return
|
||||
|
||||
# Check if train_data_dir is valid
|
||||
if not train_data_dir or not os.path.exists(train_data_dir):
|
||||
log.error(f"Image folder path '{train_data_dir}' is missing or does not exist")
|
||||
return
|
||||
|
||||
# Check if there are files with the same filename but different image extension... warn the user if it is the case.
|
||||
check_duplicate_filenames(train_data_dir)
|
||||
|
||||
if not os.path.exists(train_data_dir):
|
||||
output_message(msg="Image folder does not exist", headless=headless_bool)
|
||||
return
|
||||
|
||||
if not verify_image_folder_pattern(train_data_dir):
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ set PYTHON_VER=3.10.9
|
|||
:: Check if Python version meets the recommended version
|
||||
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
|
||||
if errorlevel 1 (
|
||||
echo Warning: Python version %PYTHON_VER% is recommended.
|
||||
echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run.
|
||||
)
|
||||
|
||||
IF NOT EXIST venv (
|
||||
|
|
|
|||
|
|
@ -10,9 +10,30 @@ import datetime
|
|||
import platform
|
||||
import pkg_resources
|
||||
|
||||
from packaging import version
|
||||
|
||||
errors = 0 # Define the 'errors' variable before using it
|
||||
log = logging.getLogger('sd')
|
||||
|
||||
def check_python_version():
|
||||
"""
|
||||
Check if the current Python version is >= 3.10.9 and < 3.11.0
|
||||
|
||||
Returns:
|
||||
bool: True if the current Python version is valid, False otherwise.
|
||||
"""
|
||||
min_version = (3, 10, 9)
|
||||
max_version = (3, 11, 0)
|
||||
current_version = sys.version_info
|
||||
|
||||
log.info(f"Python version is {sys.version}")
|
||||
|
||||
if not (min_version <= current_version < max_version):
|
||||
log.error(f"The current version of python is not appropriate to run Kohya_ss GUI")
|
||||
log.error("The python version need to be greater or equal to 3.10.9 and less than 3.11.0")
|
||||
|
||||
return (min_version <= current_version < max_version)
|
||||
|
||||
def update_submodule():
|
||||
"""
|
||||
Ensure the submodule is initialized and updated.
|
||||
|
|
@ -352,7 +373,7 @@ def check_repo_version(): # pylint: disable=unused-argument
|
|||
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
|
||||
release= file.read()
|
||||
|
||||
log.info(f'Version: {release}')
|
||||
log.info(f'Kohya_ss GUI version: {release}')
|
||||
else:
|
||||
log.debug('Could not read release...')
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
python_ver = setup_common.check_python_version()
|
||||
setup_common.ensure_base_requirements()
|
||||
setup_common.setup_logging()
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ def main_menu(platform_requirements_file):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
python_ver = setup_common.check_python_version()
|
||||
setup_common.ensure_base_requirements()
|
||||
setup_common.setup_logging()
|
||||
|
||||
|
|
|
|||
|
|
@ -223,6 +223,7 @@ def main_menu():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
python_ver = setup_common.check_python_version()
|
||||
setup_common.ensure_base_requirements()
|
||||
setup_common.setup_logging()
|
||||
main_menu()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
|
|
@ -88,8 +87,7 @@ def check_torch():
|
|||
except Exception as e:
|
||||
log.error(f'Could not load torch: {e}')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
setup_common.check_repo_version()
|
||||
# Parse command line arguments
|
||||
|
|
@ -107,6 +105,8 @@ def main():
|
|||
|
||||
torch_ver = check_torch()
|
||||
|
||||
python_ver = setup_common.check_python_version()
|
||||
|
||||
setup_common.update_submodule()
|
||||
|
||||
if args.requirements:
|
||||
|
|
|
|||
Loading…
Reference in New Issue