mirror of https://github.com/vladmandic/automatic
255 lines
9.1 KiB
Python
255 lines
9.1 KiB
Python
"""V1 tag autocomplete endpoints.
|
|
|
|
Serves pre-built tag files (Danbooru, e621, natural language, artists)
|
|
from JSON files in the configured autocomplete directory. Remote files
|
|
are hosted on HuggingFace and downloaded on demand.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
|
|
from fastapi.exceptions import HTTPException
|
|
|
|
from modules.api.models import ItemAutocomplete, ItemAutocompleteContent, ItemAutocompleteRemote
|
|
from modules.logger import log
|
|
|
|
|
|
autocomplete_dir: str = ""
|
|
cache: dict[str, dict] = {}
|
|
HF_REPO = "CalamitousFelicitousness/prompt-vocab"
|
|
HF_BASE = f"https://huggingface.co/datasets/{HF_REPO}/resolve/main"
|
|
MANIFEST_CACHE_SEC = 300 # re-fetch manifest every 5 minutes
|
|
manifest_cache: dict = {} # {"data": [...], "fetched_at": float}
|
|
|
|
|
|
def init(path: str) -> None:
|
|
"""Set the autocomplete directory path. Called once during API registration."""
|
|
global autocomplete_dir # pylint: disable=global-statement
|
|
autocomplete_dir = path
|
|
|
|
|
|
def get_cached(name: str) -> dict:
|
|
"""Load a tag file, returning cached version if file hasn't changed."""
|
|
if '/' in name or '\\' in name or '..' in name:
|
|
raise HTTPException(status_code=400, detail="Invalid name")
|
|
path = os.path.join(autocomplete_dir, f"{name}.json")
|
|
if not os.path.isfile(path):
|
|
cache.pop(name, None)
|
|
# Auto-download from HF if available in manifest
|
|
try:
|
|
manifest = fetch_manifest_sync()
|
|
if any(e.get('name') == name for e in manifest):
|
|
log.info(f'Autocomplete: name="{name}" auto-download')
|
|
download_sync(name)
|
|
else:
|
|
raise HTTPException(status_code=404, detail=f"Not found: {name}")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=404, detail=f"Not found: {name} ({e})") from e
|
|
stat = os.stat(path)
|
|
entry = cache.get(name)
|
|
if entry and entry['mtime'] == stat.st_mtime:
|
|
return entry
|
|
with open(path, encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
entry = {
|
|
'mtime': stat.st_mtime,
|
|
'size': stat.st_size,
|
|
'meta': {
|
|
'name': data.get('name', name),
|
|
'version': data.get('version', ''),
|
|
'tag_count': len(data.get('tags', [])),
|
|
'categories': {
|
|
str(k): v.get('name', str(k)) if isinstance(v, dict) else str(v)
|
|
for k, v in data.get('categories', {}).items()
|
|
},
|
|
},
|
|
'content': data,
|
|
}
|
|
cache[name] = entry
|
|
return entry
|
|
|
|
|
|
def list_all_sync() -> list[ItemAutocomplete]:
|
|
"""Scan autocomplete directory and return metadata for each tag file."""
|
|
if not autocomplete_dir or not os.path.isdir(autocomplete_dir):
|
|
return []
|
|
items = []
|
|
for filename in sorted(os.listdir(autocomplete_dir)):
|
|
if not filename.endswith('.json') or filename.startswith('.') or filename == 'manifest.json':
|
|
continue
|
|
name = filename.rsplit('.', 1)[0]
|
|
try:
|
|
entry = get_cached(name)
|
|
meta = entry['meta']
|
|
items.append(ItemAutocomplete(
|
|
name=meta['name'],
|
|
version=meta['version'],
|
|
tag_count=meta['tag_count'],
|
|
categories=meta['categories'],
|
|
size=entry['size'],
|
|
))
|
|
except Exception:
|
|
pass
|
|
return items
|
|
|
|
|
|
async def list_all() -> list[ItemAutocomplete]:
|
|
"""List available tag autocomplete files."""
|
|
return await asyncio.to_thread(list_all_sync)
|
|
|
|
|
|
async def get_content(name: str) -> ItemAutocompleteContent:
|
|
"""Get full tag file content by name."""
|
|
def _load():
|
|
return get_cached(name)
|
|
try:
|
|
entry = await asyncio.to_thread(_load)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
content = entry['content']
|
|
return ItemAutocompleteContent(
|
|
name=content.get('name', name),
|
|
version=content.get('version', ''),
|
|
categories=content.get('categories', {}),
|
|
tags=content.get('tags', []),
|
|
)
|
|
|
|
|
|
# -- Remote management --
|
|
|
|
def fetch_manifest_sync() -> list[dict]:
|
|
"""Fetch manifest.json from HuggingFace, with caching."""
|
|
import time
|
|
import requests
|
|
now = time.time()
|
|
if manifest_cache.get('data') and now - manifest_cache.get('fetched_at', 0) < MANIFEST_CACHE_SEC:
|
|
return manifest_cache['data']
|
|
url = f"{HF_BASE}/manifest.json"
|
|
try:
|
|
resp = requests.get(url, timeout=15)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
entries = data.get('entries', data) if isinstance(data, dict) else data
|
|
manifest_cache['data'] = entries
|
|
manifest_cache['fetched_at'] = now
|
|
return entries
|
|
except Exception as e:
|
|
log.warning(f"Autocomplete: Failed to fetch manifest: {e}")
|
|
return manifest_cache.get('data', [])
|
|
|
|
|
|
def local_names() -> set[str]:
|
|
"""Return set of locally available autocomplete file names."""
|
|
if not autocomplete_dir or not os.path.isdir(autocomplete_dir):
|
|
return set()
|
|
return {
|
|
f.rsplit('.', 1)[0]
|
|
for f in os.listdir(autocomplete_dir)
|
|
if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json'
|
|
}
|
|
|
|
|
|
def local_version(name: str) -> str:
|
|
"""Return the version string of a local tag file, or empty string."""
|
|
path = os.path.join(autocomplete_dir, f"{name}.json")
|
|
if not os.path.isfile(path):
|
|
return ""
|
|
try:
|
|
entry = cache.get(name)
|
|
if entry:
|
|
return entry['meta'].get('version', '')
|
|
with open(path, encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
return data.get('version', '')
|
|
except Exception:
|
|
return ""
|
|
|
|
|
|
async def list_remote() -> list[ItemAutocompleteRemote]:
|
|
"""List tag files available for download from HuggingFace."""
|
|
entries = await asyncio.to_thread(fetch_manifest_sync)
|
|
local = await asyncio.to_thread(local_names)
|
|
results = []
|
|
for e in entries:
|
|
name = e['name']
|
|
is_local = name in local
|
|
remote_version = e.get('version', '')
|
|
update = False
|
|
if is_local and remote_version:
|
|
lv = await asyncio.to_thread(local_version, name)
|
|
update = bool(lv and lv != remote_version)
|
|
results.append(ItemAutocompleteRemote(
|
|
name=name,
|
|
description=e.get('description', ''),
|
|
version=remote_version,
|
|
tag_count=e.get('tag_count', 0),
|
|
size_mb=e.get('size_mb', 0),
|
|
downloaded=is_local,
|
|
update_available=update,
|
|
))
|
|
return results
|
|
|
|
|
|
def download_sync(name: str) -> str:
|
|
"""Download a tag file from HuggingFace to the local autocomplete directory."""
|
|
import requests
|
|
if '/' in name or '\\' in name or '..' in name:
|
|
raise HTTPException(status_code=400, detail="Invalid name")
|
|
os.makedirs(autocomplete_dir, exist_ok=True)
|
|
url = f"{HF_BASE}/{name}.json"
|
|
try:
|
|
resp = requests.get(url, timeout=120, stream=True)
|
|
resp.raise_for_status()
|
|
except requests.RequestException as e:
|
|
raise HTTPException(status_code=502, detail=f"Failed to download {name}: {e}") from e
|
|
target = os.path.join(autocomplete_dir, f"{name}.json")
|
|
tmp = target + ".tmp"
|
|
size = 0
|
|
with open(tmp, 'wb') as f:
|
|
for chunk in resp.iter_content(chunk_size=1024 * 256):
|
|
f.write(chunk)
|
|
size += len(chunk)
|
|
os.replace(tmp, target)
|
|
cache.pop(name, None)
|
|
log.info(f'Autocomplete: name="{name}" url={url} ({size / 1024 / 1024:.2f}MB) downloaded')
|
|
return target
|
|
|
|
|
|
async def download(name: str):
|
|
"""Download a tag file from HuggingFace."""
|
|
await asyncio.to_thread(download_sync, name)
|
|
entry = await asyncio.to_thread(get_cached, name)
|
|
meta = entry['meta']
|
|
return ItemAutocomplete(
|
|
name=meta['name'],
|
|
version=meta['version'],
|
|
tag_count=meta['tag_count'],
|
|
categories=meta['categories'],
|
|
size=entry['size'],
|
|
)
|
|
|
|
|
|
async def delete(name: str):
|
|
"""Delete a locally downloaded tag file."""
|
|
if '/' in name or '\\' in name or '..' in name:
|
|
raise HTTPException(status_code=400, detail="Invalid name")
|
|
path = os.path.join(autocomplete_dir, f"{name}.json")
|
|
if not os.path.isfile(path):
|
|
raise HTTPException(status_code=404, detail=f"Not found: {name}")
|
|
await asyncio.to_thread(os.remove, path)
|
|
cache.pop(name, None)
|
|
return {"status": "deleted", "name": name}
|
|
|
|
|
|
def register_api(api):
|
|
api.add_api_route("/sdapi/v1/autocomplete", list_all, methods=["GET"], response_model=list[ItemAutocomplete], tags=["Enumerators"])
|
|
api.add_api_route("/sdapi/v1/autocomplete/remote", list_remote, methods=["GET"], response_model=list[ItemAutocompleteRemote], tags=["Enumerators"])
|
|
api.add_api_route("/sdapi/v1/autocomplete/{name}", get_content, methods=["GET"], response_model=ItemAutocompleteContent, tags=["Enumerators"])
|
|
api.add_api_route("/sdapi/v1/autocomplete/{name}/download", download, methods=["POST"], response_model=ItemAutocomplete, tags=["Enumerators"])
|
|
api.add_api_route("/sdapi/v1/autocomplete/{name}", delete, methods=["DELETE"], tags=["Enumerators"])
|