mirror of https://github.com/vladmandic/automatic
325 lines
12 KiB
Python
325 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Fetch and convert booru tag databases to SD.Next dict format.
|
|
|
|
Usage:
|
|
python cli/fetch_dicts.py danbooru [--output PATH] [--min-count N]
|
|
python cli/fetch_dicts.py e621 [--output PATH] [--min-count N]
|
|
python cli/fetch_dicts.py rule34 --key USER_ID:API_KEY [--output PATH] [--min-count N]
|
|
python cli/fetch_dicts.py sankaku [--output PATH] [--min-count N]
|
|
python cli/fetch_dicts.py all --key USER_ID:API_KEY [--output-dir DIR] [--min-count N]
|
|
|
|
Output format:
|
|
JSON with { name, version, categories, tags: [[name, category_id, post_count], ...] }
|
|
Tags sorted by post_count descending.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import date
|
|
|
|
import requests
|
|
|
|
DANBOORU_CATEGORIES = {
|
|
"0": {"name": "general", "color": "#0075f8"},
|
|
"1": {"name": "artist", "color": "#a800aa"},
|
|
"3": {"name": "copyright", "color": "#dd00dd"},
|
|
"4": {"name": "character", "color": "#00ab2c"},
|
|
"5": {"name": "meta", "color": "#ee8800"},
|
|
}
|
|
|
|
E621_CATEGORIES = {
|
|
"0": {"name": "general", "color": "#0075f8"},
|
|
"1": {"name": "artist", "color": "#a800aa"},
|
|
"3": {"name": "copyright", "color": "#dd00dd"},
|
|
"4": {"name": "character", "color": "#00ab2c"},
|
|
"5": {"name": "species", "color": "#ed5d1f"},
|
|
"6": {"name": "invalid", "color": "#ff3d3d"},
|
|
"7": {"name": "meta", "color": "#ee8800"},
|
|
"8": {"name": "lore", "color": "#228b22"},
|
|
}
|
|
|
|
# Rule34 uses the same category IDs as Danbooru (Gelbooru-compatible)
|
|
RULE34_CATEGORIES = {
|
|
"0": {"name": "general", "color": "#0075f8"},
|
|
"1": {"name": "artist", "color": "#a800aa"},
|
|
"3": {"name": "copyright", "color": "#dd00dd"},
|
|
"4": {"name": "character", "color": "#00ab2c"},
|
|
"5": {"name": "meta", "color": "#ee8800"},
|
|
}
|
|
|
|
SANKAKU_CATEGORIES = {
|
|
"0": {"name": "general", "color": "#0075f8"},
|
|
"1": {"name": "artist", "color": "#a800aa"},
|
|
"2": {"name": "studio", "color": "#a800aa"},
|
|
"3": {"name": "copyright", "color": "#dd00dd"},
|
|
"4": {"name": "character", "color": "#00ab2c"},
|
|
"5": {"name": "species", "color": "#ed5d1f"},
|
|
"8": {"name": "medium", "color": "#ee8800"},
|
|
"9": {"name": "meta", "color": "#ee8800"},
|
|
}
|
|
|
|
USER_AGENT = "SDNext-DictFetcher/1.0 (tag autocomplete)"
|
|
|
|
|
|
def fetch_danbooru(min_count: int = 10, **_kwargs) -> list:
|
|
"""Fetch tags from Danbooru API, paginated."""
|
|
tags = []
|
|
page = 1
|
|
session = requests.Session()
|
|
session.headers["User-Agent"] = USER_AGENT
|
|
while True:
|
|
url = f"https://danbooru.donmai.us/tags.json?limit=1000&page={page}&search[order]=count"
|
|
try:
|
|
resp = session.get(url, timeout=30)
|
|
resp.raise_for_status()
|
|
except requests.RequestException as e:
|
|
print(f" Error on page {page}: {e}", file=sys.stderr)
|
|
break
|
|
data = resp.json()
|
|
if not data:
|
|
break
|
|
for tag in data:
|
|
count = tag.get("post_count", 0)
|
|
if count < min_count:
|
|
continue
|
|
tags.append([tag["name"], tag["category"], count])
|
|
below_threshold = all(t.get("post_count", 0) < min_count for t in data)
|
|
print(f" Page {page}: {len(data)} tags (total: {len(tags)})", file=sys.stderr)
|
|
if below_threshold:
|
|
break
|
|
page += 1
|
|
time.sleep(0.5)
|
|
tags.sort(key=lambda t: t[2], reverse=True)
|
|
return tags
|
|
|
|
|
|
def fetch_e621(min_count: int = 10, **_kwargs) -> list:
|
|
"""Fetch tags from e621 API, paginated."""
|
|
tags = []
|
|
page = 1
|
|
session = requests.Session()
|
|
session.headers["User-Agent"] = USER_AGENT
|
|
while True:
|
|
url = f"https://e621.net/tags.json?limit=320&page={page}&search[order]=count"
|
|
try:
|
|
resp = session.get(url, timeout=30)
|
|
resp.raise_for_status()
|
|
except requests.RequestException as e:
|
|
print(f" Error on page {page}: {e}", file=sys.stderr)
|
|
break
|
|
data = resp.json()
|
|
if not data:
|
|
break
|
|
for tag in data:
|
|
count = tag.get("post_count", 0)
|
|
if count < min_count:
|
|
continue
|
|
tags.append([tag["name"], tag["category"], count])
|
|
below_threshold = all(t.get("post_count", 0) < min_count for t in data)
|
|
print(f" Page {page}: {len(data)} tags (total: {len(tags)})", file=sys.stderr)
|
|
if below_threshold:
|
|
break
|
|
page += 1
|
|
time.sleep(1.0) # e621 rate limit is stricter
|
|
tags.sort(key=lambda t: t[2], reverse=True)
|
|
return tags
|
|
|
|
|
|
def fetch_gelbooru(base_url: str, min_count: int = 10, api_key: str | None = None, rate_limit: float = 0.5) -> list:
|
|
"""Fetch tags from a Gelbooru-compatible API (rule34, gelbooru, etc.).
|
|
|
|
Unlike Danbooru/e621, the Gelbooru tag endpoint doesn't support ordering
|
|
by count, so we paginate through all tags and filter client-side.
|
|
The tag endpoint returns XML (json=1 is not supported for tags).
|
|
"""
|
|
import xml.etree.ElementTree as ET
|
|
|
|
tags = []
|
|
page = 0
|
|
page_size = 1000
|
|
session = requests.Session()
|
|
session.headers["User-Agent"] = USER_AGENT
|
|
params: dict[str, str] = {
|
|
"page": "dapi", "s": "tag", "q": "index",
|
|
"limit": str(page_size),
|
|
}
|
|
if api_key:
|
|
if ":" not in api_key:
|
|
print(" Error: --key must be USER_ID:API_KEY format", file=sys.stderr)
|
|
return []
|
|
uid, key = api_key.split(":", 1)
|
|
params["user_id"] = uid
|
|
params["api_key"] = key
|
|
|
|
while True:
|
|
params["pid"] = str(page)
|
|
try:
|
|
resp = session.get(base_url, params=params, timeout=30)
|
|
resp.raise_for_status()
|
|
except requests.RequestException as e:
|
|
print(f" Error on page {page}: {e}", file=sys.stderr)
|
|
break
|
|
text = resp.text.strip()
|
|
if not text or text.startswith('"') or not text.startswith('<?xml'):
|
|
# Non-XML response (auth error, empty, etc.)
|
|
if page == 0:
|
|
print(f" Error: {text[:200]}", file=sys.stderr)
|
|
break
|
|
root = ET.fromstring(text)
|
|
elements = root.findall('tag')
|
|
for el in elements:
|
|
count = int(el.get("count", "0"))
|
|
if count >= min_count:
|
|
tags.append([el.get("name", ""), int(el.get("type", "0")), count])
|
|
print(f" Page {page}: {len(elements)} tags (total: {len(tags)})", file=sys.stderr)
|
|
if len(elements) < page_size:
|
|
break
|
|
page += 1
|
|
time.sleep(rate_limit)
|
|
tags.sort(key=lambda t: t[2], reverse=True)
|
|
return tags
|
|
|
|
|
|
def fetch_rule34(min_count: int = 10, api_key: str | None = None, **_kwargs) -> list:
|
|
"""Fetch tags from rule34.xxx."""
|
|
if not api_key:
|
|
print(" Warning: no --key provided, rule34 may rate-limit aggressively", file=sys.stderr)
|
|
return fetch_gelbooru("https://api.rule34.xxx/index.php", min_count=min_count, api_key=api_key)
|
|
|
|
|
|
def fetch_sankaku(min_count: int = 10, **_kwargs) -> list:
|
|
"""Fetch tags from Sankaku Complex (chan.sankakucomplex.com).
|
|
|
|
Uses the public JSON API at sankakuapi.com which supports order=count,
|
|
so we can stop early when counts drop below min_count.
|
|
Tag names come as English with spaces — converted to lowercase underscores.
|
|
"""
|
|
tags = []
|
|
page = 1
|
|
page_size = 200 # API max
|
|
session = requests.Session()
|
|
session.headers["User-Agent"] = USER_AGENT
|
|
while True:
|
|
try:
|
|
resp = session.get(
|
|
"https://sankakuapi.com/tags",
|
|
params={"limit": page_size, "page": page, "order": "count"},
|
|
timeout=30,
|
|
)
|
|
resp.raise_for_status()
|
|
except requests.RequestException as e:
|
|
print(f" Error on page {page}: {e}", file=sys.stderr)
|
|
break
|
|
data = resp.json()
|
|
if not data:
|
|
break
|
|
below_threshold = True
|
|
for tag in data:
|
|
count = tag.get("post_count", 0)
|
|
if count >= min_count:
|
|
below_threshold = False
|
|
name = tag.get("name_en") or tag.get("name_ja", "")
|
|
if not name:
|
|
continue
|
|
name = name.strip().lower()
|
|
tags.append([name, tag.get("type", 0), count])
|
|
else:
|
|
below_threshold = True
|
|
print(f" Page {page}: {len(data)} tags (total: {len(tags)})", file=sys.stderr)
|
|
if below_threshold:
|
|
break
|
|
page += 1
|
|
time.sleep(0.5)
|
|
tags.sort(key=lambda t: t[2], reverse=True)
|
|
return tags
|
|
|
|
|
|
SOURCES = {
|
|
"danbooru": {
|
|
"fetch": fetch_danbooru,
|
|
"categories": DANBOORU_CATEGORIES,
|
|
},
|
|
"e621": {
|
|
"fetch": fetch_e621,
|
|
"categories": E621_CATEGORIES,
|
|
},
|
|
"rule34": {
|
|
"fetch": fetch_rule34,
|
|
"categories": RULE34_CATEGORIES,
|
|
},
|
|
"sankaku": {
|
|
"fetch": fetch_sankaku,
|
|
"categories": SANKAKU_CATEGORIES,
|
|
},
|
|
}
|
|
|
|
|
|
def write_dict(name: str, tags: list, categories: dict, output_path: str, separator: str = "_"):
|
|
"""Write dict JSON file atomically.
|
|
|
|
separator controls the word separator in tag names:
|
|
"_" (default) → "high_resolution" (booru convention, anime/illustration models)
|
|
" " → "high resolution" (natural language, SDXL/Flux-style models)
|
|
"""
|
|
if separator == "_":
|
|
tags = [[t[0].replace(" ", "_"), t[1], t[2]] for t in tags]
|
|
else:
|
|
tags = [[t[0].replace("_", " "), t[1], t[2]] for t in tags]
|
|
data = {
|
|
"name": name,
|
|
"version": date.today().isoformat(),
|
|
"categories": categories,
|
|
"tags": tags,
|
|
}
|
|
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
|
|
tmp_path = output_path + ".tmp"
|
|
with open(tmp_path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, separators=(",", ":"))
|
|
os.replace(tmp_path, output_path)
|
|
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
|
print(f" Written: {output_path} ({len(tags)} tags, {size_mb:.1f} MB)", file=sys.stderr)
|
|
|
|
|
|
def fetch_source(name: str, output: str, min_count: int, api_key: str | None = None, separator: str = "_"):
|
|
"""Fetch and write a single source."""
|
|
if name not in SOURCES:
|
|
print(f"Unknown source: {name}. Available: {', '.join(SOURCES.keys())}", file=sys.stderr)
|
|
sys.exit(1)
|
|
source = SOURCES[name]
|
|
print(f"Fetching {name} (min_count={min_count})...", file=sys.stderr)
|
|
tags = source["fetch"](min_count=min_count, api_key=api_key)
|
|
if not tags:
|
|
print(f" No tags fetched for {name}", file=sys.stderr)
|
|
return
|
|
write_dict(name, tags, source["categories"], output, separator=separator)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Fetch booru tag databases for SD.Next dict autocomplete")
|
|
parser.add_argument("source", choices=list(SOURCES.keys()) + ["all"], help="Tag source to fetch")
|
|
parser.add_argument("--output", "-o", help="Output file path (for single source)")
|
|
parser.add_argument("--output-dir", "-d", help="Output directory (for 'all')")
|
|
parser.add_argument("--min-count", "-m", type=int, default=10, help="Minimum post count to include (default: 10)")
|
|
parser.add_argument("--key", "-k", help="API key as USER_ID:API_KEY (required for rule34)")
|
|
parser.add_argument("--spaces", action="store_true", help="Use spaces instead of underscores in tag names (for natural language models like SDXL/Flux)")
|
|
args = parser.parse_args()
|
|
separator = " " if args.spaces else "_"
|
|
|
|
if args.source == "all":
|
|
output_dir = args.output_dir or "."
|
|
for name in SOURCES:
|
|
output = os.path.join(output_dir, f"{name}.json")
|
|
fetch_source(name, output, args.min_count, api_key=args.key, separator=separator)
|
|
else:
|
|
output = args.output or f"{args.source}.json"
|
|
fetch_source(args.source, output, args.min_count, api_key=args.key, separator=separator)
|
|
|
|
print("Done.", file=sys.stderr)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|