diff --git a/mikazuki/app/api.py b/mikazuki/app/api.py index 731ef81..0049469 100644 --- a/mikazuki/app/api.py +++ b/mikazuki/app/api.py @@ -3,8 +3,12 @@ import hashlib import json import os import re +import random + +from glob import glob from datetime import datetime from pathlib import Path +from typing import Tuple, Optional import toml from fastapi import APIRouter, BackgroundTasks, Request @@ -80,6 +84,38 @@ async def load_presets(): avaliable_presets.append(toml.loads(content)) +def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]: + train_data_dir = config["train_data_dir"] + sub_dir = [dir for dir in glob(os.path.join(train_data_dir, '*')) if os.path.isdir(dir)] + + positive_prompts = config.pop('positive_prompts', None) + negative_prompts = config.pop('negative_prompts', '') + sample_width = config.pop('sample_width', 512) # 默认宽度 512 + sample_height = config.pop('sample_height', 512) # 默认高度 512 + sample_cfg = config.pop('sample_cfg', 7) # 默认 CFG 值 7.5 + sample_seed = config.pop('sample_seed', 2333) # 默认随机种子 42 + sample_steps = config.pop('sample_steps', 24) # 默认步数 50 + randomly_choice_prompt = config.pop('randomly_choice_prompt', False) + + if randomly_choice_prompt: + if len(sub_dir) != 1: + raise ValueError('训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能') + # return None, APIResponseFail(message='训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能') + + txt_files = glob(os.path.join(sub_dir[0], '*.txt')) + if not txt_files: + raise ValueError('训练数据集路径没有 txt 文件') + # return None, APIResponseFail(message='训练数据集路径没有 txt 文件') + try: + sample_prompt_file = random.choice(txt_files) + with open(sample_prompt_file, 'r', encoding='utf-8') as f: + positive_prompts = f.read() + except IOError: + # positive_prompts = config['positive_prompts'] + pass + + return positive_prompts, f'{positive_prompts} --n {negative_prompts} --w {sample_width} --h {sample_height} --l {sample_cfg} --s {sample_steps} --d {sample_seed}' + @router.post("/run") async def create_toml_file(request: Request): timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") @@ -103,13 +139,19 @@ async def create_toml_file(request: Request): if not validated: return APIResponseFail(message=message) - sample_prompts = config.get("sample_prompts", None) - if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts): - sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt") - with open(sample_prompts_file, "w", encoding="utf-8") as f: - f.write(sample_prompts) - config["sample_prompts"] = sample_prompts_file - log.info(f"Wrote promopts to file {sample_prompts_file}") + try: + positive_prompt, sample_prompts_arg = get_sample_prompts(config=config) + + if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg): + sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt") + with open(sample_prompts_file, "w", encoding="utf-8") as f: + f.write(sample_prompts_arg) + + config["sample_prompts"] = sample_prompts_file + log.info(f"Wrote promopts to file {sample_prompts_file}") + except ValueError as e: + return APIResponseFail(message=str(e)) + with open(toml_file, "w", encoding="utf-8") as f: f.write(toml.dumps(config)) diff --git a/mikazuki/schema/shared.ts b/mikazuki/schema/shared.ts index c30bb47..8fa93a8 100644 --- a/mikazuki/schema/shared.ts +++ b/mikazuki/schema/shared.ts @@ -168,7 +168,14 @@ Schema.union([ Schema.object({ enable_preview: Schema.const(true).required(), - sample_prompts: Schema.string().role('textarea').default(SAMPLE_PROMPTS_DEFAULT).description(SAMPLE_PROMPTS_DESCRIPTION), + randomly_choice_prompt: Schema.boolean().default(false).description('随机选择预览图 Prompt'), + positive_prompts: Schema.string().role('textarea').default('masterpiece, best quality, 1girl, solo').description("Prompt"), + negative_prompts: Schema.string().role('textarea').default('lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry').description("Negative Prompt"), + sample_width: Schema.number().default(512).description('预览图宽'), + sample_height: Schema.number().default(512).description('预览图高'), + sample_cfg: Schema.number().min(1).max(30).default(7).description('CFG Scale'), + sample_seed: Schema.number().default(2333).description('种子'), + sample_steps: Schema.number().min(1).max(300).default(24).description('迭代步数'), sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"), sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"), }),