From f3e0b4a86b8d8c7c30db22e4dea6e2d6cd953f5f Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Tue, 24 Mar 2026 23:40:23 +0000 Subject: [PATCH] feat(dicts): add remote dict management via HuggingFace Add endpoints for browsing, downloading, and deleting dicts hosted on HuggingFace. Manifest is fetched and cached from the remote repo to show available dicts with download status and version comparison. - /dicts/remote GET: list available dicts from HF manifest - /dicts/{name}/download POST: download dict from HF with atomic write - /dicts/{name} DELETE: remove local copy and evict cache - ItemDictRemote model with downloaded/update_available flags - Exclude manifest.json from local dict listing in all scan paths --- modules/api/dicts.py | 144 ++++++++++++++++++++++++++++++++++++-- modules/api/models.py | 9 +++ modules/ui_definitions.py | 2 +- 3 files changed, 150 insertions(+), 5 deletions(-) diff --git a/modules/api/dicts.py b/modules/api/dicts.py index 1d6400717..f22e0b18d 100644 --- a/modules/api/dicts.py +++ b/modules/api/dicts.py @@ -1,21 +1,27 @@ """V1 dictionary / tag autocomplete endpoints. Serves pre-built tag dictionaries (Danbooru, e621, natural language, artists) -from JSON files in the configured dicts directory. +from JSON files in the configured dicts directory. Remote dicts are hosted +on HuggingFace and downloaded on demand. """ import asyncio import json import os -from datetime import datetime from fastapi.exceptions import HTTPException -from modules.api.models import ItemDict, ItemDictContent +from modules.api.models import ItemDict, ItemDictContent, ItemDictRemote +from modules.logger import log dicts_dir: str = "" cache: dict[str, dict] = {} +HF_REPO = "CalamitousFelicitousness/prompt-vocab" +HF_BASE = f"https://huggingface.co/datasets/{HF_REPO}/resolve/main" +MANIFEST_CACHE_SEC = 300 # re-fetch manifest every 5 minutes +manifest_cache: dict = {} # {"data": [...], "fetched_at": float} + def init(path: str) -> None: """Set the dicts directory path. Called once during API registration.""" @@ -61,7 +67,7 @@ def list_dicts_sync() -> list[ItemDict]: return [] items = [] for filename in sorted(os.listdir(dicts_dir)): - if not filename.endswith('.json') or filename.startswith('.'): + if not filename.endswith('.json') or filename.startswith('.') or filename == 'manifest.json': continue name = filename.rsplit('.', 1)[0] try: @@ -103,6 +109,136 @@ async def get_dict(name: str) -> ItemDictContent: ) +# ── Remote dict management ── + +def fetch_manifest_sync() -> list[dict]: + """Fetch manifest.json from HuggingFace, with caching.""" + import time + import requests + now = time.time() + if manifest_cache.get('data') and now - manifest_cache.get('fetched_at', 0) < MANIFEST_CACHE_SEC: + return manifest_cache['data'] + url = f"{HF_BASE}/manifest.json" + try: + resp = requests.get(url, timeout=15) + resp.raise_for_status() + data = resp.json() + entries = data.get('dicts', data) if isinstance(data, dict) else data + manifest_cache['data'] = entries + manifest_cache['fetched_at'] = now + return entries + except Exception as e: + log.warning(f"Failed to fetch dict manifest: {e}") + return manifest_cache.get('data', []) + + +def local_dict_names() -> set[str]: + """Return set of locally available dict names.""" + if not dicts_dir or not os.path.isdir(dicts_dir): + return set() + return { + f.rsplit('.', 1)[0] + for f in os.listdir(dicts_dir) + if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json' + } + + +def local_dict_version(name: str) -> str: + """Return the version string of a locally downloaded dict, or empty string.""" + path = os.path.join(dicts_dir, f"{name}.json") + if not os.path.isfile(path): + return "" + try: + entry = cache.get(name) + if entry: + return entry['meta'].get('version', '') + with open(path, encoding='utf-8') as f: + data = json.load(f) + return data.get('version', '') + except Exception: + return "" + + +async def list_remote() -> list[ItemDictRemote]: + """List dicts available for download from HuggingFace.""" + entries = await asyncio.to_thread(fetch_manifest_sync) + local = await asyncio.to_thread(local_dict_names) + results = [] + for e in entries: + name = e['name'] + is_local = name in local + remote_version = e.get('version', '') + update = False + if is_local and remote_version: + local_version = await asyncio.to_thread(local_dict_version, name) + update = bool(local_version and local_version != remote_version) + results.append(ItemDictRemote( + name=name, + description=e.get('description', ''), + version=remote_version, + tag_count=e.get('tag_count', 0), + size_mb=e.get('size_mb', 0), + downloaded=is_local, + update_available=update, + )) + return results + + +def download_dict_sync(name: str) -> str: + """Download a dict file from HuggingFace to the local dicts directory.""" + import requests + if '/' in name or '\\' in name or '..' in name: + raise HTTPException(status_code=400, detail="Invalid dict name") + os.makedirs(dicts_dir, exist_ok=True) + url = f"{HF_BASE}/{name}.json" + log.info(f"Downloading dict: {url}") + try: + resp = requests.get(url, timeout=120, stream=True) + resp.raise_for_status() + except requests.RequestException as e: + raise HTTPException(status_code=502, detail=f"Failed to download {name}: {e}") from e + target = os.path.join(dicts_dir, f"{name}.json") + tmp = target + ".tmp" + size = 0 + with open(tmp, 'wb') as f: + for chunk in resp.iter_content(chunk_size=1024 * 256): + f.write(chunk) + size += len(chunk) + os.replace(tmp, target) + cache.pop(name, None) + log.info(f"Downloaded dict: {name} ({size / 1024 / 1024:.1f} MB)") + return target + + +async def download_dict(name: str): + """Download a dict from HuggingFace.""" + path = await asyncio.to_thread(download_dict_sync, name) + entry = await asyncio.to_thread(get_cached, name) + meta = entry['meta'] + return ItemDict( + name=meta['name'], + version=meta['version'], + tag_count=meta['tag_count'], + categories=meta['categories'], + size=entry['size'], + ) + + +async def delete_dict(name: str): + """Delete a locally downloaded dict.""" + if '/' in name or '\\' in name or '..' in name: + raise HTTPException(status_code=400, detail="Invalid dict name") + path = os.path.join(dicts_dir, f"{name}.json") + if not os.path.isfile(path): + raise HTTPException(status_code=404, detail=f"Dict not found: {name}") + await asyncio.to_thread(os.remove, path) + cache.pop(name, None) + return {"status": "deleted", "name": name} + + def register_api(app): app.add_api_route("/sdapi/v1/dicts", list_dicts, methods=["GET"], response_model=list[ItemDict], tags=["Enumerators"]) + app.add_api_route("/sdapi/v1/dicts/remote", list_remote, methods=["GET"], response_model=list[ItemDictRemote], tags=["Enumerators"]) app.add_api_route("/sdapi/v1/dicts/{name}", get_dict, methods=["GET"], response_model=ItemDictContent, tags=["Enumerators"]) + app.add_api_route("/sdapi/v1/dicts/{name}/download", download_dict, methods=["POST"], response_model=ItemDict, tags=["Enumerators"]) + app.add_api_route("/sdapi/v1/dicts/{name}", delete_dict, methods=["DELETE"], tags=["Enumerators"]) diff --git a/modules/api/models.py b/modules/api/models.py index d3d030dec..16af70af3 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -522,6 +522,15 @@ class ItemDictContent(BaseModel): categories: dict = Field(default_factory=dict, title="Categories", description="Category definitions with name and color") tags: list = Field(default_factory=list, title="Tags", description="Tag entries as [name, category_id, post_count] tuples") +class ItemDictRemote(BaseModel): + name: str = Field(title="Name", description="Dictionary identifier") + description: str = Field(default="", title="Description", description="Human-readable description") + version: str = Field(default="", title="Version", description="Version string") + tag_count: int = Field(default=0, title="Tag count", description="Number of tags") + size_mb: float = Field(default=0, title="Size (MB)", description="Approximate file size in megabytes") + downloaded: bool = Field(default=False, title="Downloaded", description="Whether the dict is available locally") + update_available: bool = Field(default=False, title="Update available", description="Whether a newer version exists remotely") + # helper function def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list | None = None, exclude_fields: list[str] | None = None) -> type[BaseModel]: diff --git a/modules/ui_definitions.py b/modules/ui_definitions.py index 176257a52..ebcce9b4b 100644 --- a/modules/ui_definitions.py +++ b/modules/ui_definitions.py @@ -60,7 +60,7 @@ def list_dict_names(): return sorted( os.path.splitext(f)[0] for f in os.listdir(dicts_dir) - if f.endswith('.json') and not f.startswith('.') + if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json' )