use webui's models path as the init models folder

pull/186/head
kabachuha 2023-07-02 18:27:07 +03:00
parent a9cbf65dc6
commit c606463e44
2 changed files with 19 additions and 3 deletions

View File

@ -6,6 +6,8 @@ from types import SimpleNamespace
from t2v_helpers.video_audio_utils import find_ffmpeg_binary
from samplers.samplers_common import available_samplers
import os
import modules.paths as ph
from general_utils import get_model_location
from modules.shared import opts
welcome_text_videocrafter = '''
@ -84,11 +86,11 @@ def setup_text2video_settings_dictionary():
def refresh_all_models(model):
models = []
if os.path.isdir(os.path.join(os.getcwd(), 'models/ModelScope/t2v')):
if os.path.isdir(os.path.join(ph.models_dir, 'ModelScope/t2v')):
models.append('<modelscope>')
if os.path.isdir(os.path.join(os.getcwd(), 'models/VideoCrafter')):
if os.path.isdir(os.path.join(ph.models_dir, 'VideoCrafter')):
models.append('<videocrafter>')
models_dir = os.path.join(os.getcwd(), 'models/text2video/')
models_dir = os.path.join(ph.models_dir, 'text2video/')
for subdir in os.listdir(models_dir):
if os.path.isdir(subdir):
models.append(subdir)
@ -214,6 +216,8 @@ def T2VArgs():
def T2VArgs_sanity_check(t2v_args):
try:
if t2v_args.model is not None and not os.path.isdir(get_model_location(t2v_args.model)):
raise ValueError(f'Model "{t2v_args.model}" not found in {get_model_location(t2v_args.model)}!')
if t2v_args.frames < 1:
raise ValueError('Frames count cannot be lower than 1!')
if t2v_args.batch_count < 1:

View File

@ -1,6 +1,8 @@
# Copyright (C) 2023 by Artem Khrapov (kabachuha)
# Read LICENSE for usage terms.
from modules.prompt_parser import reconstruct_cond_batch
import os
import modules.paths as ph
def get_t2v_version():
from modules import extensions as mext
@ -12,6 +14,16 @@ def get_t2v_version():
except:
return "Unknown"
def get_model_location(model_name):
assert model_name is not None
if model_name == "<modelscope>":
return os.path.join(ph.models_path, 'models/ModelScope/t2v')
elif model_name == "<videocrafter>":
return os.path.join(ph.models_path, 'models/VideoCrafter')
else:
return os.path.join(ph.models_path, 'models/text2video/', model_name)
def reconstruct_conds(cond, uncond, step):
c = reconstruct_cond_batch(cond, step)
uc = reconstruct_cond_batch(uncond, step)