feat: improve sample prompt option (#608)

* feat: improve sample prompt option

* fix: type
pull/611/head
undefined 2025-01-10 17:02:48 +08:00 committed by GitHub
parent d083810630
commit c18e95208a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 8 deletions

View File

@ -3,8 +3,12 @@ import hashlib
import json
import os
import re
import random
from glob import glob
from datetime import datetime
from pathlib import Path
from typing import Tuple, Optional
import toml
from fastapi import APIRouter, BackgroundTasks, Request
@ -80,6 +84,38 @@ async def load_presets():
avaliable_presets.append(toml.loads(content))
def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]:
train_data_dir = config["train_data_dir"]
sub_dir = [dir for dir in glob(os.path.join(train_data_dir, '*')) if os.path.isdir(dir)]
positive_prompts = config.pop('positive_prompts', None)
negative_prompts = config.pop('negative_prompts', '')
sample_width = config.pop('sample_width', 512) # 默认宽度 512
sample_height = config.pop('sample_height', 512) # 默认高度 512
sample_cfg = config.pop('sample_cfg', 7) # 默认 CFG 值 7.5
sample_seed = config.pop('sample_seed', 2333) # 默认随机种子 42
sample_steps = config.pop('sample_steps', 24) # 默认步数 50
randomly_choice_prompt = config.pop('randomly_choice_prompt', False)
if randomly_choice_prompt:
if len(sub_dir) != 1:
raise ValueError('训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
# return None, APIResponseFail(message='训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
txt_files = glob(os.path.join(sub_dir[0], '*.txt'))
if not txt_files:
raise ValueError('训练数据集路径没有 txt 文件')
# return None, APIResponseFail(message='训练数据集路径没有 txt 文件')
try:
sample_prompt_file = random.choice(txt_files)
with open(sample_prompt_file, 'r', encoding='utf-8') as f:
positive_prompts = f.read()
except IOError:
# positive_prompts = config['positive_prompts']
pass
return positive_prompts, f'{positive_prompts} --n {negative_prompts} --w {sample_width} --h {sample_height} --l {sample_cfg} --s {sample_steps} --d {sample_seed}'
@router.post("/run")
async def create_toml_file(request: Request):
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
@ -103,13 +139,19 @@ async def create_toml_file(request: Request):
if not validated:
return APIResponseFail(message=message)
sample_prompts = config.get("sample_prompts", None)
if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts)
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote promopts to file {sample_prompts_file}")
try:
positive_prompt, sample_prompts_arg = get_sample_prompts(config=config)
if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts_arg)
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote promopts to file {sample_prompts_file}")
except ValueError as e:
return APIResponseFail(message=str(e))
with open(toml_file, "w", encoding="utf-8") as f:
f.write(toml.dumps(config))

View File

@ -168,7 +168,14 @@
Schema.union([
Schema.object({
enable_preview: Schema.const(true).required(),
sample_prompts: Schema.string().role('textarea').default(SAMPLE_PROMPTS_DEFAULT).description(SAMPLE_PROMPTS_DESCRIPTION),
randomly_choice_prompt: Schema.boolean().default(false).description('随机选择预览图 Prompt'),
positive_prompts: Schema.string().role('textarea').default('masterpiece, best quality, 1girl, solo').description("Prompt"),
negative_prompts: Schema.string().role('textarea').default('lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry').description("Negative Prompt"),
sample_width: Schema.number().default(512).description('预览图宽'),
sample_height: Schema.number().default(512).description('预览图高'),
sample_cfg: Schema.number().min(1).max(30).default(7).description('CFG Scale'),
sample_seed: Schema.number().default(2333).description('种子'),
sample_steps: Schema.number().min(1).max(300).default(24).description('迭代步数'),
sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"),
sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"),
}),