automatic/modules/api/dicts.py

245 lines
8.5 KiB
Python

"""V1 dictionary / tag autocomplete endpoints.
Serves pre-built tag dictionaries (Danbooru, e621, natural language, artists)
from JSON files in the configured dicts directory. Remote dicts are hosted
on HuggingFace and downloaded on demand.
"""
import asyncio
import json
import os
from fastapi.exceptions import HTTPException
from modules.api.models import ItemDict, ItemDictContent, ItemDictRemote
from modules.logger import log
dicts_dir: str = ""
cache: dict[str, dict] = {}
HF_REPO = "CalamitousFelicitousness/prompt-vocab"
HF_BASE = f"https://huggingface.co/datasets/{HF_REPO}/resolve/main"
MANIFEST_CACHE_SEC = 300 # re-fetch manifest every 5 minutes
manifest_cache: dict = {} # {"data": [...], "fetched_at": float}
def init(path: str) -> None:
"""Set the dicts directory path. Called once during API registration."""
global dicts_dir # noqa: PLW0603
dicts_dir = path
def get_cached(name: str) -> dict:
"""Load a dict file, returning cached version if file hasn't changed."""
if '/' in name or '\\' in name or '..' in name:
raise HTTPException(status_code=400, detail="Invalid dict name")
path = os.path.join(dicts_dir, f"{name}.json")
if not os.path.isfile(path):
cache.pop(name, None)
raise HTTPException(status_code=404, detail=f"Dict not found: {name}")
stat = os.stat(path)
entry = cache.get(name)
if entry and entry['mtime'] == stat.st_mtime:
return entry
with open(path, encoding='utf-8') as f:
data = json.load(f)
entry = {
'mtime': stat.st_mtime,
'size': stat.st_size,
'meta': {
'name': data.get('name', name),
'version': data.get('version', ''),
'tag_count': len(data.get('tags', [])),
'categories': {
str(k): v.get('name', str(k)) if isinstance(v, dict) else str(v)
for k, v in data.get('categories', {}).items()
},
},
'content': data,
}
cache[name] = entry
return entry
def list_dicts_sync() -> list[ItemDict]:
"""Scan dicts directory and return metadata for each dict file."""
if not dicts_dir or not os.path.isdir(dicts_dir):
return []
items = []
for filename in sorted(os.listdir(dicts_dir)):
if not filename.endswith('.json') or filename.startswith('.') or filename == 'manifest.json':
continue
name = filename.rsplit('.', 1)[0]
try:
entry = get_cached(name)
meta = entry['meta']
items.append(ItemDict(
name=meta['name'],
version=meta['version'],
tag_count=meta['tag_count'],
categories=meta['categories'],
size=entry['size'],
))
except Exception:
pass
return items
async def list_dicts() -> list[ItemDict]:
"""List available tag dictionaries."""
return await asyncio.to_thread(list_dicts_sync)
async def get_dict(name: str) -> ItemDictContent:
"""Get full dict content by name."""
def _load():
return get_cached(name)
try:
entry = await asyncio.to_thread(_load)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
content = entry['content']
return ItemDictContent(
name=content.get('name', name),
version=content.get('version', ''),
categories=content.get('categories', {}),
tags=content.get('tags', []),
)
# ── Remote dict management ──
def fetch_manifest_sync() -> list[dict]:
"""Fetch manifest.json from HuggingFace, with caching."""
import time
import requests
now = time.time()
if manifest_cache.get('data') and now - manifest_cache.get('fetched_at', 0) < MANIFEST_CACHE_SEC:
return manifest_cache['data']
url = f"{HF_BASE}/manifest.json"
try:
resp = requests.get(url, timeout=15)
resp.raise_for_status()
data = resp.json()
entries = data.get('dicts', data) if isinstance(data, dict) else data
manifest_cache['data'] = entries
manifest_cache['fetched_at'] = now
return entries
except Exception as e:
log.warning(f"Failed to fetch dict manifest: {e}")
return manifest_cache.get('data', [])
def local_dict_names() -> set[str]:
"""Return set of locally available dict names."""
if not dicts_dir or not os.path.isdir(dicts_dir):
return set()
return {
f.rsplit('.', 1)[0]
for f in os.listdir(dicts_dir)
if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json'
}
def local_dict_version(name: str) -> str:
"""Return the version string of a locally downloaded dict, or empty string."""
path = os.path.join(dicts_dir, f"{name}.json")
if not os.path.isfile(path):
return ""
try:
entry = cache.get(name)
if entry:
return entry['meta'].get('version', '')
with open(path, encoding='utf-8') as f:
data = json.load(f)
return data.get('version', '')
except Exception:
return ""
async def list_remote() -> list[ItemDictRemote]:
"""List dicts available for download from HuggingFace."""
entries = await asyncio.to_thread(fetch_manifest_sync)
local = await asyncio.to_thread(local_dict_names)
results = []
for e in entries:
name = e['name']
is_local = name in local
remote_version = e.get('version', '')
update = False
if is_local and remote_version:
local_version = await asyncio.to_thread(local_dict_version, name)
update = bool(local_version and local_version != remote_version)
results.append(ItemDictRemote(
name=name,
description=e.get('description', ''),
version=remote_version,
tag_count=e.get('tag_count', 0),
size_mb=e.get('size_mb', 0),
downloaded=is_local,
update_available=update,
))
return results
def download_dict_sync(name: str) -> str:
"""Download a dict file from HuggingFace to the local dicts directory."""
import requests
if '/' in name or '\\' in name or '..' in name:
raise HTTPException(status_code=400, detail="Invalid dict name")
os.makedirs(dicts_dir, exist_ok=True)
url = f"{HF_BASE}/{name}.json"
log.info(f"Downloading dict: {url}")
try:
resp = requests.get(url, timeout=120, stream=True)
resp.raise_for_status()
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"Failed to download {name}: {e}") from e
target = os.path.join(dicts_dir, f"{name}.json")
tmp = target + ".tmp"
size = 0
with open(tmp, 'wb') as f:
for chunk in resp.iter_content(chunk_size=1024 * 256):
f.write(chunk)
size += len(chunk)
os.replace(tmp, target)
cache.pop(name, None)
log.info(f"Downloaded dict: {name} ({size / 1024 / 1024:.1f} MB)")
return target
async def download_dict(name: str):
"""Download a dict from HuggingFace."""
path = await asyncio.to_thread(download_dict_sync, name)
entry = await asyncio.to_thread(get_cached, name)
meta = entry['meta']
return ItemDict(
name=meta['name'],
version=meta['version'],
tag_count=meta['tag_count'],
categories=meta['categories'],
size=entry['size'],
)
async def delete_dict(name: str):
"""Delete a locally downloaded dict."""
if '/' in name or '\\' in name or '..' in name:
raise HTTPException(status_code=400, detail="Invalid dict name")
path = os.path.join(dicts_dir, f"{name}.json")
if not os.path.isfile(path):
raise HTTPException(status_code=404, detail=f"Dict not found: {name}")
await asyncio.to_thread(os.remove, path)
cache.pop(name, None)
return {"status": "deleted", "name": name}
def register_api(app):
app.add_api_route("/sdapi/v1/dicts", list_dicts, methods=["GET"], response_model=list[ItemDict], tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/remote", list_remote, methods=["GET"], response_model=list[ItemDictRemote], tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/{name}", get_dict, methods=["GET"], response_model=ItemDictContent, tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/{name}/download", download_dict, methods=["POST"], response_model=ItemDict, tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/{name}", delete_dict, methods=["DELETE"], tags=["Enumerators"])