diff --git a/scripts/t2v_helpers/args.py b/scripts/t2v_helpers/args.py index 8715f74..acb2718 100644 --- a/scripts/t2v_helpers/args.py +++ b/scripts/t2v_helpers/args.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from t2v_helpers.video_audio_utils import find_ffmpeg_binary from samplers.samplers_common import available_samplers import os +import modules.paths as ph +from general_utils import get_model_location from modules.shared import opts welcome_text_videocrafter = ''' @@ -84,11 +86,11 @@ def setup_text2video_settings_dictionary(): def refresh_all_models(model): models = [] - if os.path.isdir(os.path.join(os.getcwd(), 'models/ModelScope/t2v')): + if os.path.isdir(os.path.join(ph.models_dir, 'ModelScope/t2v')): models.append('') - if os.path.isdir(os.path.join(os.getcwd(), 'models/VideoCrafter')): + if os.path.isdir(os.path.join(ph.models_dir, 'VideoCrafter')): models.append('') - models_dir = os.path.join(os.getcwd(), 'models/text2video/') + models_dir = os.path.join(ph.models_dir, 'text2video/') for subdir in os.listdir(models_dir): if os.path.isdir(subdir): models.append(subdir) @@ -214,6 +216,8 @@ def T2VArgs(): def T2VArgs_sanity_check(t2v_args): try: + if t2v_args.model is not None and not os.path.isdir(get_model_location(t2v_args.model)): + raise ValueError(f'Model "{t2v_args.model}" not found in {get_model_location(t2v_args.model)}!') if t2v_args.frames < 1: raise ValueError('Frames count cannot be lower than 1!') if t2v_args.batch_count < 1: diff --git a/scripts/t2v_helpers/general_utils.py b/scripts/t2v_helpers/general_utils.py index 49dc510..b5ed536 100644 --- a/scripts/t2v_helpers/general_utils.py +++ b/scripts/t2v_helpers/general_utils.py @@ -1,6 +1,8 @@ # Copyright (C) 2023 by Artem Khrapov (kabachuha) # Read LICENSE for usage terms. from modules.prompt_parser import reconstruct_cond_batch +import os +import modules.paths as ph def get_t2v_version(): from modules import extensions as mext @@ -12,6 +14,16 @@ def get_t2v_version(): except: return "Unknown" +def get_model_location(model_name): + assert model_name is not None + + if model_name == "": + return os.path.join(ph.models_path, 'models/ModelScope/t2v') + elif model_name == "": + return os.path.join(ph.models_path, 'models/VideoCrafter') + else: + return os.path.join(ph.models_path, 'models/text2video/', model_name) + def reconstruct_conds(cond, uncond, step): c = reconstruct_cond_batch(cond, step) uc = reconstruct_cond_batch(uncond, step)