654 lines
28 KiB
Python
654 lines
28 KiB
Python
import base64
|
|
import base64
|
|
import functools
|
|
import hashlib
|
|
import io
|
|
import json
|
|
import os
|
|
import traceback
|
|
import zipfile
|
|
from pathlib import Path
|
|
|
|
import gradio as gr
|
|
from PIL import Image
|
|
from fastapi import FastAPI, Response, Query, Body
|
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
|
from pydantic import BaseModel, Field
|
|
from pydantic.dataclasses import Union
|
|
from pydantic.types import List
|
|
|
|
from extensions.sd_dreambooth_extension.dreambooth import db_shared
|
|
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file, DreamboothConfig
|
|
from extensions.sd_dreambooth_extension.dreambooth.db_shared import DreamState
|
|
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
|
|
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, generate_classifiers
|
|
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
|
|
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret
|
|
from extensions.sd_dreambooth_extension.dreambooth.utils import get_images
|
|
from extensions.sd_dreambooth_extension.scripts import dreambooth
|
|
from extensions.sd_dreambooth_extension.scripts.dreambooth import ui_samples
|
|
from modules import sd_models
|
|
|
|
|
|
class InstanceData(BaseModel):
|
|
data: str = Field(title="File data", description="Base64 representation of the file")
|
|
name: str = Field(title="File name")
|
|
txt: str = Field(title="Prompt")
|
|
|
|
|
|
class ImageData:
|
|
def __init__(self, name, prompt, data):
|
|
self.name = name
|
|
self.prompt = prompt
|
|
self.data = data
|
|
|
|
def dict(self):
|
|
return {
|
|
"name": self.name,
|
|
"data": self.data,
|
|
"txt": self.prompt
|
|
}
|
|
|
|
|
|
class DbImagesRequest(BaseModel):
|
|
imageList: List[InstanceData] = Field(title="Images",
|
|
description="List of images to work on. Must be Base64 strings")
|
|
|
|
|
|
# API Representation of concept data
|
|
class DreamboothConcept(BaseModel):
|
|
instance_data_dir: str = ""
|
|
class_data_dir: str = ""
|
|
instance_prompt: str = ""
|
|
class_prompt: Union[str, None] = ""
|
|
save_sample_prompt: Union[str, None] = ""
|
|
save_sample_template: Union[str, None] = ""
|
|
instance_token: Union[str, None] = ""
|
|
class_token: Union[str, None] = ""
|
|
num_class_images_per: int = 0
|
|
class_negative_prompt: Union[str, None] = ""
|
|
class_guidance_scale: float = 7.5
|
|
class_infer_steps: int = 60
|
|
save_sample_negative_prompt: Union[str, None] = ""
|
|
n_save_sample: int = 1
|
|
sample_seed: int = -1
|
|
save_guidance_scale: float = 7.5
|
|
save_infer_steps: int = 60
|
|
|
|
|
|
# API Representation of db config
|
|
class DreamboothParameters(BaseModel):
|
|
concepts_list: List[DreamboothConcept]
|
|
attention: str = "default"
|
|
cache_latents: bool = True
|
|
center_crop: bool = False
|
|
clip_skip: int = 1
|
|
concepts_path: Union[str, None] = ""
|
|
custom_model_name: Union[str, None] = ""
|
|
epoch_pause_frequency: int = 0
|
|
epoch_pause_time: int = 60
|
|
gradient_accumulation_steps: int = 1
|
|
gradient_checkpointing: bool = True
|
|
gradient_set_to_none: bool = True
|
|
graph_smoothing: int = 50
|
|
half_model: bool = False
|
|
hflip: bool = True
|
|
learning_rate: float = 0.000002
|
|
learning_rate_min: float = 0.000001
|
|
lora_learning_rate: float = 0.0002
|
|
lora_model_name: str = ""
|
|
lora_rank: int = 4
|
|
lora_txt_learning_rate: float = 0.0002
|
|
lora_txt_weight: int = 1
|
|
lora_weight: int = 1
|
|
lr_cycles: int = 1
|
|
lr_factor: float = 0.5
|
|
lr_power: float = 1.0
|
|
lr_scale_pos: float = 0.5
|
|
lr_scheduler: str = "constant"
|
|
lr_warmup_steps: int = 500
|
|
max_token_length: int = 75
|
|
mixed_precision: str = "no"
|
|
adamw_weight_decay: float = 1e-2
|
|
model_name: str = ""
|
|
num_train_epochs: int = 100
|
|
pad_tokens: bool = True
|
|
pretrained_vae_name_or_path: Union[str, None] = ""
|
|
prior_loss_weight: float = 1.0
|
|
resolution: int = 512
|
|
revision: int = 0
|
|
sample_batch_size: int = 1
|
|
sanity_prompt: str = ""
|
|
sanity_seed: int = 420420
|
|
save_ckpt_after: bool = True
|
|
save_ckpt_cancel: bool = False
|
|
save_ckpt_during: bool = True
|
|
save_embedding_every: int = 25
|
|
save_lora_after: bool = True
|
|
save_lora_cancel: bool = False
|
|
save_lora_during: bool = True
|
|
save_preview_every: int = 5
|
|
save_state_after: bool = False
|
|
save_state_cancel: bool = False
|
|
save_state_during: bool = False
|
|
src: Union[str, None] = ""
|
|
shuffle_tags: bool = False
|
|
train_batch_size: int = 1
|
|
train_imagic: bool = False
|
|
stop_text_encoder: float = 0
|
|
use_8bit_adam: bool = False
|
|
use_concepts: bool = False
|
|
use_ema: bool = True
|
|
use_lora: bool = False
|
|
use_subdir: bool = True
|
|
|
|
|
|
import asyncio
|
|
|
|
def run_in_background(func, *args, **kwargs):
|
|
"""
|
|
Wrapper function to run a non-asynchronous method as a task in the event loop.
|
|
"""
|
|
async def wrapper():
|
|
new_func = functools.partial(func, *args, **kwargs)
|
|
await asyncio.get_running_loop().run_in_executor(None, new_func)
|
|
asyncio.create_task(wrapper())
|
|
def zip_files(db_model_name, files, name_part=""):
|
|
zip_buffer = io.BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, "a",
|
|
zipfile.ZIP_DEFLATED, False) as zip_file:
|
|
for file in files:
|
|
if isinstance(file, str):
|
|
print(f"Zipping img: {file}")
|
|
if os.path.exists(file) and os.path.isfile(file):
|
|
parent_path = os.path.join(Path(file).parent, Path(file).name)
|
|
zip_file.write(file, arcname=parent_path)
|
|
check_txt = os.path.join(os.path.splitext(file)[0], ".txt")
|
|
if os.path.exists(check_txt):
|
|
print(f"Zipping txt: {check_txt}")
|
|
parent_path = os.path.join(Path(check_txt).parent, Path(check_txt).name)
|
|
zip_file.write(check_txt, arcname=parent_path)
|
|
else:
|
|
img_byte_arr = io.BytesIO()
|
|
file.save(img_byte_arr, format='PNG')
|
|
img_byte_arr = img_byte_arr.getvalue()
|
|
file_name = hashlib.sha1(file.tobytes()).hexdigest()
|
|
image_filename = f"{file_name}.png"
|
|
zip_file.writestr(image_filename, img_byte_arr)
|
|
zip_file.close()
|
|
return StreamingResponse(
|
|
iter([zip_buffer.getvalue()]),
|
|
media_type="application/x-zip-compressed",
|
|
headers={"Content-Disposition": f"attachment; filename={db_model_name}{name_part}_images.zip"}
|
|
)
|
|
|
|
|
|
def check_api_key(key):
|
|
current_key = get_secret()
|
|
if current_key is not None and current_key != "":
|
|
if key is None or key == "":
|
|
return JSONResponse(status_code=401, content={"message": "API Key Required."})
|
|
if key != current_key:
|
|
return JSONResponse(status_code=403, content={"message": "Invalid API Key."})
|
|
return None
|
|
|
|
|
|
def base64_to_pil(im_b64) -> Image:
|
|
im_b64 = bytes(im_b64, 'utf-8')
|
|
im_bytes = base64.b64decode(im_b64) # im_bytes is a binary image
|
|
im_file = io.BytesIO(im_bytes) # convert image to file-like object
|
|
img = Image.open(im_file)
|
|
return img
|
|
|
|
|
|
def file_to_base64(file_path) -> str:
|
|
with open(file_path, "rb") as f:
|
|
im_b64 = base64.b64encode(f.read())
|
|
return str(im_b64, 'utf-8')
|
|
|
|
|
|
def dreambooth_api(_: gr.Blocks, app: FastAPI):
|
|
@app.post("/dreambooth/createModel")
|
|
async def create_model(
|
|
new_model_name: str = Query(None, description="The name of the model to create.", ),
|
|
new_model_src: str = Query(None, description="The source checkpoint to extract to create this model.", ),
|
|
new_model_scheduler: str = Query(None, description="The scheduler to use. V2+ models ignore this.", ),
|
|
create_from_hub: bool = Query(False, description="Create this model from the hub", ),
|
|
new_model_url: str = Query(None,
|
|
description="The hub URL to use for this model. Must contain diffusers model.", ),
|
|
new_model_token: str = Query(None, description="Your huggingface hub token.", ),
|
|
new_model_extract_ema: bool = Query(False, description="Whether to extract EMA weights if present.", ),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", ),
|
|
):
|
|
"""
|
|
Create a new Dreambooth model.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
|
|
if new_model_name is None or new_model_name == "":
|
|
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
|
|
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
|
|
|
|
print("Creating new Checkpoint: " + new_model_name)
|
|
_ = extract_checkpoint(new_model_name,
|
|
new_model_src,
|
|
new_model_scheduler,
|
|
create_from_hub,
|
|
new_model_url,
|
|
new_model_token,
|
|
new_model_extract_ema)
|
|
|
|
@app.post("/dreambooth/start_training")
|
|
async def start_training(
|
|
model_name: str = Query(None,
|
|
description="The model name to load params for.", ),
|
|
use_tx2img: bool = Query(True, description="Use txt2img to generate class images.", ),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", ),
|
|
):
|
|
"""
|
|
Start training dreambooth.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
|
|
if model_name is None or model_name == "":
|
|
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
|
|
config = from_file(model_name)
|
|
if config is None:
|
|
return JSONResponse(status_code=422, content={"message": "Invalid config."})
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
|
|
print("Starting Training")
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
db_shared.status.begin()
|
|
run_in_background(dreambooth.start_training, model_name, use_tx2img)
|
|
return {"Status": "Training started."}
|
|
|
|
@app.get("/dreambooth/cancel")
|
|
async def cancel_jobs(
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> \
|
|
Union[DreamState, JSONResponse]:
|
|
"""
|
|
Check the current state of Dreambooth processes.
|
|
@return:
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
if db_shared.status.job_count == 0:
|
|
return JSONResponse(content={"message": "Nothing to cancel."})
|
|
db_shared.status.interrupted = True
|
|
return JSONResponse(content={"message": f"Processes cancelled."})
|
|
|
|
@app.get("/dreambooth/status")
|
|
async def check_status(
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> \
|
|
Union[DreamState, JSONResponse]:
|
|
"""
|
|
Check the current state of Dreambooth processes.
|
|
@return:
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
return JSONResponse(content={"current_state": f"{json.dumps(db_shared.status.dict())}"})
|
|
|
|
@app.get("/dreambooth/status_images")
|
|
async def check_status_images(
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> JSONResponse:
|
|
"""
|
|
Retrieve any images that may currently be present in the state.
|
|
Args:
|
|
api_key: An API key, if one has been set in the UI.
|
|
|
|
Returns:
|
|
A single image or zip of images, depending on how many exist.
|
|
"""
|
|
db_shared.status.set_current_image()
|
|
images = db_shared.status.current_image
|
|
if not isinstance(images, List):
|
|
if images is not None:
|
|
images = [images]
|
|
else:
|
|
images = []
|
|
if len(images) == 0:
|
|
return JSONResponse(content={"message": "No images."})
|
|
if len(images) > 1:
|
|
return zip_files("status", images, "_sample")
|
|
else:
|
|
file = images[0]
|
|
if isinstance(file, str):
|
|
file = Image.open(file)
|
|
img_byte_arr = io.BytesIO()
|
|
file.save(img_byte_arr, format='PNG')
|
|
img_byte_arr = img_byte_arr.getvalue()
|
|
|
|
return Response(content=img_byte_arr, media_type="image/png")
|
|
|
|
@app.get("/dreambooth/model_config")
|
|
async def get_model_config(
|
|
model_name: str = Query(None, description="The model name to fetch config for."),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
) -> Union[DreamboothConfig, JSONResponse]:
|
|
"""
|
|
Get a specified model config.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
if model_name is None or model_name == "":
|
|
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
|
|
config = from_file(model_name)
|
|
if config is None:
|
|
return JSONResponse(status_code=422, content={"message": "Invalid config."})
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
|
|
return JSONResponse(content=config.__dict__)
|
|
|
|
@app.post("/dreambooth/model_config")
|
|
async def set_model_config(
|
|
model_cfg: DreamboothParameters = Body(description="The config to save"),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
):
|
|
"""
|
|
Save a model config from JSON.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
try:
|
|
print("Create config")
|
|
config = DreamboothConfig()
|
|
for key in model_cfg.dict():
|
|
if key in config.__dict__:
|
|
config.__dict__[key] = model_cfg.dict()[key]
|
|
config.save()
|
|
print("Saved?")
|
|
return JSONResponse(content=config.__dict__)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return {"Exception saving model": f"{e}"}
|
|
|
|
@app.get("/dreambooth/get_checkpoint")
|
|
async def get_checkpoint(
|
|
model_name: str = Query(description="The model name of the checkpoint to get."),
|
|
skip_build: bool = Query(True, description="Set to false to re-compile the checkpoint before retrieval."),
|
|
lora_model_name: str = Query("",
|
|
description="The (optional) name of the lora model to merge with the checkpoint."),
|
|
save_model_name: str = Query("", description="A custom name to use when generating the checkpoint."),
|
|
lora_weight: int = Query(1, description="The weight of the lora UNET when merged with the checkpoint."),
|
|
lora_text_weight: int = Query(1,
|
|
description="The weight of the lora Text Encoder when merged with the checkpoint."),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
):
|
|
"""
|
|
Generate and zip a checkpoint for a specified model.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
if model_name is None or model_name == "":
|
|
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
|
|
config = from_file(model_name)
|
|
if config is None:
|
|
return JSONResponse(status_code=422, content={"message": "Invalid config."})
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
path = None
|
|
if save_model_name == "" or save_model_name is None:
|
|
save_model_name = model_name
|
|
if skip_build:
|
|
ckpt_dir = db_shared.ckpt_dir
|
|
models_path = os.path.join(db_shared.models_path, "Stable-diffusion")
|
|
if ckpt_dir is not None:
|
|
models_path = ckpt_dir
|
|
use_subdir = False
|
|
if "use_subdir" in config.__dict__:
|
|
use_subdir = config["use_subdir"]
|
|
total_steps = config.revision
|
|
if use_subdir:
|
|
checkpoint_path = os.path.join(models_path, save_model_name, f"{save_model_name}_{total_steps}.ckpt")
|
|
else:
|
|
checkpoint_path = os.path.join(models_path, f"{save_model_name}_{total_steps}.ckpt")
|
|
print(f"Looking for checkpoint at {checkpoint_path}")
|
|
if os.path.exists(checkpoint_path):
|
|
print("Existing checkpoint found, returning.")
|
|
path = checkpoint_path
|
|
else:
|
|
skip_build = False
|
|
if not skip_build:
|
|
ckpt_result = compile_checkpoint(model_name, config.half_model, False, lora_model_name, lora_weight,
|
|
lora_text_weight, save_model_name, False, True)
|
|
if "Checkpoint compiled successfully" in ckpt_result:
|
|
path = ckpt_result.replace("Checkpoint compiled successfully:", "").strip()
|
|
print(f"Checkpoint aved to path: {path}")
|
|
|
|
if path is not None and os.path.exists(path):
|
|
print(f"Returning file response: {path}-{os.path.splitext(path)}")
|
|
return FileResponse(path)
|
|
|
|
return {"exception": f"Unable to find or compile checkpoint."}
|
|
|
|
@app.get("/dreambooth/list_checkpoints")
|
|
async def get_checkpoints(
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
):
|
|
"""
|
|
Collect the current list of available source checkpoints.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
sd_models.list_models()
|
|
ckpt_list = sd_models.checkpoints_list
|
|
return JSONResponse(content=ckpt_list)
|
|
|
|
@app.get("/dreambooth/samples")
|
|
async def generate_samples(
|
|
model_name: str = Query(description="The model name to use for generating samples."),
|
|
sample_prompt: str = Query(description="The prompt to use to generate sample images."),
|
|
num_images: int = Query(1, description="The number of sample images to generate."),
|
|
batch_size: int = Query(1, description="How many images to generate at once."),
|
|
lora_model_path: str = Query("", description="The path to a lora model to use when generating images."),
|
|
lora_rank: int = Query(1,
|
|
description="LORA rank when training, or something.."),
|
|
lora_weight: float = Query(1.0,
|
|
description="The weight of the lora unet when merging with the base model."),
|
|
lora_txt_weight: float = Query(1.0,
|
|
description="The weight of the lora text encoder when merging with the base model"),
|
|
negative_prompt: str = Query("", description="An optional negative prompt to use when generating images."),
|
|
seed: int = Query(-1, description="The seed to use when generating samples"),
|
|
steps: int = Query(60, description="Number of sampling steps to use when generating images."),
|
|
scale: float = Query(7.5, description="CFG scale to use when generating images."),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
):
|
|
"""
|
|
Generate sample images for a specified model.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
images, msg, status = ui_samples(
|
|
model_dir=model_name,
|
|
save_sample_prompt=sample_prompt,
|
|
num_samples=num_images,
|
|
sample_batch_size=batch_size,
|
|
lora_model_path=lora_model_path,
|
|
lora_rank=lora_rank,
|
|
lora_weight=lora_weight,
|
|
lora_txt_weight=lora_txt_weight,
|
|
negative_prompt=negative_prompt,
|
|
seed=seed,
|
|
steps=steps,
|
|
scale=scale
|
|
)
|
|
if len(images) > 1:
|
|
return zip_files(model_name, images, "_sample")
|
|
else:
|
|
img_byte_arr = io.BytesIO()
|
|
file = images[0]
|
|
file.save(img_byte_arr, format='PNG')
|
|
img_byte_arr = img_byte_arr.getvalue()
|
|
|
|
return Response(content=img_byte_arr, media_type="image/png")
|
|
|
|
@app.post("/dreambooth/classifiers")
|
|
async def generate_classes(
|
|
model_name: str = Query(description="The model name to generate classifiers for."),
|
|
use_txt2img: bool = Query("", description="Use Txt2Image to generate classifiers."),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.")
|
|
):
|
|
"""
|
|
Generate classification images for a model based on a saved config.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
if model_name is None or model_name == "":
|
|
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
|
|
config = from_file(model_name)
|
|
if config is None:
|
|
return JSONResponse(status_code=422, content={"message": "Invalid config."})
|
|
if db_shared.status.job_count != 0:
|
|
print("Something is already running.")
|
|
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
|
|
db_shared.status.begin()
|
|
run_in_background(
|
|
generate_classifiers,
|
|
config,
|
|
use_txt2img
|
|
)
|
|
return JSONResponse(content={"message":"Generating classifiers..."})
|
|
|
|
|
|
@app.get("/dreambooth/classifiers")
|
|
async def get_classifiers(
|
|
model_name: str = Query(description="The model name to retrieve classifiers for."),
|
|
concept_idx: int = Query(-1,
|
|
description="If set, will retrieve the specified concept's class images. Otherwise, all class images will be retrieved."),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
):
|
|
"""
|
|
Retrieve generated classifier images from a saved model config.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
config = from_file(model_name)
|
|
concepts = config.concepts_list
|
|
concept_dict = {}
|
|
out_images = []
|
|
if concept_idx >= 0:
|
|
if len(concepts) - 1 >= concept_idx:
|
|
print(f"Returning class images for concept {concept_idx}")
|
|
concept_dict[concept_idx] = concepts[concept_idx]
|
|
else:
|
|
return {"Exception": f"Concept index {concept_idx} out of range."}
|
|
else:
|
|
c_idx = 0
|
|
for concept in concepts:
|
|
concept_dict[c_idx] = concept
|
|
|
|
for concept_key in concept_dict:
|
|
concept = concept_dict[concept_key]
|
|
class_images_dir = concept["class_data_dir"]
|
|
if class_images_dir == "" or class_images_dir is None or class_images_dir == db_shared.script_path:
|
|
class_images_dir = os.path.join(config.model_dir, f"classifiers_{concept_key}")
|
|
print(f"Class image dir is not set, defaulting to {class_images_dir}")
|
|
if os.path.exists(class_images_dir):
|
|
from extensions.sd_dreambooth_extension.dreambooth.utils import get_images
|
|
class_images = get_images(class_images_dir)
|
|
for image in class_images:
|
|
out_images.append(str(image))
|
|
|
|
if len(out_images) > 0:
|
|
return zip_files(model_name, out_images, "_class")
|
|
else:
|
|
return {"Result": "No images found."}
|
|
|
|
@app.post("/dreambooth/upload")
|
|
async def upload_db_images(
|
|
model_name: str = Query(description="The model name to upload images for."),
|
|
instance_name: str = Query(description="The concept/instance name the images are for."),
|
|
images: DbImagesRequest = Body(description="A dictionary of images, filenames, and prompts to save."),
|
|
api_key: str = Query("", description="If an API key is set, this must be present.", )
|
|
):
|
|
"""
|
|
Upload images for training.
|
|
|
|
Request body should be a JSON Object. Primary key is 'imageList'.
|
|
|
|
'imageList' is a list of objects. Each object should have three values:
|
|
'data' - A base64-encoded string containing the binary data of the image.
|
|
'name' - The filename to store the image under.
|
|
'txt' - The caption for the image. Will be stored in a text file beside the image.
|
|
"""
|
|
key_check = check_api_key(api_key)
|
|
if key_check is not None:
|
|
return key_check
|
|
|
|
root_img_path = os.path.join(db_shared.script_path, "..", "InstanceImages")
|
|
if not os.path.exists(root_img_path):
|
|
print(f"Creating root instance dir: {root_img_path}")
|
|
os.makedirs(root_img_path)
|
|
|
|
image_dir = os.path.join(root_img_path, model_name, instance_name)
|
|
if not os.path.exists(image_dir):
|
|
os.makedirs(image_dir)
|
|
|
|
image_paths = []
|
|
for img_data in images.imageList:
|
|
img = base64_to_pil(img_data.data)
|
|
name = img_data.name
|
|
prompt = img_data.txt
|
|
image_path = os.path.join(image_dir, name)
|
|
text_path = os.path.splitext(image_path)[0]
|
|
text_path = F"{text_path}.txt"
|
|
print(f"Saving image to: {image_path}")
|
|
img.save(image_path)
|
|
print(f"Saving prompt to: {text_path}")
|
|
with open(text_path, "w") as tx_file:
|
|
tx_file.writelines(prompt)
|
|
image_paths.append(image_path)
|
|
|
|
return {"Status": f"Saved {len(image_paths)} images.", "Images": {x for x in image_paths}}
|
|
|
|
@app.get("/dreambooth/testimg")
|
|
async def generate_test_data():
|
|
model_dir = "E:\\dev\\sd_db\\mj_5"
|
|
text_getter = FilenameTextGetter(False)
|
|
instance_images = get_images(model_dir)
|
|
inst_datas = []
|
|
for x in instance_images:
|
|
image_bytes = file_to_base64(x)
|
|
name = x.name + x.suffix
|
|
txt = text_getter.read_text(x)
|
|
inst_datas.append(ImageData(name, txt, image_bytes).dict())
|
|
return JSONResponse(content=inst_datas)
|
|
|
|
|
|
try:
|
|
import modules.script_callbacks as script_callbacks
|
|
|
|
script_callbacks.on_app_started(dreambooth_api)
|
|
print("SD-Webui API layer loaded")
|
|
except:
|
|
pass
|