Add better python, model and data dir validation

pull/2055/head
bmaltais 2024-03-09 20:50:25 -05:00
parent 8c760b5e03
commit 33ccb84446
10 changed files with 60 additions and 32 deletions

View File

@ -1 +1 @@
v23.0.1
v23.0.1

View File

@ -1,10 +1,15 @@
@echo off
set PYTHON_VER=3.10.9
:: Deactivate the virtual environment
call .\venv\Scripts\deactivate.bat
:: Calling external python program to check for local modules
:: python .\setup\check_local_modules.py --no_question
:: Check if Python version meets the recommended version
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
if errorlevel 1 (
echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run.
)
:: Activate the virtual environment
call .\venv\Scripts\activate.bat

View File

@ -15,6 +15,18 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
default_models = [
'stabilityai/stable-diffusion-xl-base-1.0',
'stabilityai/stable-diffusion-xl-refiner-1.0',
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
]
class SourceModel:
def __init__(
@ -39,19 +51,6 @@ class SourceModel:
self.save_model_as_choices = save_model_as_choices
self.finetuning = finetuning
default_models = [
'stabilityai/stable-diffusion-xl-base-1.0',
'stabilityai/stable-diffusion-xl-refiner-1.0',
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
]
from .common_gui import create_refresh_button
default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs")

View File

@ -560,23 +560,23 @@ def train_model(
log.info(f"Start training LoRA {LoRA_type} ...")
headless_bool = True if headless.get("label") == "True" else False
if pretrained_model_name_or_path == "":
output_message(
msg="Source model information is missing", headless=headless_bool
)
return
from .class_source_model import default_models
if train_data_dir == "":
output_message(msg="Image folder path is missing", headless=headless_bool)
# Check if the pretrained_model_name_or_path is valid
if pretrained_model_name_or_path not in default_models:
# If not one of the default models, check if it's a valid path
if not pretrained_model_name_or_path or not os.path.exists(pretrained_model_name_or_path):
log.error(f"Source model path '{pretrained_model_name_or_path}' is missing or does not exist")
return
# Check if train_data_dir is valid
if not train_data_dir or not os.path.exists(train_data_dir):
log.error(f"Image folder path '{train_data_dir}' is missing or does not exist")
return
# Check if there are files with the same filename but different image extension... warn the user if it is the case.
check_duplicate_filenames(train_data_dir)
if not os.path.exists(train_data_dir):
output_message(msg="Image folder does not exist", headless=headless_bool)
return
if not verify_image_folder_pattern(train_data_dir):
return

View File

@ -5,7 +5,7 @@ set PYTHON_VER=3.10.9
:: Check if Python version meets the recommended version
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
if errorlevel 1 (
echo Warning: Python version %PYTHON_VER% is recommended.
echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run.
)
IF NOT EXIST venv (

View File

@ -10,9 +10,30 @@ import datetime
import platform
import pkg_resources
from packaging import version
errors = 0 # Define the 'errors' variable before using it
log = logging.getLogger('sd')
def check_python_version():
"""
Check if the current Python version is >= 3.10.9 and < 3.11.0
Returns:
bool: True if the current Python version is valid, False otherwise.
"""
min_version = (3, 10, 9)
max_version = (3, 11, 0)
current_version = sys.version_info
log.info(f"Python version is {sys.version}")
if not (min_version <= current_version < max_version):
log.error(f"The current version of python is not appropriate to run Kohya_ss GUI")
log.error("The python version need to be greater or equal to 3.10.9 and less than 3.11.0")
return (min_version <= current_version < max_version)
def update_submodule():
"""
Ensure the submodule is initialized and updated.
@ -352,7 +373,7 @@ def check_repo_version(): # pylint: disable=unused-argument
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
release= file.read()
log.info(f'Version: {release}')
log.info(f'Kohya_ss GUI version: {release}')
else:
log.debug('Could not read release...')

View File

@ -25,6 +25,7 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce
if __name__ == '__main__':
python_ver = setup_common.check_python_version()
setup_common.ensure_base_requirements()
setup_common.setup_logging()

View File

@ -59,6 +59,7 @@ def main_menu(platform_requirements_file):
if __name__ == '__main__':
python_ver = setup_common.check_python_version()
setup_common.ensure_base_requirements()
setup_common.setup_logging()

View File

@ -223,6 +223,7 @@ def main_menu():
if __name__ == "__main__":
python_ver = setup_common.check_python_version()
setup_common.ensure_base_requirements()
setup_common.setup_logging()
main_menu()

View File

@ -1,5 +1,4 @@
import os
import re
import sys
import shutil
import argparse
@ -88,8 +87,7 @@ def check_torch():
except Exception as e:
log.error(f'Could not load torch: {e}')
sys.exit(1)
def main():
setup_common.check_repo_version()
# Parse command line arguments
@ -107,6 +105,8 @@ def main():
torch_ver = check_torch()
python_ver = setup_common.check_python_version()
setup_common.update_submodule()
if args.requirements: