backward compatibility

pull/669/head
Akegarasu 2025-01-13 18:29:46 +08:00
parent 7159f93ab0
commit c59a264241
No known key found for this signature in database
GPG Key ID: DACA951FEBA569A2
1 changed files with 10 additions and 8 deletions

View File

@ -85,6 +85,10 @@ async def load_presets():
def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]:
# backward compatibility
if "sample_prompts" in config and "positive_prompts" not in config:
return None, config["sample_prompts"]
train_data_dir = config["train_data_dir"]
sub_dir = [dir for dir in glob(os.path.join(train_data_dir, '*')) if os.path.isdir(dir)]
@ -99,23 +103,21 @@ def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]:
if randomly_choice_prompt:
if len(sub_dir) != 1:
raise ValueError('训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
# return None, APIResponseFail(message='训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
raise ValueError('训练数据集下有多个子文件夹,无法启用随机选取 Prompt 功能')
txt_files = glob(os.path.join(sub_dir[0], '*.txt'))
if not txt_files:
raise ValueError('训练数据集路径没有 txt 文件')
# return None, APIResponseFail(message='训练数据集路径没有 txt 文件')
try:
sample_prompt_file = random.choice(txt_files)
with open(sample_prompt_file, 'r', encoding='utf-8') as f:
positive_prompts = f.read()
except IOError:
# positive_prompts = config['positive_prompts']
pass
log.error(f"读取 {sample_prompt_file} 文件失败")
return positive_prompts, f'{positive_prompts} --n {negative_prompts} --w {sample_width} --h {sample_height} --l {sample_cfg} --s {sample_steps} --d {sample_seed}'
@router.post("/run")
async def create_toml_file(request: Request):
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
@ -141,7 +143,7 @@ async def create_toml_file(request: Request):
try:
positive_prompt, sample_prompts_arg = get_sample_prompts(config=config)
if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
@ -150,9 +152,9 @@ async def create_toml_file(request: Request):
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote prompts to file {sample_prompts_file}")
except ValueError as e:
log.error(f"Error while processing prompts: {e}")
return APIResponseFail(message=str(e))
with open(toml_file, "w", encoding="utf-8") as f:
f.write(toml.dumps(config))