sd_dreambooth_extension/scripts/api.py

937 lines
38 KiB
Python

import base64
import functools
import hashlib
import io
import json
import logging
import os
import shutil
import traceback
import zipfile
from pathlib import Path
from typing import List, Union
import requests
from PIL import Image
from fastapi import FastAPI, Response, Query, Body, Form, Header
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from pydantic import BaseModel, Field
from starlette import status
from starlette.requests import Request
try:
from dreambooth import shared
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.secret import get_secret
from dreambooth.shared import DreamState
from dreambooth.ui_functions import create_model, generate_samples, \
start_training
from dreambooth.utils.gen_utils import generate_classifiers
from dreambooth.utils.image_utils import get_images
from dreambooth.utils.model_utils import get_db_models, get_lora_models
except:
print("Exception importing api")
traceback.print_exc()
if os.environ.get("DEBUG_API", False):
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
print("No, really, API loaded, wtf...")
class InstanceData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file or URL")
name: str = Field(title="File name", description="File name to save image as")
txt: str = Field(title="Prompt", description="Training prompt for image")
class ImageData:
def __init__(self, name, prompt, data):
self.name = name
self.prompt = prompt
self.data = data
def dict(self):
return {
"name": self.name,
"data": self.data,
"txt": self.prompt
}
class DbImagesRequest(BaseModel):
imageList: List[InstanceData] = Field(title="Images",
description="List of images to work on. Must be Base64 strings")
import asyncio
active = False
def is_running():
return False
def run_in_background(func, *args, **kwargs):
"""
Wrapper function to run a non-asynchronous method as a task in the event loop.
"""
async def wrapper():
global active
new_func = functools.partial(func, *args, **kwargs)
await asyncio.get_running_loop().run_in_executor(None, new_func)
active = False
asyncio.create_task(wrapper())
def zip_files(db_model_name, files, name_part=""):
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "a",
zipfile.ZIP_DEFLATED, False) as zip_file:
for file in files:
if isinstance(file, str):
logger.debug(f"Zipping img: {file}")
if os.path.exists(file) and os.path.isfile(file):
parent_path = os.path.join(Path(file).parent, Path(file).name)
zip_file.write(file, arcname=parent_path)
check_txt = os.path.join(os.path.splitext(file)[0], ".txt")
if os.path.exists(check_txt):
logger.debug(f"Zipping txt: {check_txt}")
parent_path = os.path.join(Path(check_txt).parent, Path(check_txt).name)
zip_file.write(check_txt, arcname=parent_path)
else:
img_byte_arr = io.BytesIO()
file.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
file_name = hashlib.sha1(file.tobytes()).hexdigest()
image_filename = f"{file_name}.png"
zip_file.writestr(image_filename, img_byte_arr)
zip_file.close()
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/x-zip-compressed",
headers={"Content-Disposition": f"attachment; filename={db_model_name}{name_part}_images.zip"}
)
def check_api_key(key):
current_key = get_secret()
if current_key is not None and current_key != "":
if key is None or key == "":
return JSONResponse(status_code=401, content={"message": "API Key Required."})
if key != current_key:
return JSONResponse(status_code=403, content={"message": "Invalid API Key."})
return None
def base64_to_pil(im_b64) -> Image:
im_b64 = bytes(im_b64, 'utf-8')
im_bytes = base64.b64decode(im_b64) # im_bytes is a binary image
im_file = io.BytesIO(im_bytes) # convert image to file-like object
img = Image.open(im_file)
return img
def file_to_base64(file_path) -> str:
with open(file_path, "rb") as f:
im_b64 = base64.b64encode(f.read())
return str(im_b64, 'utf-8')
def dreambooth_api(_, app: FastAPI):
print("API LOAD")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
)
@app.get("/dreambooth/cancel")
async def cancel_jobs(
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> \
JSONResponse:
"""
Check the current state of Dreambooth processes.
@return:
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if shared.status.job_count == 0:
return JSONResponse(content={"message": "Nothing to cancel."})
shared.status.interrupted = True
return JSONResponse(content={"message": f"Processes cancelled."})
@app.get("/dreambooth/checkpoint")
async def get_checkpoint(
model_name: str = Query(description="The model name of the checkpoint to get."),
skip_build: bool = Query(True, description="Set to false to re-compile the checkpoint before retrieval."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Generate and zip a checkpoint for a specified model.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
status = is_running()
if status:
return status
path = None
save_model_name = config.model_name
if config.custom_model_name:
save_model_name = config.custom_model_name
if skip_build:
ckpt_dir = shared.ckpt_dir
models_path = os.path.join(shared.models_path, "Stable-diffusion")
if ckpt_dir is not None:
models_path = ckpt_dir
use_subdir = False
if "use_subdir" in config.__dict__:
use_subdir = config["use_subdir"]
total_steps = config.revision
if use_subdir:
checkpoint_path = os.path.join(models_path, save_model_name, f"{save_model_name}_{total_steps}.ckpt")
else:
checkpoint_path = os.path.join(models_path, f"{save_model_name}_{total_steps}.ckpt")
if config.save_safetensors:
checkpoint_path = checkpoint_path.replace(".ckpt", ".safetensors")
logger.debug(f"Looking for checkpoint at {checkpoint_path}")
if os.path.exists(checkpoint_path):
logger.debug("Existing checkpoint found, returning.")
path = checkpoint_path
else:
skip_build = False
if not skip_build:
global active
shared.status.begin()
active = True
ckpt_result = compile_checkpoint(model_name, reload_models=False, log=False)
active = False
shared.status.end()
if "Checkpoint compiled successfully" in ckpt_result:
path = ckpt_result.replace("Checkpoint compiled successfully:", "").strip()
logger.debug(f"Checkpoint aved to path: {path}")
if path is not None and os.path.exists(path):
logger.debug(f"Returning file response: {path}-{os.path.splitext(path)}")
return FileResponse(path)
return {"exception": f"Unable to find or compile checkpoint."}
@app.get("/dreambooth/checkpoints")
async def get_checkpoints(
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Collect the current list of available source checkpoints.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
ckpt_list = {}
try:
from modules import sd_models
sd_models.list_models()
ckpt_list = sd_models.checkpoints_list
except:
pass
models = []
for key, _ in ckpt_list.items():
models.append(key)
return JSONResponse(content=models)
@app.post("/dreambooth/classifiers")
async def generate_classes(
model_name: str = Form(description="The model name to generate classifiers for."),
use_txt2img: bool = Form("", description="Use Txt2Image to generate classifiers."),
api_key: str = Form("", description="If an API key is set, this must be present.")
):
"""
Generate classification images for a model based on a saved config.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
status = is_running()
if status:
return status
global active
active = True
shared.status.begin()
run_in_background(
generate_classifiers,
config,
use_txt2img
)
active = False
return JSONResponse(content={"message": "Generating classifiers..."})
@app.get("/dreambooth/classifiers")
async def get_classifiers(
model_name: str = Query(description="The model name to retrieve classifiers for."),
concept_idx: int = Query(-1,
description="If set, will retrieve the specified concept's class images. Otherwise, all class images will be retrieved."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Retrieve generated classifier images from a saved model config.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
config = from_file(model_name)
concepts = config.concepts()
concept_dict = {}
out_images = []
if concept_idx >= 0:
if len(concepts) - 1 >= concept_idx:
logger.debug(f"Returning class images for concept {concept_idx}")
concept_dict[concept_idx] = concepts[concept_idx]
else:
return {"Exception": f"Concept index {concept_idx} out of range."}
else:
c_idx = 0
for concept in concepts:
concept_dict[c_idx] = concept
for concept_key in concept_dict:
concept = concept_dict[concept_key]
class_images_dir = concept["class_data_dir"]
if class_images_dir == "" or class_images_dir is None or class_images_dir == shared.script_path:
class_images_dir = os.path.join(config.model_dir, f"classifiers_{concept_key}")
logger.debug(f"Class image dir is not set, defaulting to {class_images_dir}")
if os.path.exists(class_images_dir):
class_images = get_images(class_images_dir)
for image in class_images:
out_images.append(str(image))
if len(out_images) > 0:
return zip_files(model_name, out_images, "_class")
else:
return {"Result": "No images found."}
@app.post("/dreambooth/concept")
async def set_model_concept(
model_name: str = Form(description="The model name to fetch config for."),
instance_dir: str = Form("", description="The directory containing training images."),
instance_token: str = Form("", description="The instance token to use."),
class_token: str = Form("", description="The class token to use."),
api_key: str = Form("", description="If an API key is set, this must be present."),
concept: Union[Concept, None] = Body(None, description="A concept to update or add to the model.")
) -> JSONResponse:
"""
Add or update a concept. Provide either a full json concept or path to instance dir.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
new_concepts = []
if concept is None and instance_dir != "":
new_concept = Concept()
new_concept.instance_data_dir = instance_dir
new_concept.instance_token = instance_token
new_concept.class_token = class_token
new_concept.class_prompt = "[filewords]"
new_concept.instance_prompt = "[filewords]"
new_concept.save_sample_prompt = "[filewords]"
new_concept.is_valid = os.path.exists(instance_dir)
existing_concepts = config.concepts()
replaced = False
for ex_concept in existing_concepts:
if ex_concept.instance_data_dir == concept.instance_data_dir:
new_concepts.append(concept.__dict__)
replaced = True
else:
new_concepts.append(ex_concept)
if not replaced:
new_concepts.append(concept.__dict__)
config.concepts_list = new_concepts
config.save()
return JSONResponse(content=config.concepts())
@app.get("/dreambooth/concepts")
async def get_model_concepts(
model_name: str = Query(description="The model name to fetch config for."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
) -> JSONResponse:
"""
Get a model's concepts.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
return JSONResponse(content=config.concepts())
@app.post("/dreambooth/concepts")
async def set_model_concepts(
model_name: str = Form(description="The model name to fetch config for."),
api_key: str = Form("", description="If an API key is set, this must be present."),
concepts: List[Concept] = Body()
) -> JSONResponse:
"""
Replace a full concepts list.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
new_concepts = []
for concept in concepts:
new_concepts.append(concept.__dict__)
config.concepts_list = new_concepts
config.save()
return JSONResponse(content=config.concepts())
@app.post("/dreambooth/createModel")
async def create_db_model(
new_model_name: str = Query(description="The name of the model to create.", ),
new_model_src: str = Query(description="The source checkpoint to extract to create this model.", ),
new_model_scheduler: str = Query("ddim", description="The scheduler to use. V2+ models ignore this.", ),
create_from_hub: bool = Query(False, description="Create this model from the hub", ),
new_model_url: str = Query(None,
description="The hub URL to use for this model. Must contain diffusers model.", ),
is_512: bool = Query(False,
description="Whether or not the model is 512x resolution.", ),
train_unfrozen: bool = Query(True,
description="Un-freeze the model.", ),
new_model_token: str = Query(None, description="Your huggingface hub token.", ),
new_model_extract_ema: bool = Query(False, description="Whether to extract EMA weights if present.", ),
api_key: str = Query("", description="If an API key is set, this must be present.", ),
):
"""
Create a new Dreambooth model.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if new_model_name is None or new_model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
status = is_running()
if status:
return status
logger.debug("Creating new Checkpoint: " + new_model_name)
res = create_model(new_model_name,
new_model_src,
create_from_hub,
new_model_url,
new_model_token,
new_model_extract_ema,
train_unfrozen,
is_512)
return JSONResponse(res[-1])
@app.delete("/dreambooth/model")
async def delete_model(
model_name: str = Form(description="The model to delete."),
api_key: str = Form("", description="If an API key is set, this must be present."),
) -> JSONResponse:
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
model_dir = config.model_dir
models_path = os.path.join(shared.models_path, "stable-diffusion")
model_base = config.custom_model_name if config.custom_model_name != "" else config.model_name
if config.use_subdir:
models_path = os.path.join(models_path, model_base)
model_files = os.listdir(models_path)
for mf in model_files:
rev = mf.split("_")[-1]
try:
revision = int(rev)
if mf == f"{model_base}_{revision}.safetensors":
full_file = os.path.join(models_path, mf)
print(f"Removing model: {full_file}")
os.remove(full_file)
except:
pass
try:
shutil.rmtree(model_dir,True)
except:
pass
return JSONResponse(f"Model {model_name} has been deleted.")
@app.get("/dreambooth/model_config")
async def get_model_config(
model_name: str = Query(description="The model name to fetch config for."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
) -> JSONResponse:
"""
Get a specified model config.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
return JSONResponse(content=config.__dict__)
@app.post("/dreambooth/model_config")
async def set_model_config(
model_cfg: DreamboothConfig = Body(description="The config to save"),
api_key: str = Header(description="If an API key is set, this must be present.", default="")
):
"""
Save a model config from JSON.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
try:
logger.debug("Create config")
config = DreamboothConfig()
for key in model_cfg.dict():
if key in config.__dict__:
config.__dict__[key] = model_cfg.dict()[key]
config.save()
logger.debug("Saved?")
return JSONResponse(content=config.__dict__)
except Exception as e:
traceback.print_exc()
return {"Exception saving model": f"{e}"}
@app.post("/dreambooth/model_params")
async def set_model_params(
model_name: str = Query(description="The model name to update params for."),
api_key: str = Query("", description="If an API key is set, this must be present."),
params: str = Query(description="A json string representing a dictionary of parameters to set.")
) -> JSONResponse:
"""
Update an existing model configuration's parameters from a dictionary of values.
"""
params = json.loads(params)
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
logger.debug(f"Loading new params: {params}")
config.load_params(params)
config.save()
return JSONResponse(content=config.__dict__)
@app.get("/dreambooth/models")
async def get_models(
api_key: str = Query("", description="If an API key is set, this must be present."),
) -> JSONResponse:
"""
Args:
api_key: The api key
Returns: A list of Dreambooth model names.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
models = get_db_models()
return JSONResponse(models)
@app.get("/dreambooth/models_lora")
async def get_models_lora(
api_key: str = Query("", description="If an API key is set, this must be present."),
model_name: str = Query(description="The model name to query for lora files."),
) -> JSONResponse:
"""
Args:
api_key: API Key.
model_name: The model name to query for lora files.
Returns: A list of LoRA Models.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
config = from_file(model_name)
if model_name and config is None:
return JSONResponse("Config not found")
models = get_lora_models(config)
return JSONResponse(models)
@app.get("/dreambooth/samples")
async def api_generate_samples(
model_name: str = Query(description="The model name to use for generating samples."),
sample_prompt: str = Query("", description="The prompt to use to generate sample images."),
sample_prompt_file: str = Query("", description="A file of pre-generated sample prompts."),
negative_prompt: str = Query("", description="An optional negative prompt to use when generating images."),
width: int = Query(512, description="Sample width"),
height: int = Query(512, description="Sample height"),
num_images: int = Query(1, description="The number of sample images to generate."),
batch_size: int = Query(1, description="How many images to generate at once."),
seed: int = Query(-1, description="The seed to use when generating samples"),
steps: int = Query(60, description="Number of sampling steps to use when generating images."),
scale: float = Query(7.5, description="CFG scale to use when generating images."),
use_txt2img: bool = Query(True, description="Use txt2img to generate samples"),
scheduler: str = Query("DEISMultistep", description="Sampler to use if not using txt2img"),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Generate sample images for a specified model.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
status = is_running()
if status:
return status
shared.status.begin()
config = from_file(model_name)
if config is None:
return JSONResponse("Config not found")
images, prompts, status = generate_samples(
model_name,
prompt=sample_prompt,
prompt_file=sample_prompt_file,
negative_prompt=negative_prompt,
width=width,
height=height,
num_samples=num_images,
batch_size=batch_size,
seed=seed,
scale=scale,
steps=steps,
use_txt2img=use_txt2img,
scheduler=scheduler
)
shared.status.end()
if len(images) > 1:
return zip_files(model_name, images, "_sample")
else:
file = images[0]
image = Image.open(file)
img_byte_arr = image.tobytes()
return Response(content=img_byte_arr, media_type="image/png")
@app.get("/dreambooth/status")
async def check_status(
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> \
JSONResponse:
"""
Check the current state of Dreambooth processes.
@return:
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
return JSONResponse(content={"current_state": f"{json.dumps(shared.status.dict())}"})
@app.get("/dreambooth/status_images")
async def check_status_images(
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> JSONResponse:
"""
Retrieve any images that may currently be present in the state.
Args:
api_key: An API key, if one has been set in the UI.
Returns:
A single image or zip of images, depending on how many exist.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
shared.status.set_current_image()
images = shared.status.current_image
if not isinstance(images, List):
if images is not None:
images = [images]
else:
images = []
if len(images) == 0:
return JSONResponse(content={"message": "No images."})
if len(images) > 1:
return zip_files("status", images, "_sample")
else:
file = images[0]
if isinstance(file, str):
file = Image.open(file)
img_byte_arr = io.BytesIO()
file.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(content=img_byte_arr, media_type="image/png")
@app.post("/dreambooth/start_training")
async def train(
model_name: str = Query(None,
description="The model name to load params for.", ),
use_tx2img: bool = Query(True, description="Use txt2img to generate class images."),
api_key: str = Query("", description="If an API key is set, this must be present.")
):
"""
Start training dreambooth.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
status = is_running()
if status:
return status
logger.debug("Starting Training")
shared.status.begin()
run_in_background(start_training, model_name, use_tx2img)
return {"Status": "Training started."}
@app.post("/dreambooth/upload")
async def upload_db_images(
model_name: str = Query(description="The model name to upload images for."),
instance_name: str = Query(description="The concept/instance name the images are for."),
create_concept: bool = Query(True,
description="Enable to automatically append the new concept to the model config."),
images: DbImagesRequest = Body(description="A dictionary of images, filenames, and prompts to save."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Upload images for training.
Request body should be a JSON Object. Primary key is 'imageList'.
'imageList' is a list of objects. Each object should have three values:
'data' - A base64-encoded string containing the binary data of the image.
'name' - The filename to store the image under.
'txt' - The caption for the image. Will be stored in a text file beside the image.
"""
logger.debug("API UPLOAD STARTED.")
key_check = check_api_key(api_key)
if key_check is not None:
logger.debug("NO KEY")
return key_check
root_img_path = os.path.join(shared.script_path, "..", "InstanceImages")
if not os.path.exists(root_img_path):
logger.debug(f"Creating root instance dir: {root_img_path}")
os.makedirs(root_img_path)
else:
logger.debug(f"Root dir exists already: {root_img_path}")
image_dir = os.path.join(root_img_path, model_name, instance_name)
image_dir = os.path.abspath(image_dir)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
logger.debug(f"Input data: {images}")
image_paths = []
for img_data in images.imageList:
img = base64_to_pil(img_data.data)
name = img_data.name
prompt = img_data.txt
logger.debug(f"Input prompt for image: {prompt} {name}")
image_path = os.path.join(image_dir, name)
text_path = os.path.splitext(image_path)[0]
text_path = F"{text_path}.txt"
logger.debug(f"Saving image to: {image_path}")
img.save(image_path)
logger.debug(f"Saving prompt ({prompt}) to: {text_path}")
with open(text_path, "w") as tx_file:
tx_file.writelines(prompt)
logger.debug(f"Saved prompt text to: {text_path}")
image_paths.append(image_path)
status = {"Status": f"Saved {len(image_paths)} images.", "Image dir": {image_dir}}
if create_concept:
config = from_file(model_name)
logger.debug(f"Creating concept: {model_name}")
if config is None:
status["Status"] += " Unable to load model config."
new_concept = Concept()
new_concept.instance_data_dir = image_dir
new_concept.class_prompt = "[filewords]"
new_concept.instance_prompt = "[filewords]"
new_concept.save_sample_prompt = "[filewords]"
new_concept.is_valid = True
logger.debug(f"New concept: {new_concept}")
new_concepts = []
replaced = False
for concept in config.concepts():
if concept.instance_data_dir == new_concept.instance_data_dir:
new_concepts.append(new_concept.__dict__)
replaced = True
else:
new_concepts.append(concept.__dict__)
if not replaced:
new_concepts.append(new_concept.__dict__)
config.concepts_list = new_concepts
config.save()
logger.debug("Saved concepts.")
status["Concepts"] = config.concepts_list
return status
@app.post("/dreambooth/upload_url")
async def upload_db_images_url(
model_name: str = Query(description="The model name to upload images for."),
instance_name: str = Query(description="The concept/instance name the images are for."),
create_concept: bool = Query(True,
description="Enable to automatically append the new concept to the model config."),
images: DbImagesRequest = Body(description="A dictionary of images, filenames, and prompts to save."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Upload images for training.
Request body should be a JSON Object. Primary key is 'imageList'.
'imageList' is a list of objects. Each object should have three values:
'data' - A base64-encoded string containing the binary data of the image.
'name' - The filename to store the image under.
'txt' - The caption for the image. Will be stored in a text file beside the image.
"""
logger.debug("API UPLOAD STARTED.")
key_check = check_api_key(api_key)
if key_check is not None:
logger.debug("NO KEY")
return key_check
root_img_path = os.path.join(shared.script_path, "..", "InstanceImages")
if not os.path.exists(root_img_path):
logger.debug(f"Creating root instance dir: {root_img_path}")
os.makedirs(root_img_path)
else:
logger.debug(f"Root dir exists already: {root_img_path}")
image_dir = os.path.join(root_img_path, model_name, instance_name)
image_dir = os.path.abspath(image_dir)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
logger.debug(f"Input data: {images}")
image_paths = []
for img_data in images.imageList:
img_url = img_data.data
name = img_data.name
prompt = img_data.txt
logger.debug(f"Input prompt for image: {prompt} {name}")
image_path = os.path.join(image_dir, name)
text_path = os.path.splitext(image_path)[0]
text_path = F"{text_path}.txt"
logger.debug(f"Saving image to: {image_path}")
try:
response = requests.get(img_url, stream=True)
response.raise_for_status()
with open(image_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e:
logger.exception(f"Error downloading image from {img_url}: {e}")
continue
logger.debug(f"Saving prompt ({prompt}) to: {text_path}")
with open(text_path, "w") as tx_file:
tx_file.writelines(prompt)
logger.debug(f"Saved prompt text to: {text_path}")
image_paths.append(image_path)
status = {"Status": f"Saved {len(image_paths)} images.", "Image dir": {image_dir}}
if create_concept:
config = from_file(model_name)
logger.debug(f"Creating concept: {model_name}")
if config is None:
status["Status"] += " Unable to load model config."
new_concept = Concept()
new_concept.instance_data_dir = image_dir
new_concept.class_prompt = "[filewords]"
new_concept.instance_prompt = "[filewords]"
new_concept.save_sample_prompt = "[filewords]"
new_concept.is_valid = True
logger.debug(f"New concept: {new_concept}")
new_concepts = []
replaced = False
for concept in config.concepts():
if concept.instance_data_dir == new_concept.instance_data_dir:
new_concepts.append(new_concept.__dict__)
replaced = True
else:
new_concepts.append(concept.__dict__)
if not replaced:
new_concepts.append(new_concept.__dict__)
config.concepts_list = new_concepts
config.save()
logger.debug("Saved concepts.")
status["Concepts"] = config.concepts_list
return status
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(dreambooth_api)
logger.debug("SD-Webui API layer loaded")
except:
logger.debug("Unable to import script callbacks.")
pass