sd_dreambooth_extension/scripts/api.py

654 lines
28 KiB
Python

import base64
import base64
import functools
import hashlib
import io
import json
import os
import traceback
import zipfile
from pathlib import Path
import gradio as gr
from PIL import Image
from fastapi import FastAPI, Response, Query, Body
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from pydantic import BaseModel, Field
from pydantic.dataclasses import Union
from pydantic.types import List
from extensions.sd_dreambooth_extension.dreambooth import db_shared
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file, DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.db_shared import DreamState
from extensions.sd_dreambooth_extension.dreambooth.diff_to_sd import compile_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.finetune_utils import FilenameTextGetter, generate_classifiers
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.secret import get_secret
from extensions.sd_dreambooth_extension.dreambooth.utils import get_images
from extensions.sd_dreambooth_extension.scripts import dreambooth
from extensions.sd_dreambooth_extension.scripts.dreambooth import ui_samples
from modules import sd_models
class InstanceData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file")
name: str = Field(title="File name")
txt: str = Field(title="Prompt")
class ImageData:
def __init__(self, name, prompt, data):
self.name = name
self.prompt = prompt
self.data = data
def dict(self):
return {
"name": self.name,
"data": self.data,
"txt": self.prompt
}
class DbImagesRequest(BaseModel):
imageList: List[InstanceData] = Field(title="Images",
description="List of images to work on. Must be Base64 strings")
# API Representation of concept data
class DreamboothConcept(BaseModel):
instance_data_dir: str = ""
class_data_dir: str = ""
instance_prompt: str = ""
class_prompt: Union[str, None] = ""
save_sample_prompt: Union[str, None] = ""
save_sample_template: Union[str, None] = ""
instance_token: Union[str, None] = ""
class_token: Union[str, None] = ""
num_class_images_per: int = 0
class_negative_prompt: Union[str, None] = ""
class_guidance_scale: float = 7.5
class_infer_steps: int = 60
save_sample_negative_prompt: Union[str, None] = ""
n_save_sample: int = 1
sample_seed: int = -1
save_guidance_scale: float = 7.5
save_infer_steps: int = 60
# API Representation of db config
class DreamboothParameters(BaseModel):
concepts_list: List[DreamboothConcept]
attention: str = "default"
cache_latents: bool = True
center_crop: bool = False
clip_skip: int = 1
concepts_path: Union[str, None] = ""
custom_model_name: Union[str, None] = ""
epoch_pause_frequency: int = 0
epoch_pause_time: int = 60
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = True
gradient_set_to_none: bool = True
graph_smoothing: int = 50
half_model: bool = False
hflip: bool = True
learning_rate: float = 0.000002
learning_rate_min: float = 0.000001
lora_learning_rate: float = 0.0002
lora_model_name: str = ""
lora_rank: int = 4
lora_txt_learning_rate: float = 0.0002
lora_txt_weight: int = 1
lora_weight: int = 1
lr_cycles: int = 1
lr_factor: float = 0.5
lr_power: float = 1.0
lr_scale_pos: float = 0.5
lr_scheduler: str = "constant"
lr_warmup_steps: int = 500
max_token_length: int = 75
mixed_precision: str = "no"
adamw_weight_decay: float = 1e-2
model_name: str = ""
num_train_epochs: int = 100
pad_tokens: bool = True
pretrained_vae_name_or_path: Union[str, None] = ""
prior_loss_weight: float = 1.0
resolution: int = 512
revision: int = 0
sample_batch_size: int = 1
sanity_prompt: str = ""
sanity_seed: int = 420420
save_ckpt_after: bool = True
save_ckpt_cancel: bool = False
save_ckpt_during: bool = True
save_embedding_every: int = 25
save_lora_after: bool = True
save_lora_cancel: bool = False
save_lora_during: bool = True
save_preview_every: int = 5
save_state_after: bool = False
save_state_cancel: bool = False
save_state_during: bool = False
src: Union[str, None] = ""
shuffle_tags: bool = False
train_batch_size: int = 1
train_imagic: bool = False
stop_text_encoder: float = 0
use_8bit_adam: bool = False
use_concepts: bool = False
use_ema: bool = True
use_lora: bool = False
use_subdir: bool = True
import asyncio
def run_in_background(func, *args, **kwargs):
"""
Wrapper function to run a non-asynchronous method as a task in the event loop.
"""
async def wrapper():
new_func = functools.partial(func, *args, **kwargs)
await asyncio.get_running_loop().run_in_executor(None, new_func)
asyncio.create_task(wrapper())
def zip_files(db_model_name, files, name_part=""):
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "a",
zipfile.ZIP_DEFLATED, False) as zip_file:
for file in files:
if isinstance(file, str):
print(f"Zipping img: {file}")
if os.path.exists(file) and os.path.isfile(file):
parent_path = os.path.join(Path(file).parent, Path(file).name)
zip_file.write(file, arcname=parent_path)
check_txt = os.path.join(os.path.splitext(file)[0], ".txt")
if os.path.exists(check_txt):
print(f"Zipping txt: {check_txt}")
parent_path = os.path.join(Path(check_txt).parent, Path(check_txt).name)
zip_file.write(check_txt, arcname=parent_path)
else:
img_byte_arr = io.BytesIO()
file.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
file_name = hashlib.sha1(file.tobytes()).hexdigest()
image_filename = f"{file_name}.png"
zip_file.writestr(image_filename, img_byte_arr)
zip_file.close()
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/x-zip-compressed",
headers={"Content-Disposition": f"attachment; filename={db_model_name}{name_part}_images.zip"}
)
def check_api_key(key):
current_key = get_secret()
if current_key is not None and current_key != "":
if key is None or key == "":
return JSONResponse(status_code=401, content={"message": "API Key Required."})
if key != current_key:
return JSONResponse(status_code=403, content={"message": "Invalid API Key."})
return None
def base64_to_pil(im_b64) -> Image:
im_b64 = bytes(im_b64, 'utf-8')
im_bytes = base64.b64decode(im_b64) # im_bytes is a binary image
im_file = io.BytesIO(im_bytes) # convert image to file-like object
img = Image.open(im_file)
return img
def file_to_base64(file_path) -> str:
with open(file_path, "rb") as f:
im_b64 = base64.b64encode(f.read())
return str(im_b64, 'utf-8')
def dreambooth_api(_: gr.Blocks, app: FastAPI):
@app.post("/dreambooth/createModel")
async def create_model(
new_model_name: str = Query(None, description="The name of the model to create.", ),
new_model_src: str = Query(None, description="The source checkpoint to extract to create this model.", ),
new_model_scheduler: str = Query(None, description="The scheduler to use. V2+ models ignore this.", ),
create_from_hub: bool = Query(False, description="Create this model from the hub", ),
new_model_url: str = Query(None,
description="The hub URL to use for this model. Must contain diffusers model.", ),
new_model_token: str = Query(None, description="Your huggingface hub token.", ),
new_model_extract_ema: bool = Query(False, description="Whether to extract EMA weights if present.", ),
api_key: str = Query("", description="If an API key is set, this must be present.", ),
):
"""
Create a new Dreambooth model.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if new_model_name is None or new_model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
print("Creating new Checkpoint: " + new_model_name)
_ = extract_checkpoint(new_model_name,
new_model_src,
new_model_scheduler,
create_from_hub,
new_model_url,
new_model_token,
new_model_extract_ema)
@app.post("/dreambooth/start_training")
async def start_training(
model_name: str = Query(None,
description="The model name to load params for.", ),
use_tx2img: bool = Query(True, description="Use txt2img to generate class images.", ),
api_key: str = Query("", description="If an API key is set, this must be present.", ),
):
"""
Start training dreambooth.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
print("Starting Training")
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
db_shared.status.begin()
run_in_background(dreambooth.start_training, model_name, use_tx2img)
return {"Status": "Training started."}
@app.get("/dreambooth/cancel")
async def cancel_jobs(
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> \
Union[DreamState, JSONResponse]:
"""
Check the current state of Dreambooth processes.
@return:
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if db_shared.status.job_count == 0:
return JSONResponse(content={"message": "Nothing to cancel."})
db_shared.status.interrupted = True
return JSONResponse(content={"message": f"Processes cancelled."})
@app.get("/dreambooth/status")
async def check_status(
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> \
Union[DreamState, JSONResponse]:
"""
Check the current state of Dreambooth processes.
@return:
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
return JSONResponse(content={"current_state": f"{json.dumps(db_shared.status.dict())}"})
@app.get("/dreambooth/status_images")
async def check_status_images(
api_key: str = Query("", description="If an API key is set, this must be present.", )) -> JSONResponse:
"""
Retrieve any images that may currently be present in the state.
Args:
api_key: An API key, if one has been set in the UI.
Returns:
A single image or zip of images, depending on how many exist.
"""
db_shared.status.set_current_image()
images = db_shared.status.current_image
if not isinstance(images, List):
if images is not None:
images = [images]
else:
images = []
if len(images) == 0:
return JSONResponse(content={"message": "No images."})
if len(images) > 1:
return zip_files("status", images, "_sample")
else:
file = images[0]
if isinstance(file, str):
file = Image.open(file)
img_byte_arr = io.BytesIO()
file.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(content=img_byte_arr, media_type="image/png")
@app.get("/dreambooth/model_config")
async def get_model_config(
model_name: str = Query(None, description="The model name to fetch config for."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
) -> Union[DreamboothConfig, JSONResponse]:
"""
Get a specified model config.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
return JSONResponse(content=config.__dict__)
@app.post("/dreambooth/model_config")
async def set_model_config(
model_cfg: DreamboothParameters = Body(description="The config to save"),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Save a model config from JSON.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
try:
print("Create config")
config = DreamboothConfig()
for key in model_cfg.dict():
if key in config.__dict__:
config.__dict__[key] = model_cfg.dict()[key]
config.save()
print("Saved?")
return JSONResponse(content=config.__dict__)
except Exception as e:
traceback.print_exc()
return {"Exception saving model": f"{e}"}
@app.get("/dreambooth/get_checkpoint")
async def get_checkpoint(
model_name: str = Query(description="The model name of the checkpoint to get."),
skip_build: bool = Query(True, description="Set to false to re-compile the checkpoint before retrieval."),
lora_model_name: str = Query("",
description="The (optional) name of the lora model to merge with the checkpoint."),
save_model_name: str = Query("", description="A custom name to use when generating the checkpoint."),
lora_weight: int = Query(1, description="The weight of the lora UNET when merged with the checkpoint."),
lora_text_weight: int = Query(1,
description="The weight of the lora Text Encoder when merged with the checkpoint."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Generate and zip a checkpoint for a specified model.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
path = None
if save_model_name == "" or save_model_name is None:
save_model_name = model_name
if skip_build:
ckpt_dir = db_shared.ckpt_dir
models_path = os.path.join(db_shared.models_path, "Stable-diffusion")
if ckpt_dir is not None:
models_path = ckpt_dir
use_subdir = False
if "use_subdir" in config.__dict__:
use_subdir = config["use_subdir"]
total_steps = config.revision
if use_subdir:
checkpoint_path = os.path.join(models_path, save_model_name, f"{save_model_name}_{total_steps}.ckpt")
else:
checkpoint_path = os.path.join(models_path, f"{save_model_name}_{total_steps}.ckpt")
print(f"Looking for checkpoint at {checkpoint_path}")
if os.path.exists(checkpoint_path):
print("Existing checkpoint found, returning.")
path = checkpoint_path
else:
skip_build = False
if not skip_build:
ckpt_result = compile_checkpoint(model_name, config.half_model, False, lora_model_name, lora_weight,
lora_text_weight, save_model_name, False, True)
if "Checkpoint compiled successfully" in ckpt_result:
path = ckpt_result.replace("Checkpoint compiled successfully:", "").strip()
print(f"Checkpoint aved to path: {path}")
if path is not None and os.path.exists(path):
print(f"Returning file response: {path}-{os.path.splitext(path)}")
return FileResponse(path)
return {"exception": f"Unable to find or compile checkpoint."}
@app.get("/dreambooth/list_checkpoints")
async def get_checkpoints(
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Collect the current list of available source checkpoints.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
sd_models.list_models()
ckpt_list = sd_models.checkpoints_list
return JSONResponse(content=ckpt_list)
@app.get("/dreambooth/samples")
async def generate_samples(
model_name: str = Query(description="The model name to use for generating samples."),
sample_prompt: str = Query(description="The prompt to use to generate sample images."),
num_images: int = Query(1, description="The number of sample images to generate."),
batch_size: int = Query(1, description="How many images to generate at once."),
lora_model_path: str = Query("", description="The path to a lora model to use when generating images."),
lora_rank: int = Query(1,
description="LORA rank when training, or something.."),
lora_weight: float = Query(1.0,
description="The weight of the lora unet when merging with the base model."),
lora_txt_weight: float = Query(1.0,
description="The weight of the lora text encoder when merging with the base model"),
negative_prompt: str = Query("", description="An optional negative prompt to use when generating images."),
seed: int = Query(-1, description="The seed to use when generating samples"),
steps: int = Query(60, description="Number of sampling steps to use when generating images."),
scale: float = Query(7.5, description="CFG scale to use when generating images."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Generate sample images for a specified model.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
images, msg, status = ui_samples(
model_dir=model_name,
save_sample_prompt=sample_prompt,
num_samples=num_images,
sample_batch_size=batch_size,
lora_model_path=lora_model_path,
lora_rank=lora_rank,
lora_weight=lora_weight,
lora_txt_weight=lora_txt_weight,
negative_prompt=negative_prompt,
seed=seed,
steps=steps,
scale=scale
)
if len(images) > 1:
return zip_files(model_name, images, "_sample")
else:
img_byte_arr = io.BytesIO()
file = images[0]
file.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(content=img_byte_arr, media_type="image/png")
@app.post("/dreambooth/classifiers")
async def generate_classes(
model_name: str = Query(description="The model name to generate classifiers for."),
use_txt2img: bool = Query("", description="Use Txt2Image to generate classifiers."),
api_key: str = Query("", description="If an API key is set, this must be present.")
):
"""
Generate classification images for a model based on a saved config.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
if model_name is None or model_name == "":
return JSONResponse(status_code=422, content={"message": "Invalid model name."})
config = from_file(model_name)
if config is None:
return JSONResponse(status_code=422, content={"message": "Invalid config."})
if db_shared.status.job_count != 0:
print("Something is already running.")
return JSONResponse(content={"message": "Job already in progress.", "status": db_shared.status.dict()})
db_shared.status.begin()
run_in_background(
generate_classifiers,
config,
use_txt2img
)
return JSONResponse(content={"message":"Generating classifiers..."})
@app.get("/dreambooth/classifiers")
async def get_classifiers(
model_name: str = Query(description="The model name to retrieve classifiers for."),
concept_idx: int = Query(-1,
description="If set, will retrieve the specified concept's class images. Otherwise, all class images will be retrieved."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Retrieve generated classifier images from a saved model config.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
config = from_file(model_name)
concepts = config.concepts_list
concept_dict = {}
out_images = []
if concept_idx >= 0:
if len(concepts) - 1 >= concept_idx:
print(f"Returning class images for concept {concept_idx}")
concept_dict[concept_idx] = concepts[concept_idx]
else:
return {"Exception": f"Concept index {concept_idx} out of range."}
else:
c_idx = 0
for concept in concepts:
concept_dict[c_idx] = concept
for concept_key in concept_dict:
concept = concept_dict[concept_key]
class_images_dir = concept["class_data_dir"]
if class_images_dir == "" or class_images_dir is None or class_images_dir == db_shared.script_path:
class_images_dir = os.path.join(config.model_dir, f"classifiers_{concept_key}")
print(f"Class image dir is not set, defaulting to {class_images_dir}")
if os.path.exists(class_images_dir):
from extensions.sd_dreambooth_extension.dreambooth.utils import get_images
class_images = get_images(class_images_dir)
for image in class_images:
out_images.append(str(image))
if len(out_images) > 0:
return zip_files(model_name, out_images, "_class")
else:
return {"Result": "No images found."}
@app.post("/dreambooth/upload")
async def upload_db_images(
model_name: str = Query(description="The model name to upload images for."),
instance_name: str = Query(description="The concept/instance name the images are for."),
images: DbImagesRequest = Body(description="A dictionary of images, filenames, and prompts to save."),
api_key: str = Query("", description="If an API key is set, this must be present.", )
):
"""
Upload images for training.
Request body should be a JSON Object. Primary key is 'imageList'.
'imageList' is a list of objects. Each object should have three values:
'data' - A base64-encoded string containing the binary data of the image.
'name' - The filename to store the image under.
'txt' - The caption for the image. Will be stored in a text file beside the image.
"""
key_check = check_api_key(api_key)
if key_check is not None:
return key_check
root_img_path = os.path.join(db_shared.script_path, "..", "InstanceImages")
if not os.path.exists(root_img_path):
print(f"Creating root instance dir: {root_img_path}")
os.makedirs(root_img_path)
image_dir = os.path.join(root_img_path, model_name, instance_name)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
image_paths = []
for img_data in images.imageList:
img = base64_to_pil(img_data.data)
name = img_data.name
prompt = img_data.txt
image_path = os.path.join(image_dir, name)
text_path = os.path.splitext(image_path)[0]
text_path = F"{text_path}.txt"
print(f"Saving image to: {image_path}")
img.save(image_path)
print(f"Saving prompt to: {text_path}")
with open(text_path, "w") as tx_file:
tx_file.writelines(prompt)
image_paths.append(image_path)
return {"Status": f"Saved {len(image_paths)} images.", "Images": {x for x in image_paths}}
@app.get("/dreambooth/testimg")
async def generate_test_data():
model_dir = "E:\\dev\\sd_db\\mj_5"
text_getter = FilenameTextGetter(False)
instance_images = get_images(model_dir)
inst_datas = []
for x in instance_images:
image_bytes = file_to_base64(x)
name = x.name + x.suffix
txt = text_getter.read_text(x)
inst_datas.append(ImageData(name, txt, image_bytes).dict())
return JSONResponse(content=inst_datas)
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(dreambooth_api)
print("SD-Webui API layer loaded")
except:
pass