305 lines
11 KiB
Python
305 lines
11 KiB
Python
import atexit
|
|
import json
|
|
import random
|
|
import tempfile
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass, replace
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List
|
|
from deforum_api_models import Batch, DeforumJobErrorType, DeforumJobStatusCategory, DeforumJobPhase, DeforumJobStatus
|
|
|
|
import gradio as gr
|
|
from deforum_helpers.args import (DeforumAnimArgs, DeforumArgs,
|
|
DeforumOutputArgs, LoopArgs, ParseqArgs,
|
|
RootArgs, get_component_names)
|
|
from fastapi import FastAPI, Response, status
|
|
|
|
from modules.shared import cmd_opts, opts
|
|
|
|
|
|
def make_ids(job_count: int):
|
|
batch_id = f"batch({random.randint(0, 1e9)})"
|
|
job_ids = [f"{batch_id}-{i}" for i in range(job_count)]
|
|
return [batch_id, job_ids]
|
|
|
|
|
|
def get_default_value(name:str):
|
|
allArgs = RootArgs() | DeforumAnimArgs() | DeforumArgs() | LoopArgs() | ParseqArgs() | DeforumOutputArgs()
|
|
if name in allArgs and isinstance(allArgs[name], dict):
|
|
return allArgs[name].get("value", None)
|
|
elif name in allArgs:
|
|
return allArgs[name]
|
|
else:
|
|
return None
|
|
|
|
def run_deforum_batch(batch_id : str, deforum_settings : List[Any] ):
|
|
|
|
print("started run_deforum_batch")
|
|
|
|
# Fill args with default values.
|
|
# We are overriding everything with the batch files, but some values are eagerly validated, so must appear valid.
|
|
component_names = get_component_names()
|
|
prefixed_gradio_args = 2
|
|
expected_arg_count = prefixed_gradio_args + len(component_names)
|
|
run_deforum_args = [None] * expected_arg_count
|
|
for idx, name in enumerate(component_names):
|
|
run_deforum_args[prefixed_gradio_args + idx] = get_default_value(name)
|
|
|
|
# For some values, defaults don't pass validation...
|
|
run_deforum_args[prefixed_gradio_args + component_names.index('animation_prompts')] = '{"0":"dummy value"}'
|
|
run_deforum_args[prefixed_gradio_args + component_names.index('animation_prompts_negative')] = ''
|
|
|
|
# Arg 0 is a UID for the batch
|
|
run_deforum_args[0] = batch_id
|
|
|
|
# Setup batch override
|
|
run_deforum_args[prefixed_gradio_args + component_names.index('override_settings_with_file')] = True
|
|
run_deforum_args[prefixed_gradio_args + component_names.index('custom_settings_file')] = deforum_settings
|
|
|
|
# Invoke deforum with appropriate args
|
|
from deforum_helpers.run_deforum import run_deforum
|
|
run_deforum(*run_deforum_args)
|
|
|
|
|
|
|
|
#
|
|
# API to allow a batch of jobs to be submitted to the deforum pipeline.
|
|
# A batch is settings object OR a list of settings objects.
|
|
# A settings object is the JSON structure you can find in your saved settings.txt files.
|
|
#
|
|
# Request format:
|
|
# {
|
|
# "deforum_settings": [
|
|
# { ... settings object ... },
|
|
# { ... settings object ... },
|
|
# ]
|
|
# }
|
|
# OR:
|
|
# {
|
|
# "deforum_settings": { ... settings object ... }
|
|
# }
|
|
#
|
|
# Each settings object in the request represents a job to run as part of the batch.
|
|
# Each submitted batch will be given a batch ID which the user can use to query the status of all jobs in the batch.
|
|
#
|
|
def deforum_api(_: gr.Blocks, app: FastAPI):
|
|
|
|
apiState = ApiState()
|
|
|
|
@app.post("/deforum_api/batches")
|
|
async def run_batch(batch: Batch, response: Response):
|
|
|
|
# Extract the settings files from the request
|
|
deforum_settings_data = batch.deforum_settings
|
|
if not deforum_settings_data:
|
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
return {"message": "No settings files provided. Please provide an element 'deforum_settings' of type list in the request JSON payload."}
|
|
|
|
if not isinstance(deforum_settings_data, list):
|
|
deforum_settings_data = [deforum_settings_data]
|
|
|
|
deforum_settings_tempfiles = []
|
|
for data in deforum_settings_data:
|
|
temp_file = tempfile.NamedTemporaryFile(mode='w+t', delete=False)
|
|
json.dump(data, temp_file)
|
|
temp_file.close()
|
|
deforum_settings_tempfiles.append(temp_file)
|
|
|
|
job_count = len(deforum_settings_tempfiles)
|
|
[batch_id, job_ids] = make_ids(job_count)
|
|
apiState.submit_job(batch_id, job_ids, deforum_settings_tempfiles, batch.options_overrides)
|
|
|
|
for job_id in job_ids:
|
|
JobStatusTracker().accept_job(batch_id=batch_id, job_id=job_id)
|
|
|
|
response.status_code = status.HTTP_202_ACCEPTED
|
|
return {"message": "Job(s) accepted", "batch_id": batch_id, "job_ids": job_ids }
|
|
|
|
|
|
@app.get("/deforum_api/batches/{id}")
|
|
async def get_batch(id: str):
|
|
jobsForBatch = JobStatusTracker().batches[id]
|
|
return [JobStatusTracker().get(job_id) for job_id in jobsForBatch]
|
|
|
|
## TODO ##
|
|
@app.delete("/deforum_api/batches/{id}")
|
|
async def stop_batch(id: str, response: Response):
|
|
response.status_code = status.HTTP_501_NOT_IMPLEMENTED
|
|
return {"id": id, "status": "NOT IMPLEMENTED"}
|
|
|
|
@app.get("/deforum_api/jobs")
|
|
async def list_jobs():
|
|
return JobStatusTracker().statuses
|
|
|
|
@app.get("/deforum_api/jobs/{id}")
|
|
async def get_job(id: str):
|
|
return JobStatusTracker().get(id)
|
|
|
|
## TODO ##
|
|
@app.delete("/deforum_api/jobs/{id}")
|
|
async def stop_job(id: str, response: Response):
|
|
response.status_code = status.HTTP_501_NOT_IMPLEMENTED
|
|
return {"id": id, "status": "NOT IMPLEMENTED"}
|
|
|
|
class Singleton(type):
|
|
_instances = {}
|
|
def __call__(cls, *args, **kwargs):
|
|
if cls not in cls._instances:
|
|
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
|
return cls._instances[cls]
|
|
|
|
# Maintains persistent state required by API, e.g. thread pook, list of submitted jobs.
|
|
class ApiState(metaclass=Singleton):
|
|
deforum_api_executor = ThreadPoolExecutor(max_workers=5)
|
|
submitted_jobs : Dict[str, Any] = {}
|
|
|
|
@staticmethod
|
|
def cleanup():
|
|
ApiState().deforum_api_executor.shutdown(wait=False)
|
|
|
|
def submit_job(self, batch_id: str, job_ids: [str], deforum_settings: List[Any], opts_overrides: Dict[str, Any] ):
|
|
def task():
|
|
print("started task")
|
|
try:
|
|
if opts_overrides is not None and len(opts_overrides)>1:
|
|
original_opts = {k: opts.data[k] for k in opts_overrides.keys() if k in opts.data}
|
|
print(f"Captured options to override: {original_opts}")
|
|
print(f"Overriding with: {opts_overrides}")
|
|
for k, v in opts_overrides.items():
|
|
setattr(opts, k, v)
|
|
run_deforum_batch(batch_id, deforum_settings)
|
|
except Exception as e:
|
|
print(f"Batch {batch_id} failed: {e}")
|
|
traceback.print_exc()
|
|
# Mark all jobs in this batch as failed
|
|
for job_id in job_ids:
|
|
JobStatusTracker().fail_job(job_id, 'TERMINAL', {e})
|
|
finally:
|
|
if (original_opts is not None):
|
|
print(f"Restoring options")
|
|
for k, v in original_opts.items():
|
|
setattr(opts, k, v)
|
|
|
|
|
|
print("submitting task")
|
|
future = self.deforum_api_executor.submit(task)
|
|
self.submitted_jobs[batch_id] = future
|
|
|
|
atexit.register(ApiState.cleanup)
|
|
|
|
# Maintains state that tracks status of submitted jobs,
|
|
# so that clients can query job status.
|
|
class JobStatusTracker(metaclass=Singleton):
|
|
statuses: Dict[str, DeforumJobStatus] = {}
|
|
batches: Dict[str, List[str]] = {}
|
|
|
|
def accept_job(self, batch_id : str, job_id: str):
|
|
if batch_id in self.batches:
|
|
self.batches[batch_id].append(job_id)
|
|
else:
|
|
self.batches[batch_id] = [job_id]
|
|
|
|
now = datetime.now().timestamp()
|
|
self.statuses[job_id] = DeforumJobStatus(
|
|
id=job_id,
|
|
status= DeforumJobStatusCategory.ACCEPTED,
|
|
phase=DeforumJobPhase.QUEUED,
|
|
error_type=DeforumJobErrorType.NONE,
|
|
phase_progress=0.0,
|
|
started_at=now,
|
|
last_updated=now,
|
|
execution_time=0,
|
|
update_interval_time=0,
|
|
updates=0,
|
|
message=None,
|
|
outdir=None,
|
|
timestring=None
|
|
)
|
|
|
|
def update_phase(self, job_id: str, phase: DeforumJobPhase, progress: float = 0):
|
|
if job_id in self.statuses:
|
|
current_status = self.statuses[job_id]
|
|
now = datetime.now().timestamp()
|
|
new_status = replace(
|
|
current_status,
|
|
phase=phase,
|
|
phase_progress=progress,
|
|
last_updated=now,
|
|
execution_time=now-current_status.started_at,
|
|
update_interval_time=now-current_status.last_updated,
|
|
updates=current_status.updates+1
|
|
)
|
|
self.statuses[job_id] = new_status
|
|
|
|
def update_output_info(self, job_id: str, outdir: str, timestring: str):
|
|
if job_id in self.statuses:
|
|
current_status = self.statuses[job_id]
|
|
now = datetime.now().timestamp()
|
|
new_status = replace(
|
|
current_status,
|
|
outdir=outdir,
|
|
timestring=timestring,
|
|
last_updated=now,
|
|
execution_time=now-current_status.started_at,
|
|
update_interval_time=now-current_status.last_updated,
|
|
updates=current_status.updates+1
|
|
)
|
|
self.statuses[job_id] = new_status
|
|
|
|
def complete_job(self, job_id: str):
|
|
if job_id in self.statuses:
|
|
current_status = self.statuses[job_id]
|
|
now = datetime.now().timestamp()
|
|
new_status = replace(
|
|
current_status,
|
|
status=DeforumJobStatusCategory.SUCCEEDED,
|
|
phase=DeforumJobPhase.DONE,
|
|
phase_progress=1.0,
|
|
last_updated=now,
|
|
execution_time=now-current_status.started_at,
|
|
update_interval_time=now-current_status.last_updated,
|
|
updates=current_status.updates+1
|
|
)
|
|
self.statuses[job_id] = new_status
|
|
|
|
def fail_job(self, job_id: str, error_type: str, message: str):
|
|
if job_id in self.statuses:
|
|
current_status = self.statuses[job_id]
|
|
now = datetime.now().timestamp()
|
|
new_status = replace(
|
|
current_status,
|
|
status=DeforumJobStatusCategory.FAILED,
|
|
error_type=error_type,
|
|
message=message,
|
|
last_updated=now,
|
|
execution_time=now-current_status.started_at,
|
|
update_interval_time=now-current_status.last_updated,
|
|
updates=current_status.updates+1
|
|
)
|
|
self.statuses[job_id] = new_status
|
|
|
|
def get(self, job_id:str):
|
|
return self.statuses[job_id]
|
|
|
|
def deforum_init_batch(_: gr.Blocks, app: FastAPI):
|
|
settings_files = [open(filename, 'r') for filename in cmd_opts.deforum_run_now.split(",")]
|
|
[batch_id, job_ids] = make_ids(len(settings_files))
|
|
print(f"Starting batch with job(s) {job_ids}...")
|
|
|
|
run_deforum_batch(batch_id, settings_files)
|
|
|
|
if cmd_opts.deforum_terminate_after_run_now:
|
|
import os
|
|
os._exit(0)
|
|
|
|
# Setup A1111 initialisation hooks
|
|
try:
|
|
import modules.script_callbacks as script_callbacks
|
|
if cmd_opts.deforum_api:
|
|
script_callbacks.on_app_started(deforum_api)
|
|
if cmd_opts.deforum_run_now:
|
|
script_callbacks.on_app_started(deforum_init_batch)
|
|
except:
|
|
pass
|