mirror of https://github.com/vladmandic/automatic
feat(dicts): add remote dict management via HuggingFace
Add endpoints for browsing, downloading, and deleting dicts hosted on
HuggingFace. Manifest is fetched and cached from the remote repo to
show available dicts with download status and version comparison.
- /dicts/remote GET: list available dicts from HF manifest
- /dicts/{name}/download POST: download dict from HF with atomic write
- /dicts/{name} DELETE: remove local copy and evict cache
- ItemDictRemote model with downloaded/update_available flags
- Exclude manifest.json from local dict listing in all scan paths
pull/4707/head
parent
0405689606
commit
f3e0b4a86b
|
|
@ -1,21 +1,27 @@
|
|||
"""V1 dictionary / tag autocomplete endpoints.
|
||||
|
||||
Serves pre-built tag dictionaries (Danbooru, e621, natural language, artists)
|
||||
from JSON files in the configured dicts directory.
|
||||
from JSON files in the configured dicts directory. Remote dicts are hosted
|
||||
on HuggingFace and downloaded on demand.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from modules.api.models import ItemDict, ItemDictContent
|
||||
from modules.api.models import ItemDict, ItemDictContent, ItemDictRemote
|
||||
from modules.logger import log
|
||||
|
||||
dicts_dir: str = ""
|
||||
cache: dict[str, dict] = {}
|
||||
|
||||
HF_REPO = "CalamitousFelicitousness/prompt-vocab"
|
||||
HF_BASE = f"https://huggingface.co/datasets/{HF_REPO}/resolve/main"
|
||||
MANIFEST_CACHE_SEC = 300 # re-fetch manifest every 5 minutes
|
||||
manifest_cache: dict = {} # {"data": [...], "fetched_at": float}
|
||||
|
||||
|
||||
def init(path: str) -> None:
|
||||
"""Set the dicts directory path. Called once during API registration."""
|
||||
|
|
@ -61,7 +67,7 @@ def list_dicts_sync() -> list[ItemDict]:
|
|||
return []
|
||||
items = []
|
||||
for filename in sorted(os.listdir(dicts_dir)):
|
||||
if not filename.endswith('.json') or filename.startswith('.'):
|
||||
if not filename.endswith('.json') or filename.startswith('.') or filename == 'manifest.json':
|
||||
continue
|
||||
name = filename.rsplit('.', 1)[0]
|
||||
try:
|
||||
|
|
@ -103,6 +109,136 @@ async def get_dict(name: str) -> ItemDictContent:
|
|||
)
|
||||
|
||||
|
||||
# ── Remote dict management ──
|
||||
|
||||
def fetch_manifest_sync() -> list[dict]:
|
||||
"""Fetch manifest.json from HuggingFace, with caching."""
|
||||
import time
|
||||
import requests
|
||||
now = time.time()
|
||||
if manifest_cache.get('data') and now - manifest_cache.get('fetched_at', 0) < MANIFEST_CACHE_SEC:
|
||||
return manifest_cache['data']
|
||||
url = f"{HF_BASE}/manifest.json"
|
||||
try:
|
||||
resp = requests.get(url, timeout=15)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
entries = data.get('dicts', data) if isinstance(data, dict) else data
|
||||
manifest_cache['data'] = entries
|
||||
manifest_cache['fetched_at'] = now
|
||||
return entries
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to fetch dict manifest: {e}")
|
||||
return manifest_cache.get('data', [])
|
||||
|
||||
|
||||
def local_dict_names() -> set[str]:
|
||||
"""Return set of locally available dict names."""
|
||||
if not dicts_dir or not os.path.isdir(dicts_dir):
|
||||
return set()
|
||||
return {
|
||||
f.rsplit('.', 1)[0]
|
||||
for f in os.listdir(dicts_dir)
|
||||
if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json'
|
||||
}
|
||||
|
||||
|
||||
def local_dict_version(name: str) -> str:
|
||||
"""Return the version string of a locally downloaded dict, or empty string."""
|
||||
path = os.path.join(dicts_dir, f"{name}.json")
|
||||
if not os.path.isfile(path):
|
||||
return ""
|
||||
try:
|
||||
entry = cache.get(name)
|
||||
if entry:
|
||||
return entry['meta'].get('version', '')
|
||||
with open(path, encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get('version', '')
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
async def list_remote() -> list[ItemDictRemote]:
|
||||
"""List dicts available for download from HuggingFace."""
|
||||
entries = await asyncio.to_thread(fetch_manifest_sync)
|
||||
local = await asyncio.to_thread(local_dict_names)
|
||||
results = []
|
||||
for e in entries:
|
||||
name = e['name']
|
||||
is_local = name in local
|
||||
remote_version = e.get('version', '')
|
||||
update = False
|
||||
if is_local and remote_version:
|
||||
local_version = await asyncio.to_thread(local_dict_version, name)
|
||||
update = bool(local_version and local_version != remote_version)
|
||||
results.append(ItemDictRemote(
|
||||
name=name,
|
||||
description=e.get('description', ''),
|
||||
version=remote_version,
|
||||
tag_count=e.get('tag_count', 0),
|
||||
size_mb=e.get('size_mb', 0),
|
||||
downloaded=is_local,
|
||||
update_available=update,
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
def download_dict_sync(name: str) -> str:
|
||||
"""Download a dict file from HuggingFace to the local dicts directory."""
|
||||
import requests
|
||||
if '/' in name or '\\' in name or '..' in name:
|
||||
raise HTTPException(status_code=400, detail="Invalid dict name")
|
||||
os.makedirs(dicts_dir, exist_ok=True)
|
||||
url = f"{HF_BASE}/{name}.json"
|
||||
log.info(f"Downloading dict: {url}")
|
||||
try:
|
||||
resp = requests.get(url, timeout=120, stream=True)
|
||||
resp.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to download {name}: {e}") from e
|
||||
target = os.path.join(dicts_dir, f"{name}.json")
|
||||
tmp = target + ".tmp"
|
||||
size = 0
|
||||
with open(tmp, 'wb') as f:
|
||||
for chunk in resp.iter_content(chunk_size=1024 * 256):
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
os.replace(tmp, target)
|
||||
cache.pop(name, None)
|
||||
log.info(f"Downloaded dict: {name} ({size / 1024 / 1024:.1f} MB)")
|
||||
return target
|
||||
|
||||
|
||||
async def download_dict(name: str):
|
||||
"""Download a dict from HuggingFace."""
|
||||
path = await asyncio.to_thread(download_dict_sync, name)
|
||||
entry = await asyncio.to_thread(get_cached, name)
|
||||
meta = entry['meta']
|
||||
return ItemDict(
|
||||
name=meta['name'],
|
||||
version=meta['version'],
|
||||
tag_count=meta['tag_count'],
|
||||
categories=meta['categories'],
|
||||
size=entry['size'],
|
||||
)
|
||||
|
||||
|
||||
async def delete_dict(name: str):
|
||||
"""Delete a locally downloaded dict."""
|
||||
if '/' in name or '\\' in name or '..' in name:
|
||||
raise HTTPException(status_code=400, detail="Invalid dict name")
|
||||
path = os.path.join(dicts_dir, f"{name}.json")
|
||||
if not os.path.isfile(path):
|
||||
raise HTTPException(status_code=404, detail=f"Dict not found: {name}")
|
||||
await asyncio.to_thread(os.remove, path)
|
||||
cache.pop(name, None)
|
||||
return {"status": "deleted", "name": name}
|
||||
|
||||
|
||||
def register_api(app):
|
||||
app.add_api_route("/sdapi/v1/dicts", list_dicts, methods=["GET"], response_model=list[ItemDict], tags=["Enumerators"])
|
||||
app.add_api_route("/sdapi/v1/dicts/remote", list_remote, methods=["GET"], response_model=list[ItemDictRemote], tags=["Enumerators"])
|
||||
app.add_api_route("/sdapi/v1/dicts/{name}", get_dict, methods=["GET"], response_model=ItemDictContent, tags=["Enumerators"])
|
||||
app.add_api_route("/sdapi/v1/dicts/{name}/download", download_dict, methods=["POST"], response_model=ItemDict, tags=["Enumerators"])
|
||||
app.add_api_route("/sdapi/v1/dicts/{name}", delete_dict, methods=["DELETE"], tags=["Enumerators"])
|
||||
|
|
|
|||
|
|
@ -522,6 +522,15 @@ class ItemDictContent(BaseModel):
|
|||
categories: dict = Field(default_factory=dict, title="Categories", description="Category definitions with name and color")
|
||||
tags: list = Field(default_factory=list, title="Tags", description="Tag entries as [name, category_id, post_count] tuples")
|
||||
|
||||
class ItemDictRemote(BaseModel):
|
||||
name: str = Field(title="Name", description="Dictionary identifier")
|
||||
description: str = Field(default="", title="Description", description="Human-readable description")
|
||||
version: str = Field(default="", title="Version", description="Version string")
|
||||
tag_count: int = Field(default=0, title="Tag count", description="Number of tags")
|
||||
size_mb: float = Field(default=0, title="Size (MB)", description="Approximate file size in megabytes")
|
||||
downloaded: bool = Field(default=False, title="Downloaded", description="Whether the dict is available locally")
|
||||
update_available: bool = Field(default=False, title="Update available", description="Whether a newer version exists remotely")
|
||||
|
||||
# helper function
|
||||
|
||||
def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list | None = None, exclude_fields: list[str] | None = None) -> type[BaseModel]:
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def list_dict_names():
|
|||
return sorted(
|
||||
os.path.splitext(f)[0]
|
||||
for f in os.listdir(dicts_dir)
|
||||
if f.endswith('.json') and not f.startswith('.')
|
||||
if f.endswith('.json') and not f.startswith('.') and f != 'manifest.json'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue