91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
from typing import Callable
|
|
from threading import Lock
|
|
from secrets import compare_digest
|
|
|
|
from modules import shared
|
|
from modules.api.api import decode_base64_to_image
|
|
from modules.call_queue import queue_lock
|
|
from fastapi import FastAPI, Depends, HTTPException
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
|
|
from tagger import utils
|
|
from tagger import api_models as models
|
|
|
|
|
|
class Api:
|
|
def __init__(self, app: FastAPI, queue_lock: Lock, prefix: str = None) -> None:
|
|
if shared.cmd_opts.api_auth:
|
|
self.credentials = dict()
|
|
for auth in shared.cmd_opts.api_auth.split(","):
|
|
user, password = auth.split(":")
|
|
self.credentials[user] = password
|
|
|
|
self.app = app
|
|
self.queue_lock = queue_lock
|
|
self.prefix = prefix
|
|
|
|
self.add_api_route(
|
|
'interrogate',
|
|
self.endpoint_interrogate,
|
|
methods=['POST'],
|
|
response_model=models.TaggerInterrogateResponse
|
|
)
|
|
|
|
self.add_api_route(
|
|
'interrogators',
|
|
self.endpoint_interrogators,
|
|
methods=['GET'],
|
|
response_model=models.InterrogatorsResponse
|
|
)
|
|
|
|
def auth(self, creds: HTTPBasicCredentials = Depends(HTTPBasic())):
|
|
if creds.username in self.credentials:
|
|
if compare_digest(creds.password, self.credentials[creds.username]):
|
|
return True
|
|
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Incorrect username or password",
|
|
headers={
|
|
"WWW-Authenticate": "Basic"
|
|
})
|
|
|
|
def add_api_route(self, path: str, endpoint: Callable, **kwargs):
|
|
if self.prefix:
|
|
path = f'{self.prefix}/{path}'
|
|
|
|
if shared.cmd_opts.api_auth:
|
|
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
|
|
|
def endpoint_interrogate(self, req: models.TaggerInterrogateRequest):
|
|
if req.image is None:
|
|
raise HTTPException(404, 'Image not found')
|
|
|
|
if req.model not in utils.interrogators.keys():
|
|
raise HTTPException(404, 'Model not found')
|
|
|
|
image = decode_base64_to_image(req.image)
|
|
interrogator = utils.interrogators[req.model]
|
|
|
|
with self.queue_lock:
|
|
ratings, tags = interrogator.interrogate(image)
|
|
|
|
return models.TaggerInterrogateResponse(
|
|
caption={
|
|
**ratings,
|
|
**interrogator.postprocess_tags(
|
|
tags,
|
|
req.threshold
|
|
)
|
|
})
|
|
|
|
def endpoint_interrogators(self):
|
|
return models.InterrogatorsResponse(
|
|
models=list(utils.interrogators.keys())
|
|
)
|
|
|
|
|
|
def on_app_started(_, app: FastAPI):
|
|
Api(app, queue_lock, '/tagger/v1')
|