diff --git a/modules/api/dicts.py b/modules/api/dicts.py index 1d6400717..f22e0b18d 100644 --- a/modules/api/dicts.py +++ b/modules/api/dicts.py @@ -1,21 +1,27 @@ """V1 dictionary / tag autocomplete endpoints. Serves pre-built tag dictionaries (Danbooru, e621, natural language, artists) -from JSON files in the configured dicts directory. +from JSON files in the configured dicts directory. Remote dicts are hosted +on HuggingFace and downloaded on demand. """ import asyncio import json import os -from datetime import datetime from fastapi.exceptions import HTTPException -from modules.api.models import ItemDict, ItemDictContent +from modules.api.models import ItemDict, ItemDictContent, ItemDictRemote +from modules.logger import log dicts_dir: str = "" cache: dict[str, dict] = {} +HF_REPO = "CalamitousFelicitousness/prompt-vocab" +HF_BASE = f"https://huggingface.co/datasets/{HF_REPO}/resolve/main" +MANIFEST_CACHE_SEC = 300 # re-fetch manifest every 5 minutes +manifest_cache: dict = {} # {"data": [...], "fetched_at": float} + def init(path: str) -> None: """Set the dicts directory path. Called once during API registration.""" @@ -61,7 +67,7 @@ def list_dicts_sync() -> list[ItemDict]: return [] items = [] for filename in sorted(os.listdir(dicts_dir)): - if not filename.endswith('.json') or filename.startswith('.'): + if not filename.endswith('.json') or filename.startswith('.') or filename == 'manifest.json': continue name = filename.rsplit('.', 1)[0] try: @@ -103,6 +109,136 @@ async def get_dict(name: str) -> ItemDictContent: ) +# ── Remote dict management ── + +def fetch_manifest_sync() -> list[dict]: + """Fetch manifest.json from HuggingFace, with caching.""" + import time + import requests + now = time.time() + if manifest_cache.get('data') and now - manifest_cache.get('fetched_at', 0) < MANIFEST_CACHE_SEC: + return manifest_cache['data'] + url = f"{HF_BASE}/manifest.json" + try: + resp = requests.get(url, timeout=15) + resp.raise_for_status() + data = resp.json() + entries = data.get('dicts', data) if isinstance(data, dict) else data + manifest_cache['data'] = entries + manifest_cache['fetched_at'] = now + return entries + except Exception as e: + log.warning(f"Failed to fetch dict manifest: {e}") + return manifest_cache.get('data', []) + + +def local_dict_names() -> set[str]: + """Return set of locally available dict names.""" + if not dicts_dir or not os.path.isdir(dicts_dir): + return set() + return { + f.rsplit('.', 1)[0] + for f in os.listdir(dicts_dir) + if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json' + } + + +def local_dict_version(name: str) -> str: + """Return the version string of a locally downloaded dict, or empty string.""" + path = os.path.join(dicts_dir, f"{name}.json") + if not os.path.isfile(path): + return "" + try: + entry = cache.get(name) + if entry: + return entry['meta'].get('version', '') + with open(path, encoding='utf-8') as f: + data = json.load(f) + return data.get('version', '') + except Exception: + return "" + + +async def list_remote() -> list[ItemDictRemote]: + """List dicts available for download from HuggingFace.""" + entries = await asyncio.to_thread(fetch_manifest_sync) + local = await asyncio.to_thread(local_dict_names) + results = [] + for e in entries: + name = e['name'] + is_local = name in local + remote_version = e.get('version', '') + update = False + if is_local and remote_version: + local_version = await asyncio.to_thread(local_dict_version, name) + update = bool(local_version and local_version != remote_version) + results.append(ItemDictRemote( + name=name, + description=e.get('description', ''), + version=remote_version, + tag_count=e.get('tag_count', 0), + size_mb=e.get('size_mb', 0), + downloaded=is_local, + update_available=update, + )) + return results + + +def download_dict_sync(name: str) -> str: + """Download a dict file from HuggingFace to the local dicts directory.""" + import requests + if '/' in name or '\\' in name or '..' in name: + raise HTTPException(status_code=400, detail="Invalid dict name") + os.makedirs(dicts_dir, exist_ok=True) + url = f"{HF_BASE}/{name}.json" + log.info(f"Downloading dict: {url}") + try: + resp = requests.get(url, timeout=120, stream=True) + resp.raise_for_status() + except requests.RequestException as e: + raise HTTPException(status_code=502, detail=f"Failed to download {name}: {e}") from e + target = os.path.join(dicts_dir, f"{name}.json") + tmp = target + ".tmp" + size = 0 + with open(tmp, 'wb') as f: + for chunk in resp.iter_content(chunk_size=1024 * 256): + f.write(chunk) + size += len(chunk) + os.replace(tmp, target) + cache.pop(name, None) + log.info(f"Downloaded dict: {name} ({size / 1024 / 1024:.1f} MB)") + return target + + +async def download_dict(name: str): + """Download a dict from HuggingFace.""" + path = await asyncio.to_thread(download_dict_sync, name) + entry = await asyncio.to_thread(get_cached, name) + meta = entry['meta'] + return ItemDict( + name=meta['name'], + version=meta['version'], + tag_count=meta['tag_count'], + categories=meta['categories'], + size=entry['size'], + ) + + +async def delete_dict(name: str): + """Delete a locally downloaded dict.""" + if '/' in name or '\\' in name or '..' in name: + raise HTTPException(status_code=400, detail="Invalid dict name") + path = os.path.join(dicts_dir, f"{name}.json") + if not os.path.isfile(path): + raise HTTPException(status_code=404, detail=f"Dict not found: {name}") + await asyncio.to_thread(os.remove, path) + cache.pop(name, None) + return {"status": "deleted", "name": name} + + def register_api(app): app.add_api_route("/sdapi/v1/dicts", list_dicts, methods=["GET"], response_model=list[ItemDict], tags=["Enumerators"]) + app.add_api_route("/sdapi/v1/dicts/remote", list_remote, methods=["GET"], response_model=list[ItemDictRemote], tags=["Enumerators"]) app.add_api_route("/sdapi/v1/dicts/{name}", get_dict, methods=["GET"], response_model=ItemDictContent, tags=["Enumerators"]) + app.add_api_route("/sdapi/v1/dicts/{name}/download", download_dict, methods=["POST"], response_model=ItemDict, tags=["Enumerators"]) + app.add_api_route("/sdapi/v1/dicts/{name}", delete_dict, methods=["DELETE"], tags=["Enumerators"]) diff --git a/modules/api/models.py b/modules/api/models.py index d3d030dec..16af70af3 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -522,6 +522,15 @@ class ItemDictContent(BaseModel): categories: dict = Field(default_factory=dict, title="Categories", description="Category definitions with name and color") tags: list = Field(default_factory=list, title="Tags", description="Tag entries as [name, category_id, post_count] tuples") +class ItemDictRemote(BaseModel): + name: str = Field(title="Name", description="Dictionary identifier") + description: str = Field(default="", title="Description", description="Human-readable description") + version: str = Field(default="", title="Version", description="Version string") + tag_count: int = Field(default=0, title="Tag count", description="Number of tags") + size_mb: float = Field(default=0, title="Size (MB)", description="Approximate file size in megabytes") + downloaded: bool = Field(default=False, title="Downloaded", description="Whether the dict is available locally") + update_available: bool = Field(default=False, title="Update available", description="Whether a newer version exists remotely") + # helper function def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list | None = None, exclude_fields: list[str] | None = None) -> type[BaseModel]: diff --git a/modules/ui_definitions.py b/modules/ui_definitions.py index 176257a52..ebcce9b4b 100644 --- a/modules/ui_definitions.py +++ b/modules/ui_definitions.py @@ -60,7 +60,7 @@ def list_dict_names(): return sorted( os.path.splitext(f)[0] for f in os.listdir(dicts_dir) - if f.endswith('.json') and not f.startswith('.') + if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json' )