feat: improve sample prompt option (#608)
* feat: improve sample prompt option * fix: typepull/611/head
parent
d083810630
commit
c18e95208a
|
|
@ -3,8 +3,12 @@ import hashlib
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
|
||||
from glob import glob
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import toml
|
||||
from fastapi import APIRouter, BackgroundTasks, Request
|
||||
|
|
@ -80,6 +84,38 @@ async def load_presets():
|
|||
avaliable_presets.append(toml.loads(content))
|
||||
|
||||
|
||||
def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]:
|
||||
train_data_dir = config["train_data_dir"]
|
||||
sub_dir = [dir for dir in glob(os.path.join(train_data_dir, '*')) if os.path.isdir(dir)]
|
||||
|
||||
positive_prompts = config.pop('positive_prompts', None)
|
||||
negative_prompts = config.pop('negative_prompts', '')
|
||||
sample_width = config.pop('sample_width', 512) # 默认宽度 512
|
||||
sample_height = config.pop('sample_height', 512) # 默认高度 512
|
||||
sample_cfg = config.pop('sample_cfg', 7) # 默认 CFG 值 7.5
|
||||
sample_seed = config.pop('sample_seed', 2333) # 默认随机种子 42
|
||||
sample_steps = config.pop('sample_steps', 24) # 默认步数 50
|
||||
randomly_choice_prompt = config.pop('randomly_choice_prompt', False)
|
||||
|
||||
if randomly_choice_prompt:
|
||||
if len(sub_dir) != 1:
|
||||
raise ValueError('训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
|
||||
# return None, APIResponseFail(message='训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
|
||||
|
||||
txt_files = glob(os.path.join(sub_dir[0], '*.txt'))
|
||||
if not txt_files:
|
||||
raise ValueError('训练数据集路径没有 txt 文件')
|
||||
# return None, APIResponseFail(message='训练数据集路径没有 txt 文件')
|
||||
try:
|
||||
sample_prompt_file = random.choice(txt_files)
|
||||
with open(sample_prompt_file, 'r', encoding='utf-8') as f:
|
||||
positive_prompts = f.read()
|
||||
except IOError:
|
||||
# positive_prompts = config['positive_prompts']
|
||||
pass
|
||||
|
||||
return positive_prompts, f'{positive_prompts} --n {negative_prompts} --w {sample_width} --h {sample_height} --l {sample_cfg} --s {sample_steps} --d {sample_seed}'
|
||||
|
||||
@router.post("/run")
|
||||
async def create_toml_file(request: Request):
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
|
|
@ -103,13 +139,19 @@ async def create_toml_file(request: Request):
|
|||
if not validated:
|
||||
return APIResponseFail(message=message)
|
||||
|
||||
sample_prompts = config.get("sample_prompts", None)
|
||||
if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts):
|
||||
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
|
||||
with open(sample_prompts_file, "w", encoding="utf-8") as f:
|
||||
f.write(sample_prompts)
|
||||
config["sample_prompts"] = sample_prompts_file
|
||||
log.info(f"Wrote promopts to file {sample_prompts_file}")
|
||||
try:
|
||||
positive_prompt, sample_prompts_arg = get_sample_prompts(config=config)
|
||||
|
||||
if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg):
|
||||
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
|
||||
with open(sample_prompts_file, "w", encoding="utf-8") as f:
|
||||
f.write(sample_prompts_arg)
|
||||
|
||||
config["sample_prompts"] = sample_prompts_file
|
||||
log.info(f"Wrote promopts to file {sample_prompts_file}")
|
||||
except ValueError as e:
|
||||
return APIResponseFail(message=str(e))
|
||||
|
||||
|
||||
with open(toml_file, "w", encoding="utf-8") as f:
|
||||
f.write(toml.dumps(config))
|
||||
|
|
|
|||
|
|
@ -168,7 +168,14 @@
|
|||
Schema.union([
|
||||
Schema.object({
|
||||
enable_preview: Schema.const(true).required(),
|
||||
sample_prompts: Schema.string().role('textarea').default(SAMPLE_PROMPTS_DEFAULT).description(SAMPLE_PROMPTS_DESCRIPTION),
|
||||
randomly_choice_prompt: Schema.boolean().default(false).description('随机选择预览图 Prompt'),
|
||||
positive_prompts: Schema.string().role('textarea').default('masterpiece, best quality, 1girl, solo').description("Prompt"),
|
||||
negative_prompts: Schema.string().role('textarea').default('lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry').description("Negative Prompt"),
|
||||
sample_width: Schema.number().default(512).description('预览图宽'),
|
||||
sample_height: Schema.number().default(512).description('预览图高'),
|
||||
sample_cfg: Schema.number().min(1).max(30).default(7).description('CFG Scale'),
|
||||
sample_seed: Schema.number().default(2333).description('种子'),
|
||||
sample_steps: Schema.number().min(1).max(300).default(24).description('迭代步数'),
|
||||
sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"),
|
||||
sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"),
|
||||
}),
|
||||
|
|
|
|||
Loading…
Reference in New Issue