automatic/modules/civitai/api_civitai.py

69 lines
2.4 KiB
Python

from starlette.responses import JSONResponse
def models_to_json(all_models:list, model_id:int=None):
dct = []
for model in all_models:
if model_id is not None and model.id != model_id:
continue
model_dct = model.__dict__.copy()
versions_dct = []
for version in model.versions:
version_dct = version.__dict__.copy()
version_dct['files'] = [f.__dict__.copy() for f in version.files]
version_dct['images'] = [i.__dict__.copy() for i in version.images]
versions_dct.append(version_dct)
model_dct['versions'] = versions_dct
dct.append(model_dct)
# obj = json.dumps(dct, indent=2, ensure_ascii=False)
return dct
def get_civitai(
model_id:int=None, # if model_id is provided assume fetch-from-cache
query:str = '', # search query or tag is required
tag:str = '', # search query or tag is required
types:str = '', # Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses
sort:str = '', # Highest Rated, Most Downloaded, Newest
period:str = '', # AllTime, Year, Month, Week, Day
nsfw:bool = None, # optional:bool
limit:int = 0,
base:str = '',
token:str = None,
exact:bool = True,
):
from modules.civitai import search_civitai
if model_id is not None:
dct = models_to_json(search_civitai.models, model_id=model_id)
return JSONResponse(content=dct, status_code=200)
if len(query) > 0 or len(tag) > 0:
models = search_civitai.search_civitai(
query=query,
tag=tag,
types=types,
sort=sort,
period=period,
nsfw=nsfw,
limit=limit,
base=base,
token=token,
exact=exact
)
dct = models_to_json(models)
return JSONResponse(content=dct, status_code=200)
return JSONResponse(content=[], status_code=200)
def post_civitai(page:str=None):
from modules.civitai import metadata_civitai
result = []
for r in metadata_civitai.civit_search_metadata(title=page, raw=True):
result = r # get the last yielded result
return result
def register_api():
from modules.shared import api
api.add_api_route("/sdapi/v1/civitai", get_civitai, methods=["GET"], response_model=list)
api.add_api_route("/sdapi/v1/civitai", post_civitai, methods=["POST"], response_model=list)