feat(dicts): add remote dict management via HuggingFace

Add endpoints for browsing, downloading, and deleting dicts hosted on
HuggingFace. Manifest is fetched and cached from the remote repo to
show available dicts with download status and version comparison.

- /dicts/remote GET: list available dicts from HF manifest
- /dicts/{name}/download POST: download dict from HF with atomic write
- /dicts/{name} DELETE: remove local copy and evict cache
- ItemDictRemote model with downloaded/update_available flags
- Exclude manifest.json from local dict listing in all scan paths
pull/4707/head
CalamitousFelicitousness 2026-03-24 23:40:23 +00:00
parent 0405689606
commit f3e0b4a86b
3 changed files with 150 additions and 5 deletions

View File

@ -1,21 +1,27 @@
"""V1 dictionary / tag autocomplete endpoints.
Serves pre-built tag dictionaries (Danbooru, e621, natural language, artists)
from JSON files in the configured dicts directory.
from JSON files in the configured dicts directory. Remote dicts are hosted
on HuggingFace and downloaded on demand.
"""
import asyncio
import json
import os
from datetime import datetime
from fastapi.exceptions import HTTPException
from modules.api.models import ItemDict, ItemDictContent
from modules.api.models import ItemDict, ItemDictContent, ItemDictRemote
from modules.logger import log
dicts_dir: str = ""
cache: dict[str, dict] = {}
HF_REPO = "CalamitousFelicitousness/prompt-vocab"
HF_BASE = f"https://huggingface.co/datasets/{HF_REPO}/resolve/main"
MANIFEST_CACHE_SEC = 300 # re-fetch manifest every 5 minutes
manifest_cache: dict = {} # {"data": [...], "fetched_at": float}
def init(path: str) -> None:
"""Set the dicts directory path. Called once during API registration."""
@ -61,7 +67,7 @@ def list_dicts_sync() -> list[ItemDict]:
return []
items = []
for filename in sorted(os.listdir(dicts_dir)):
if not filename.endswith('.json') or filename.startswith('.'):
if not filename.endswith('.json') or filename.startswith('.') or filename == 'manifest.json':
continue
name = filename.rsplit('.', 1)[0]
try:
@ -103,6 +109,136 @@ async def get_dict(name: str) -> ItemDictContent:
)
# ── Remote dict management ──
def fetch_manifest_sync() -> list[dict]:
"""Fetch manifest.json from HuggingFace, with caching."""
import time
import requests
now = time.time()
if manifest_cache.get('data') and now - manifest_cache.get('fetched_at', 0) < MANIFEST_CACHE_SEC:
return manifest_cache['data']
url = f"{HF_BASE}/manifest.json"
try:
resp = requests.get(url, timeout=15)
resp.raise_for_status()
data = resp.json()
entries = data.get('dicts', data) if isinstance(data, dict) else data
manifest_cache['data'] = entries
manifest_cache['fetched_at'] = now
return entries
except Exception as e:
log.warning(f"Failed to fetch dict manifest: {e}")
return manifest_cache.get('data', [])
def local_dict_names() -> set[str]:
"""Return set of locally available dict names."""
if not dicts_dir or not os.path.isdir(dicts_dir):
return set()
return {
f.rsplit('.', 1)[0]
for f in os.listdir(dicts_dir)
if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json'
}
def local_dict_version(name: str) -> str:
"""Return the version string of a locally downloaded dict, or empty string."""
path = os.path.join(dicts_dir, f"{name}.json")
if not os.path.isfile(path):
return ""
try:
entry = cache.get(name)
if entry:
return entry['meta'].get('version', '')
with open(path, encoding='utf-8') as f:
data = json.load(f)
return data.get('version', '')
except Exception:
return ""
async def list_remote() -> list[ItemDictRemote]:
"""List dicts available for download from HuggingFace."""
entries = await asyncio.to_thread(fetch_manifest_sync)
local = await asyncio.to_thread(local_dict_names)
results = []
for e in entries:
name = e['name']
is_local = name in local
remote_version = e.get('version', '')
update = False
if is_local and remote_version:
local_version = await asyncio.to_thread(local_dict_version, name)
update = bool(local_version and local_version != remote_version)
results.append(ItemDictRemote(
name=name,
description=e.get('description', ''),
version=remote_version,
tag_count=e.get('tag_count', 0),
size_mb=e.get('size_mb', 0),
downloaded=is_local,
update_available=update,
))
return results
def download_dict_sync(name: str) -> str:
"""Download a dict file from HuggingFace to the local dicts directory."""
import requests
if '/' in name or '\\' in name or '..' in name:
raise HTTPException(status_code=400, detail="Invalid dict name")
os.makedirs(dicts_dir, exist_ok=True)
url = f"{HF_BASE}/{name}.json"
log.info(f"Downloading dict: {url}")
try:
resp = requests.get(url, timeout=120, stream=True)
resp.raise_for_status()
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"Failed to download {name}: {e}") from e
target = os.path.join(dicts_dir, f"{name}.json")
tmp = target + ".tmp"
size = 0
with open(tmp, 'wb') as f:
for chunk in resp.iter_content(chunk_size=1024 * 256):
f.write(chunk)
size += len(chunk)
os.replace(tmp, target)
cache.pop(name, None)
log.info(f"Downloaded dict: {name} ({size / 1024 / 1024:.1f} MB)")
return target
async def download_dict(name: str):
"""Download a dict from HuggingFace."""
path = await asyncio.to_thread(download_dict_sync, name)
entry = await asyncio.to_thread(get_cached, name)
meta = entry['meta']
return ItemDict(
name=meta['name'],
version=meta['version'],
tag_count=meta['tag_count'],
categories=meta['categories'],
size=entry['size'],
)
async def delete_dict(name: str):
"""Delete a locally downloaded dict."""
if '/' in name or '\\' in name or '..' in name:
raise HTTPException(status_code=400, detail="Invalid dict name")
path = os.path.join(dicts_dir, f"{name}.json")
if not os.path.isfile(path):
raise HTTPException(status_code=404, detail=f"Dict not found: {name}")
await asyncio.to_thread(os.remove, path)
cache.pop(name, None)
return {"status": "deleted", "name": name}
def register_api(app):
app.add_api_route("/sdapi/v1/dicts", list_dicts, methods=["GET"], response_model=list[ItemDict], tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/remote", list_remote, methods=["GET"], response_model=list[ItemDictRemote], tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/{name}", get_dict, methods=["GET"], response_model=ItemDictContent, tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/{name}/download", download_dict, methods=["POST"], response_model=ItemDict, tags=["Enumerators"])
app.add_api_route("/sdapi/v1/dicts/{name}", delete_dict, methods=["DELETE"], tags=["Enumerators"])

View File

@ -522,6 +522,15 @@ class ItemDictContent(BaseModel):
categories: dict = Field(default_factory=dict, title="Categories", description="Category definitions with name and color")
tags: list = Field(default_factory=list, title="Tags", description="Tag entries as [name, category_id, post_count] tuples")
class ItemDictRemote(BaseModel):
name: str = Field(title="Name", description="Dictionary identifier")
description: str = Field(default="", title="Description", description="Human-readable description")
version: str = Field(default="", title="Version", description="Version string")
tag_count: int = Field(default=0, title="Tag count", description="Number of tags")
size_mb: float = Field(default=0, title="Size (MB)", description="Approximate file size in megabytes")
downloaded: bool = Field(default=False, title="Downloaded", description="Whether the dict is available locally")
update_available: bool = Field(default=False, title="Update available", description="Whether a newer version exists remotely")
# helper function
def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list | None = None, exclude_fields: list[str] | None = None) -> type[BaseModel]:

View File

@ -60,7 +60,7 @@ def list_dict_names():
return sorted(
os.path.splitext(f)[0]
for f in os.listdir(dicts_dir)
if f.endswith('.json') and not f.startswith('.')
if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json'
)