sd-webui-infinite-image-bro.../scripts/iib/topic_cluster.py

1562 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import hashlib
import asyncio
import json
import math
import os
import re
import time
import uuid
import threading
from array import array
from contextlib import closing
from typing import Callable, Dict, List, Optional, Tuple
import requests
from fastapi import Depends, FastAPI, HTTPException
from pydantic import BaseModel
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, TopicTitleCache
from scripts.iib.tool import cwd
_TOPIC_CLUSTER_JOBS: Dict[str, Dict] = {}
_TOPIC_CLUSTER_JOBS_LOCK = threading.Lock()
_TOPIC_CLUSTER_JOBS_MAX = 16
_EMBEDDING_MAX_TOKENS_SOFT = 7800
# Some providers enforce a max token budget per *request* (sum of all inputs), not only per input.
# Keep a conservative per-request budget to avoid "context length exceeded" across batched inputs.
_EMBEDDING_REQUEST_MAX_TOKENS_SOFT = 7600
_EMBEDDING_DEBUG = os.getenv("IIB_EMBEDDING_DEBUG", "0").strip().lower() in ["1", "true", "yes", "on"]
def _estimate_tokens_soft(s: str) -> int:
"""
Lightweight, dependency-free token estimator.
- ASCII-ish chars: ~4 chars per token (roughly OpenAI/BPE average for English)
- Non-ASCII chars (CJK etc): ~1 char per token (conservative)
This is intentionally conservative to avoid provider-side "max tokens exceeded" errors.
"""
if not s:
return 0
ascii_chars = 0
non_ascii = 0
for ch in s:
if ord(ch) < 128:
ascii_chars += 1
else:
non_ascii += 1
return (ascii_chars + 3) // 4 + non_ascii
def _truncate_for_embedding_tokens(s: str, max_tokens: int = _EMBEDDING_MAX_TOKENS_SOFT) -> str:
"""
Truncate text to a conservative token budget.
We do not rely on provider tokenizers to keep the system lightweight.
"""
s = (s or "").strip()
if not s:
return ""
if _estimate_tokens_soft(s) <= max_tokens:
return s
# Walk from start, stop when budget is reached.
out = []
tokens = 0
ascii_bucket = 0 # count ascii chars; every 4 ascii chars ~= +1 token
for ch in s:
if ord(ch) < 128:
ascii_bucket += 1
if ascii_bucket >= 4:
tokens += 1
ascii_bucket = 0
else:
tokens += 1
if tokens > max_tokens:
break
out.append(ch)
return ("".join(out)).strip()
def _batched_by_token_budget(
items: List[Dict],
*,
max_items: int,
max_tokens_sum: int,
) -> List[List[Dict]]:
"""
Split a list of items (each contains 'text') into batches constrained by:
- max_items
- sum(estimated_tokens(text)) <= max_tokens_sum
"""
batches: List[List[Dict]] = []
cur: List[Dict] = []
cur_tokens = 0
for it in items:
txt = str(it.get("text") or "")
t = _estimate_tokens_soft(txt)
# ensure progress even if one item is huge (should already be truncated per-input)
if cur and (len(cur) >= max_items or (cur_tokens + t) > max_tokens_sum):
batches.append(cur)
cur = []
cur_tokens = 0
cur.append(it)
cur_tokens += t
if cur:
batches.append(cur)
return batches
def _job_now() -> float:
return time.time()
def _job_trim() -> None:
# Keep only the most recent jobs in memory.
with _TOPIC_CLUSTER_JOBS_LOCK:
if len(_TOPIC_CLUSTER_JOBS) <= _TOPIC_CLUSTER_JOBS_MAX:
return
# sort by updated_at desc
items = sorted(_TOPIC_CLUSTER_JOBS.items(), key=lambda kv: (kv[1].get("updated_at") or 0), reverse=True)
keep = dict(items[: _TOPIC_CLUSTER_JOBS_MAX])
_TOPIC_CLUSTER_JOBS.clear()
_TOPIC_CLUSTER_JOBS.update(keep)
def _job_get(job_id: str) -> Optional[Dict]:
with _TOPIC_CLUSTER_JOBS_LOCK:
j = _TOPIC_CLUSTER_JOBS.get(job_id)
return dict(j) if isinstance(j, dict) else None
def _job_upsert(job_id: str, patch: Dict) -> None:
with _TOPIC_CLUSTER_JOBS_LOCK:
cur = _TOPIC_CLUSTER_JOBS.get(job_id)
if not isinstance(cur, dict):
cur = {"job_id": job_id}
cur.update(patch or {})
cur["updated_at"] = _job_now()
_TOPIC_CLUSTER_JOBS[job_id] = cur
_job_trim()
def _normalize_base_url(base_url: str) -> str:
return base_url[:-1] if base_url.endswith("/") else base_url
_PROMPT_NORMALIZE_ENABLED = os.getenv("IIB_PROMPT_NORMALIZE", "1").strip().lower() not in ["0", "false", "no", "off"]
# balanced: keep some discriminative style words (e.g. "科学插图/纪实/胶片") while dropping boilerplate quality/camera terms
# theme_only: more aggressive removal, closer to "only subject/theme nouns"
_PROMPT_NORMALIZE_MODE = (os.getenv("IIB_PROMPT_NORMALIZE_MODE", "balanced") or "balanced").strip().lower()
# Remove common SD boilerplate / quality descriptors.
_DROP_PATTERNS_COMMON = [
# SD / A1111 tags
r"<lora:[^>]+>",
r"<lyco:[^>]+>",
# quality / resolution / generic
r"\b(masterpiece|best\s*quality|high\s*quality|best\s*rating|(?:highly|ultra|hyper)[-\s\u2010\u2011\u2012\u2013\u2212]*detailed|absurdres|absurd\s*res|hires|hdr|uhd|8k|4k|2k|raw\s*photo|photorealistic|realistic|cinematic)\b",
# photography / camera / lens
r"\b(film\s+photography|photography|dslr|camera|canon|nikon|sony|sigma|leica|lens|bokeh|depth\s+of\s+field|dof|sharp\s+focus|wide\s+angle|fisheye)\b",
r"\b(iso\s*\d{2,5}|f\/\d+(?:\.\d+)?|\d{2,4}mm)\b",
]
# chinese quality descriptors (common)
_DROP_PATTERNS_ZH_COMMON = [
r"(超高分辨率|高分辨率|高清|超清|8K|4K|2K|照片级|高质量|最佳质量|大师作品|杰作|超细节|细节丰富|极致细节|极致|完美)",
]
# chinese "style/photography" words: keep in balanced mode (discriminative), drop in theme_only mode
_DROP_PATTERNS_ZH_STYLE = [
r"(电影质感|写真|写实|真实感|摄影|摄影作品|摄影图像|摄影图|镜头|景深|胶片|光圈|光影|构图|色彩|渲染|纪实|插图|科学插图)",
]
def _build_drop_re() -> re.Pattern:
pats = list(_DROP_PATTERNS_COMMON) + list(_DROP_PATTERNS_ZH_COMMON)
if _PROMPT_NORMALIZE_MODE in ["theme", "theme_only", "strict"]:
pats += list(_DROP_PATTERNS_ZH_STYLE)
return re.compile("|".join(f"(?:{p})" for p in pats), flags=re.IGNORECASE)
_DROP_RE = _build_drop_re()
def _compute_prompt_normalize_version() -> str:
"""
IMPORTANT:
- Do NOT allow users to override normalize-version via environment variables.
- Version should be deterministic from the normalization rules themselves, so cache invalidation
happens automatically when we change rules in code (or switch mode).
"""
payload = {
"enabled": bool(_PROMPT_NORMALIZE_ENABLED),
"mode": str(_PROMPT_NORMALIZE_MODE),
"drop_common": list(_DROP_PATTERNS_COMMON),
"drop_zh_common": list(_DROP_PATTERNS_ZH_COMMON),
"drop_zh_style": list(_DROP_PATTERNS_ZH_STYLE),
}
s = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return "nv_" + hashlib.sha1(s.encode("utf-8")).hexdigest()[:12]
# Derived normalize version fingerprint (for embedding/title cache invalidation)
_PROMPT_NORMALIZE_VERSION = _compute_prompt_normalize_version()
def _extract_prompt_text(raw_exif: str, max_chars: int = 4000) -> str:
"""
Extract the natural-language prompt part from stored exif text.
Keep text before 'Negative prompt:' to preserve semantics.
"""
if not isinstance(raw_exif, str):
return ""
s = raw_exif.strip()
if not s:
return ""
idx = s.lower().find("negative prompt:")
if idx != -1:
s = s[:idx].strip()
if len(s) > max_chars:
s = s[:max_chars]
return s.strip()
def _clean_prompt_for_semantic(text: str) -> str:
"""
Light, dependency-free prompt normalization:
- remove lora tags / SD boilerplate / quality & photography descriptors
- keep remaining text as 'theme' semantic signal for embeddings/clustering
"""
if not isinstance(text, str):
return ""
s = text
# remove negative prompt tail early (safety if caller passes raw exif)
s = re.sub(r"(negative prompt:).*", " ", s, flags=re.IGNORECASE | re.DOTALL)
# remove weights like (foo:1.2)
s = re.sub(r"\(([^()]{1,80}):\s*\d+(?:\.\d+)?\)", r"\1", s)
# drop boilerplate patterns
s = _DROP_RE.sub(" ", s)
# normalize separators
s = s.replace("**", " ")
s = re.sub(r"[\[\]{}()]", " ", s)
s = re.sub(r"\s+", " ", s).strip()
# If it's a comma-tag style prompt, remove empty / tiny segments.
parts = re.split(r"[,\n;]+", s)
kept: List[str] = []
for p in parts:
t = p.strip()
if not t:
continue
# drop segments that are basically leftover boilerplate (too short or all punctuation)
if len(t) <= 2:
continue
kept.append(t)
s2 = "".join(kept) if kept else s
return s2.strip()
def _clean_for_title(text: str) -> str:
if not isinstance(text, str):
return ""
s = text
s = s.replace("**", " ")
s = re.sub(r"<lora:[^>]+>", " ", s, flags=re.IGNORECASE)
s = re.sub(r"<lyco:[^>]+>", " ", s, flags=re.IGNORECASE)
s = re.sub(r"(negative prompt:).*", " ", s, flags=re.IGNORECASE | re.DOTALL)
s = re.sub(r"(prompt:|提示词[:]|提示[:]|输出[:])", " ", s, flags=re.IGNORECASE)
s = re.sub(r"\s+", " ", s).strip()
return s
def _title_from_representative_prompt(text: str, max_len: int = 18) -> str:
"""
Local fallback title: take the first sentence/clause and truncate.
This is much more readable than token n-grams without Chinese word segmentation.
"""
# Use the same semantic cleaner as embeddings to avoid boilerplate titles like
# "masterpiece, best quality" / "A highly detailed ..."
base = _clean_for_title(text)
s = _clean_prompt_for_semantic(base) if _PROMPT_NORMALIZE_ENABLED else base
if not s:
s = base
if not s:
return "主题"
# Split by common sentence punctuations, keep the first segment.
seg = re.split(r"[。!?!?\n\r;]+", s)[0].strip()
# Remove leading punctuation / separators
seg = re.sub(r"^[,;:\s-]+", "", seg).strip()
# Remove leading english articles for nicer titles
seg = re.sub(r"^(a|an|the)\s+", "", seg, flags=re.IGNORECASE).strip()
# Strip common boilerplate templates in titles while keeping discriminative words.
# English: "highly detailed scientific rendering/illustration of ..."
seg = re.sub(
r"^(?:(?:highly|ultra|hyper)[-\s\u2010\u2011\u2012\u2013\u2212]*detailed\s+)?(?:scientific\s+)?(?:rendering|illustration|image|depiction|scene)\s+of\s+",
"",
seg,
flags=re.IGNORECASE,
).strip()
# Chinese: "一张...图像/插图/照片..." template
seg = re.sub(r"^一[张幅]\s*[^,。]{0,20}(?:图像|插图|照片|摄影图像|摄影作品)\s*[, ]*", "", seg).strip()
# Remove trailing commas/colons
seg = re.sub(r"[,:]\s*$", "", seg).strip()
# If still too long, hard truncate.
if len(seg) > max_len:
seg = seg[:max_len].rstrip()
return seg or "主题"
def _vec_to_blob_f32(vec: List[float]) -> bytes:
arr = array("f", vec)
return arr.tobytes()
def _blob_to_vec_f32(blob: bytes) -> array:
arr = array("f")
arr.frombytes(blob)
return arr
def _l2_norm_sq(vec: array) -> float:
return sum((x * x for x in vec))
def _dot(a: array, b: array) -> float:
return sum((x * y for x, y in zip(a, b)))
def _cos_sum(a_sum: array, a_norm_sq: float, b_sum: array, b_norm_sq: float) -> float:
if a_norm_sq <= 0 or b_norm_sq <= 0:
return 0.0
dotv = sum((x * y for x, y in zip(a_sum, b_sum)))
return dotv / (math.sqrt(a_norm_sq) * math.sqrt(b_norm_sq))
def _call_embeddings_sync(
*,
inputs: List[str],
model: str,
base_url: str,
api_key: str,
) -> List[List[float]]:
if not api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
url = f"{_normalize_base_url(base_url)}/embeddings"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {"model": model, "input": inputs}
try:
resp = requests.post(url, json=payload, headers=headers, timeout=120)
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"Embedding API request failed: {e}")
if resp.status_code != 200:
# Do not leak upstream 401 to frontend (treat as bad request/config to avoid confusing auth state).
status = 400 if resp.status_code == 401 else resp.status_code
body = (resp.text or "")[:600]
raise HTTPException(status_code=status, detail=body)
data = resp.json()
items = data.get("data") or []
items.sort(key=lambda x: x.get("index", 0))
embeddings = [x.get("embedding") for x in items]
if any((not isinstance(v, list) for v in embeddings)):
raise HTTPException(status_code=500, detail="Invalid embeddings response format")
return embeddings
async def _call_embeddings(
*,
inputs: List[str],
model: str,
base_url: str,
api_key: str,
) -> List[List[float]]:
"""
IMPORTANT:
- We must NOT block FastAPI's event loop thread with a long-running synchronous HTTP request.
- Use a thread offload for requests.post to keep the server responsive while waiting the embedding API.
"""
return await asyncio.to_thread(
_call_embeddings_sync,
inputs=inputs,
model=model,
base_url=base_url,
api_key=api_key,
)
def _call_chat_title_sync(
*,
base_url: str,
api_key: str,
model: str,
prompt_samples: List[str],
output_lang: str,
) -> Optional[Dict]:
"""
Ask LLM to generate a short topic title and a few keywords. Returns dict or None.
"""
if not api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
url = f"{_normalize_base_url(base_url)}/chat/completions"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
samples = [(_clean_prompt_for_semantic(_clean_for_title(s) or s) or s).strip() for s in prompt_samples if (s or "").strip()]
samples = [s[:400] for s in samples][:6]
if not samples:
raise HTTPException(status_code=400, detail="No prompt samples for title generation")
json_example = '{"title":"...","keywords":["...","..."]}'
sys = (
"You are a topic naming assistant for image-generation prompts.\n"
"Given several prompt snippets that belong to the SAME theme, output:\n"
"- a short topic title\n"
"- 36 keywords.\n"
"\n"
"Rules:\n"
f"- Output language MUST be: {output_lang}\n"
"- Prefer 412 characters for Chinese (Simplified/Traditional), otherwise 26 English/German words.\n"
"- Avoid generic boilerplate like: masterpiece, best quality, highly detailed, cinematic, etc.\n"
"- Keep distinctive terms if they help differentiate themes (e.g., Warhammer 40K, Lolita, scientific illustration).\n"
"- Do NOT output explanations. Do NOT output markdown/code fences.\n"
"- The output MUST start with '{' and end with '}' (no leading/trailing characters).\n"
"\n"
"Output STRICT JSON only:\n"
+ json_example
)
user = "Prompt snippets:\n" + "\n".join([f"- {s}" for s in samples])
payload = {
"model": model,
"messages": [{"role": "system", "content": sys}, {"role": "user", "content": user}],
# Prefer deterministic, JSON-only output
"temperature": 0.0,
"top_p": 1.0,
# Give enough room for JSON across providers.
"max_tokens": 2048,
# Prefer tool/function call to force structured output across providers (e.g. Gemini).
"tools": [
{
"type": "function",
"function": {
"name": "set_topic",
"description": "Return a concise topic title and 3-6 keywords.",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string"},
"keywords": {"type": "array", "items": {"type": "string"}},
},
"required": ["title", "keywords"],
"additionalProperties": False,
},
},
}
],
"tool_choice": {"type": "function", "function": {"name": "set_topic"}},
}
# Some OpenAI-compatible providers may use different token limit fields / casing.
# Set them all (still a single request; no retry/fallback).
payload["max_output_tokens"] = payload["max_tokens"]
payload["max_completion_tokens"] = payload["max_tokens"]
payload["maxOutputTokens"] = payload["max_tokens"]
payload["maxCompletionTokens"] = payload["max_tokens"]
def _post_and_parse(payload_obj: Dict) -> Dict:
try:
resp = requests.post(url, json=payload_obj, headers=headers, timeout=60)
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"Chat API request failed: {e}")
if resp.status_code != 200:
# keep response body for debugging (truncated)
body = (resp.text or "")[:600]
status = 400 if resp.status_code == 401 else resp.status_code
raise HTTPException(status_code=status, detail=body)
try:
data = resp.json()
except Exception as e:
txt = (resp.text or "")[:600]
raise HTTPException(status_code=502, detail=f"Chat API response is not JSON: {e}; body={txt}")
choice0 = (data.get("choices") or [{}])[0] if isinstance(data.get("choices"), list) else {}
msg = (choice0 or {}).get("message") or {}
# OpenAI-compatible providers may return JSON in different places:
# - message.tool_calls[].function.arguments (JSON string) <-- preferred when tools are used
# - message.function_call.arguments (legacy)
# - message.content (common)
# - choice.text (legacy completions)
raw = ""
if not raw and isinstance(msg, dict):
tcs = msg.get("tool_calls") or []
if isinstance(tcs, list) and tcs:
fn = ((tcs[0] or {}).get("function") or {}) if isinstance(tcs[0], dict) else {}
args = (fn.get("arguments") or "") if isinstance(fn, dict) else ""
if isinstance(args, str) and args.strip():
raw = args.strip()
if not raw and isinstance(msg, dict):
fc = msg.get("function_call") or {}
args = (fc.get("arguments") or "") if isinstance(fc, dict) else ""
if isinstance(args, str) and args.strip():
raw = args.strip()
if not raw:
content = (msg.get("content") or "") if isinstance(msg, dict) else ""
if isinstance(content, str) and content.strip():
raw = content.strip()
if not raw:
txt = (choice0.get("text") or "") if isinstance(choice0, dict) else ""
if isinstance(txt, str) and txt.strip():
raw = txt.strip()
m = re.search(r"\{[\s\S]*\}", raw)
if not m:
snippet = (raw or "")[:400].replace("\n", "\\n")
choice_dump = json.dumps(choice0, ensure_ascii=False)[:600] if isinstance(choice0, dict) else str(choice0)[:600]
raise HTTPException(
status_code=502,
detail=f"Chat API response has no JSON object; content_snippet={snippet}; choice0={choice_dump}",
)
try:
obj = json.loads(m.group(0))
except Exception as e:
snippet = (m.group(0) or "")[:400].replace("\n", "\\n")
raise HTTPException(status_code=502, detail=f"Chat API JSON parse failed: {e}; json_snippet={snippet}")
if not isinstance(obj, dict):
raise HTTPException(status_code=502, detail="Chat API response JSON is not an object")
title = str(obj.get("title") or "").strip()
keywords = obj.get("keywords") or []
if not title:
raise HTTPException(status_code=502, detail="Chat API response missing title")
if not isinstance(keywords, list):
keywords = []
keywords = [str(x).strip() for x in keywords if str(x).strip()][:6]
return {"title": title[:24], "keywords": keywords}
# No fallback / no retry: fail fast if provider/model doesn't support response_format or returns invalid output.
return _post_and_parse(payload)
async def _call_chat_title(
*,
base_url: str,
api_key: str,
model: str,
prompt_samples: List[str],
output_lang: str,
) -> Dict:
"""
Same rationale as embeddings:
- requests.post() is synchronous and would block the event loop thread for tens of seconds.
- Offload to a worker thread to keep other API requests responsive.
"""
ret = await asyncio.to_thread(
_call_chat_title_sync,
base_url=base_url,
api_key=api_key,
model=model,
prompt_samples=prompt_samples,
output_lang=output_lang,
)
if not isinstance(ret, dict):
raise HTTPException(status_code=502, detail="Chat API returned empty title payload")
return ret
def mount_topic_cluster_routes(
app: FastAPI,
db_api_base: str,
verify_secret,
write_permission_required,
*,
openai_base_url: str,
openai_api_key: str,
embedding_model: str,
ai_model: str,
):
"""
Mount embedding + topic clustering endpoints (MVP: manual, iib_output only).
"""
async def _run_cluster_job(job_id: str, req) -> None:
try:
_job_upsert(
job_id,
{
"status": "running",
"stage": "init",
"created_at": _job_now(),
"req": req.model_dump() if hasattr(req, "model_dump") else req.dict(),
},
)
folders = _extract_and_validate_folders(req)
_job_upsert(job_id, {"folders": folders, "stage": "embedding"})
# Aggregate per-folder embedding progress into totals
per_folder: Dict[str, Dict] = {}
def _embed_cb(p: Dict) -> None:
if not isinstance(p, dict):
return
if p.get("stage") != "embedding":
return
f = str(p.get("folder") or "")
if f:
per_folder[f] = dict(p)
scanned = sum(int(x.get("scanned") or 0) for x in per_folder.values())
to_embed = sum(int(x.get("to_embed") or 0) for x in per_folder.values())
embedded_done = sum(int(x.get("embedded_done") or 0) for x in per_folder.values())
updated = sum(int(x.get("updated") or 0) for x in per_folder.values())
skipped = sum(int(x.get("skipped") or 0) for x in per_folder.values())
_job_upsert(
job_id,
{
"stage": "embedding",
"progress": {
"scanned": scanned,
"to_embed": to_embed,
"embedded_done": embedded_done,
"updated": updated,
"skipped": skipped,
"folder": f,
"batch_n": int((p or {}).get("batch_n") or 0),
},
},
)
# Ensure embeddings exist (incremental per folder)
model = req.model or embedding_model
batch_size = max(1, min(int(req.batch_size or 64), 256))
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
force = bool(req.force_embed)
for f in folders:
await _build_embeddings_one_folder(
folder=f,
model=model,
force=force,
batch_size=batch_size,
max_chars=max_chars,
progress_cb=_embed_cb,
)
# Clustering + titling progress
def _cluster_cb(p: Dict) -> None:
if not isinstance(p, dict):
return
st = str(p.get("stage") or "")
if st:
patch = {"stage": st, "status": "running"}
# keep small progress fields only
prog = {}
for k in [
"items_total",
"items_done",
"clusters_total",
"clusters_done",
"folder",
]:
if k in p:
prog[k] = p.get(k)
if prog:
patch["progress"] = {**(_job_get(job_id) or {}).get("progress", {}), **prog}
_job_upsert(job_id, patch)
res = await _cluster_after_embeddings(req, folders, progress_cb=_cluster_cb)
_job_upsert(job_id, {"status": "done", "stage": "done", "result": res})
except HTTPException as e:
_job_upsert(job_id, {"status": "error", "stage": "error", "error": str(e.detail)})
except Exception as e:
_job_upsert(job_id, {"status": "error", "stage": "error", "error": f"{type(e).__name__}: {e}"})
async def _build_embeddings_one_folder(
*,
folder: str,
model: str,
force: bool,
batch_size: int,
max_chars: int,
progress_cb: Optional[Callable[[Dict], None]] = None,
) -> Dict:
"""
Build embeddings for a single folder with optional progress callback.
Progress payload (best-effort):
- stage: "embedding"
- folder, scanned, to_embed, embedded_done, updated, skipped
"""
if not openai_api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
if not openai_base_url:
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
folder = os.path.normpath(folder)
if not os.path.exists(folder) or not os.path.isdir(folder):
raise HTTPException(status_code=400, detail=f"Folder not found: {folder}")
conn = DataBase.get_conn()
like_prefix = os.path.join(folder, "%")
with closing(conn.cursor()) as cur:
cur.execute("SELECT id, path, exif FROM image WHERE path LIKE ?", (like_prefix,))
rows = cur.fetchall()
images = []
for image_id, path, exif in rows:
if not isinstance(path, str) or not os.path.exists(path):
continue
text_raw = _extract_prompt_text(exif, max_chars=max_chars)
if _PROMPT_NORMALIZE_ENABLED:
text = _clean_prompt_for_semantic(text_raw)
if not text:
text = text_raw
else:
text = text_raw
# Some embedding models/providers have strict context limits (often 8192 tokens).
# Apply a conservative truncation before sending to embedding API.
text = _truncate_for_embedding_tokens(text, _EMBEDDING_MAX_TOKENS_SOFT)
if not text:
continue
images.append({"id": int(image_id), "path": path, "text": text})
if not images:
if progress_cb:
progress_cb(
{
"stage": "embedding",
"folder": folder,
"scanned": 0,
"to_embed": 0,
"embedded_done": 0,
"updated": 0,
"skipped": 0,
}
)
return {"folder": folder, "count": 0, "updated": 0, "skipped": 0, "model": model}
id_list = [x["id"] for x in images]
existing = ImageEmbedding.get_by_image_ids(conn, id_list)
to_embed = []
skipped = 0
for item in images:
# include normalize version to force refresh when rules change
text_hash = ImageEmbedding.compute_text_hash(f"{_PROMPT_NORMALIZE_VERSION}:{item['text']}")
old = existing.get(item["id"])
if (
(not force)
and old
and old.get("model") == model
and old.get("text_hash") == text_hash
and old.get("vec")
):
skipped += 1
continue
to_embed.append({**item, "text_hash": text_hash})
if progress_cb:
progress_cb(
{
"stage": "embedding",
"folder": folder,
"scanned": len(images),
"to_embed": len(to_embed),
"embedded_done": 0,
"updated": 0,
"skipped": skipped,
}
)
updated = 0
embedded_done = 0
batches = _batched_by_token_budget(
to_embed,
max_items=batch_size,
max_tokens_sum=_EMBEDDING_REQUEST_MAX_TOKENS_SOFT,
)
for bi, batch in enumerate(batches):
inputs = [x["text"] for x in batch]
if _EMBEDDING_DEBUG:
token_sum = sum(_estimate_tokens_soft(s) for s in inputs)
token_max = max((_estimate_tokens_soft(s) for s in inputs), default=0)
print(
f"[iib][embed] folder={folder} batch={bi+1}/{len(batches)} n={len(inputs)} token_sum~={token_sum} token_max~={token_max}"
)
vectors = await _call_embeddings(
inputs=inputs,
model=model,
base_url=openai_base_url,
api_key=openai_api_key,
)
if len(vectors) != len(batch):
raise HTTPException(status_code=500, detail="Embeddings count mismatch")
for item, vec in zip(batch, vectors):
ImageEmbedding.upsert(
conn=conn,
image_id=item["id"],
model=model,
dim=len(vec),
text_hash=item["text_hash"],
vec_blob=_vec_to_blob_f32(vec),
)
updated += 1
embedded_done += 1
conn.commit()
if progress_cb:
progress_cb(
{
"stage": "embedding",
"folder": folder,
"scanned": len(images),
"to_embed": len(to_embed),
"embedded_done": embedded_done,
"updated": updated,
"skipped": skipped,
"batch_n": len(inputs),
}
)
# yield between batches
await asyncio.sleep(0)
return {"folder": folder, "count": len(images), "updated": updated, "skipped": skipped, "model": model}
class BuildIibOutputEmbeddingReq(BaseModel):
folder: Optional[str] = None # default: {cwd}/iib_output
model: Optional[str] = None
force: Optional[bool] = False
batch_size: Optional[int] = 64
max_chars: Optional[int] = 4000
@app.post(
f"{db_api_base}/build_iib_output_embeddings",
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
)
async def build_iib_output_embeddings(req: BuildIibOutputEmbeddingReq):
if not openai_api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
if not openai_base_url:
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
folder = req.folder or os.path.join(cwd, "iib_output")
model = req.model or embedding_model
batch_size = max(1, min(int(req.batch_size or 64), 256))
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
force = bool(req.force)
return await _build_embeddings_one_folder(
folder=folder,
model=model,
force=force,
batch_size=batch_size,
max_chars=max_chars,
progress_cb=None,
)
class ClusterIibOutputReq(BaseModel):
folder: Optional[str] = None
folder_paths: Optional[List[str]] = None
model: Optional[str] = None
force_embed: Optional[bool] = False
threshold: Optional[float] = 0.86
batch_size: Optional[int] = 64
max_chars: Optional[int] = 4000
min_cluster_size: Optional[int] = 2
# B: LLM title generation
title_model: Optional[str] = None
# Reduce noise by reassigning small-cluster members to best large cluster if similarity is high enough
assign_noise_threshold: Optional[float] = None
# Cache titles in sqlite to avoid repeated LLM calls
use_title_cache: Optional[bool] = True
force_title: Optional[bool] = False
# Output language for titles/keywords (from frontend globalStore.lang)
lang: Optional[str] = None
@app.post(
f"{db_api_base}/cluster_iib_output_job_start",
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
)
async def cluster_iib_output_job_start(req: ClusterIibOutputReq):
"""
Start a background job for embedding + clustering + LLM titling.
Returns job_id immediately; frontend should poll job_status to show progress.
"""
job_id = uuid.uuid4().hex
_job_upsert(job_id, {"status": "queued", "stage": "queued", "created_at": _job_now()})
asyncio.create_task(_run_cluster_job(job_id, req))
return {"job_id": job_id}
@app.get(
f"{db_api_base}/cluster_iib_output_job_status",
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
)
async def cluster_iib_output_job_status(job_id: str):
j = _job_get(job_id)
if not j:
raise HTTPException(status_code=404, detail="job not found")
return j
def _extract_and_validate_folders(req: ClusterIibOutputReq) -> List[str]:
folders: List[str] = []
if req.folder_paths:
for p in req.folder_paths:
if isinstance(p, str) and p.strip():
folders.append(os.path.normpath(p.strip()))
if req.folder and isinstance(req.folder, str) and req.folder.strip():
folders.append(os.path.normpath(req.folder.strip()))
# 用户不会用默认 iib_output未指定范围则直接报错
if not folders:
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to cluster)")
folders = list(dict.fromkeys(folders))
for f in folders:
if not os.path.exists(f) or not os.path.isdir(f):
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
return folders
async def _cluster_after_embeddings(
req: ClusterIibOutputReq,
folders: List[str],
progress_cb: Optional[Callable[[Dict], None]] = None,
) -> Dict:
folder = folders[0]
model = req.model or embedding_model
threshold = float(req.threshold or 0.86)
threshold = max(0.0, min(threshold, 0.999))
min_cluster_size = max(1, int(req.min_cluster_size or 2))
title_model = req.title_model or os.getenv("TOPIC_TITLE_MODEL") or ai_model
output_lang = _normalize_output_lang(req.lang)
assign_noise_threshold = req.assign_noise_threshold
if assign_noise_threshold is None:
# conservative: only reassign if very likely belongs to a large topic
assign_noise_threshold = max(0.72, min(threshold - 0.035, 0.93))
else:
assign_noise_threshold = max(0.0, min(float(assign_noise_threshold), 0.999))
use_title_cache = bool(True if req.use_title_cache is None else req.use_title_cache)
force_title = bool(req.force_title)
if progress_cb:
progress_cb({"stage": "clustering", "folder": folder, "folders": folders})
conn = DataBase.get_conn()
like_prefixes = [os.path.join(f, "%") for f in folders]
with closing(conn.cursor()) as cur:
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
cur.execute(
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
FROM image
INNER JOIN image_embedding ON image_embedding.image_id = image.id
WHERE ({where}) AND image_embedding.model = ?""",
(*like_prefixes, model),
)
rows = cur.fetchall()
items = []
for n, (image_id, path, exif, vec_blob) in enumerate(rows):
if not isinstance(path, str) or not os.path.exists(path):
continue
if not vec_blob:
continue
vec = _blob_to_vec_f32(vec_blob)
n2 = _l2_norm_sq(vec)
if n2 <= 0:
continue
inv = 1.0 / math.sqrt(n2)
for i in range(len(vec)):
vec[i] *= inv
text_raw = _extract_prompt_text(exif, max_chars=int(req.max_chars or 4000))
if _PROMPT_NORMALIZE_ENABLED:
text = _clean_prompt_for_semantic(text_raw)
if not text:
text = text_raw
else:
text = text_raw
items.append({"id": int(image_id), "path": path, "text": text, "vec": vec})
if (n + 1) % 800 == 0:
await asyncio.sleep(0)
if not items:
return {"folder": folder, "folders": folders, "model": model, "threshold": threshold, "clusters": [], "noise": []}
# Incremental clustering by centroid-direction (sum vector)
clusters = [] # {sum, norm_sq, members:[idx], sample_text}
for idx, it in enumerate(items):
v = it["vec"]
best_ci = -1
best_sim = -1.0
best_dot = 0.0
for ci, c in enumerate(clusters):
dotv = _dot(v, c["sum"])
denom = math.sqrt(c["norm_sq"]) if c["norm_sq"] > 0 else 1.0
sim = dotv / denom
if sim > best_sim:
best_sim = sim
best_ci = ci
best_dot = dotv
if best_ci != -1 and best_sim >= threshold:
c = clusters[best_ci]
for i in range(len(v)):
c["sum"][i] += v[i]
c["norm_sq"] = c["norm_sq"] + 2.0 * best_dot + 1.0
c["members"].append(idx)
else:
clusters.append({"sum": array("f", v), "norm_sq": 1.0, "members": [idx], "sample_text": it.get("text") or ""})
if (idx + 1) % 800 == 0:
if progress_cb:
progress_cb({"stage": "clustering", "items_total": len(items), "items_done": idx + 1})
await asyncio.sleep(0)
# Merge highly similar clusters (fix: same theme split into multiple clusters)
merge_threshold = min(0.995, max(threshold + 0.04, 0.90))
merged = True
while merged and len(clusters) > 1:
merged = False
best_i = best_j = -1
best_sim = merge_threshold
for i in range(len(clusters)):
ci = clusters[i]
for j in range(i + 1, len(clusters)):
cj = clusters[j]
sim = _cos_sum(ci["sum"], ci["norm_sq"], cj["sum"], cj["norm_sq"])
if sim >= best_sim:
best_sim = sim
best_i, best_j = i, j
if best_i != -1:
a = clusters[best_i]
b = clusters[best_j]
for k in range(len(a["sum"])):
a["sum"][k] += b["sum"][k]
a["norm_sq"] = _l2_norm_sq(a["sum"])
a["members"].extend(b["members"])
if not a.get("sample_text"):
a["sample_text"] = b.get("sample_text", "")
clusters.pop(best_j)
merged = True
await asyncio.sleep(0)
# Reassign members from small clusters into best large cluster to reduce noise
if min_cluster_size > 1 and assign_noise_threshold > 0 and clusters:
large = [c for c in clusters if len(c["members"]) >= min_cluster_size]
if large:
new_large = []
# copy large clusters first
for c in clusters:
if len(c["members"]) >= min_cluster_size:
new_large.append(c)
# reassign items from small clusters
for c in clusters:
if len(c["members"]) >= min_cluster_size:
continue
for mi in c["members"]:
v = items[mi]["vec"]
best_ci = -1
best_sim = -1.0
best_dot = 0.0
for ci, bigc in enumerate(new_large):
dotv = _dot(v, bigc["sum"])
denom = math.sqrt(bigc["norm_sq"]) if bigc["norm_sq"] > 0 else 1.0
sim = dotv / denom
if sim > best_sim:
best_sim = sim
best_ci = ci
best_dot = dotv
if best_ci != -1 and best_sim >= assign_noise_threshold:
bigc = new_large[best_ci]
for k in range(len(v)):
bigc["sum"][k] += v[k]
bigc["norm_sq"] = bigc["norm_sq"] + 2.0 * best_dot + 1.0
bigc["members"].append(mi)
# else: keep in small cluster -> will become noise below
await asyncio.sleep(0)
clusters = new_large
# Split small clusters to noise, generate titles
out_clusters = []
noise = []
if progress_cb:
progress_cb({"stage": "titling", "clusters_total": len(clusters)})
for cidx, c in enumerate(clusters):
if len(c["members"]) < min_cluster_size:
for mi in c["members"]:
noise.append(items[mi]["path"])
continue
member_items = [items[mi] for mi in c["members"]]
paths = [x["path"] for x in member_items]
texts = [x.get("text") or "" for x in member_items]
member_ids = [x["id"] for x in member_items]
# Representative prompt for LLM title generation
rep = (c.get("sample_text") or (texts[0] if texts else "")).strip()
cached = None
cluster_hash = _cluster_sig(
member_ids=member_ids,
model=model,
threshold=threshold,
min_cluster_size=min_cluster_size,
title_model=title_model,
lang=output_lang,
)
if use_title_cache and (not force_title):
cached = TopicTitleCache.get(conn, cluster_hash)
if cached and isinstance(cached, dict) and cached.get("title"):
title = str(cached.get("title"))
keywords = cached.get("keywords") or []
else:
llm = await _call_chat_title(
base_url=openai_base_url,
api_key=openai_api_key,
model=title_model,
prompt_samples=[rep] + texts[:5],
output_lang=output_lang,
)
title = (llm or {}).get("title")
keywords = (llm or {}).get("keywords", [])
if not title:
raise HTTPException(status_code=502, detail="Chat API returned empty title")
if use_title_cache and title:
try:
TopicTitleCache.upsert(conn, cluster_hash, str(title), list(keywords or []), str(title_model))
conn.commit()
except Exception:
pass
out_clusters.append(
{
"id": f"topic_{cidx}",
"title": title,
"keywords": keywords,
"size": len(paths),
"paths": paths,
"sample_prompt": _clean_for_title(rep)[:200],
}
)
if (cidx + 1) % 6 == 0:
if progress_cb:
progress_cb({"stage": "titling", "clusters_total": len(clusters), "clusters_done": cidx + 1})
await asyncio.sleep(0)
out_clusters.sort(key=lambda x: x["size"], reverse=True)
return {
"folder": folder,
"folders": folders,
"count": len(items),
"threshold": threshold,
"min_cluster_size": min_cluster_size,
"model": model,
"assign_noise_threshold": assign_noise_threshold,
"clusters": out_clusters,
"noise": noise,
}
class PromptSearchReq(BaseModel):
query: str
folder: Optional[str] = None
folder_paths: Optional[List[str]] = None
model: Optional[str] = None
top_k: Optional[int] = 50
min_score: Optional[float] = 0.0
# Ensure embeddings exist/updated before searching
ensure_embed: Optional[bool] = True
# Use the same normalization as clustering
max_chars: Optional[int] = 4000
@app.post(
f"{db_api_base}/search_iib_output_by_prompt",
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
)
async def search_iib_output_by_prompt(req: PromptSearchReq):
if not openai_api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
if not openai_base_url:
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
q = (req.query or "").strip()
if not q:
raise HTTPException(status_code=400, detail="query is required")
folders: List[str] = []
if req.folder_paths:
for p in req.folder_paths:
if isinstance(p, str) and p.strip():
folders.append(os.path.normpath(p.strip()))
if req.folder and isinstance(req.folder, str) and req.folder.strip():
folders.append(os.path.normpath(req.folder.strip()))
# 用户不会用默认 iib_output未指定范围则直接报错
if not folders:
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to search)")
# validate folders
folders = list(dict.fromkeys(folders)) # de-dup keep order
for f in folders:
if not os.path.exists(f) or not os.path.isdir(f):
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
folder = folders[0]
model = req.model or embedding_model
top_k = max(1, min(int(req.top_k or 50), 500))
min_score = float(req.min_score or 0.0)
min_score = max(-1.0, min(min_score, 1.0))
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
if bool(req.ensure_embed):
for f in folders:
await build_iib_output_embeddings(
BuildIibOutputEmbeddingReq(folder=f, model=model, force=False, batch_size=64, max_chars=max_chars)
)
# Build query embedding
q_text = _extract_prompt_text(q, max_chars=max_chars)
if _PROMPT_NORMALIZE_ENABLED:
q_text2 = _clean_prompt_for_semantic(q_text)
if q_text2:
q_text = q_text2
q_text = _truncate_for_embedding_tokens(q_text, _EMBEDDING_MAX_TOKENS_SOFT)
vecs = await _call_embeddings(inputs=[q_text], model=model, base_url=openai_base_url, api_key=openai_api_key)
if not vecs or not isinstance(vecs[0], list) or not vecs[0]:
raise HTTPException(status_code=502, detail="Embedding API returned empty vector")
qv = array("f", [float(x) for x in vecs[0]])
qn2 = _l2_norm_sq(qv)
if qn2 <= 0:
raise HTTPException(status_code=502, detail="Query embedding has zero norm")
qinv = 1.0 / math.sqrt(qn2)
for i in range(len(qv)):
qv[i] *= qinv
conn = DataBase.get_conn()
like_prefixes = [os.path.join(f, "%") for f in folders]
with closing(conn.cursor()) as cur:
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
cur.execute(
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
FROM image
INNER JOIN image_embedding ON image_embedding.image_id = image.id
WHERE ({where}) AND image_embedding.model = ?""",
(*like_prefixes, model),
)
rows = cur.fetchall()
# TopK by cosine similarity (brute force; MVP only)
import heapq
heap: List[Tuple[float, Dict]] = []
total = 0
for image_id, path, exif, vec_blob in rows:
if not isinstance(path, str) or not os.path.exists(path):
continue
if not vec_blob:
continue
v = _blob_to_vec_f32(vec_blob)
n2 = _l2_norm_sq(v)
if n2 <= 0:
continue
inv = 1.0 / math.sqrt(n2)
for i in range(len(v)):
v[i] *= inv
score = _dot(qv, v)
total += 1
if score < min_score:
continue
item = {
"id": int(image_id),
"path": path,
"score": float(score),
"sample_prompt": _clean_for_title(_extract_prompt_text(exif, max_chars=max_chars))[:200],
}
if len(heap) < top_k:
heapq.heappush(heap, (score, item))
else:
if score > heap[0][0]:
heapq.heapreplace(heap, (score, item))
heap.sort(key=lambda x: x[0], reverse=True)
results = [x[1] for x in heap]
return {
"query": q,
"folder": folder,
"folders": folders,
"model": model,
"count": total,
"top_k": top_k,
"results": results,
}
def _cluster_sig(
*,
member_ids: List[int],
model: str,
threshold: float,
min_cluster_size: int,
title_model: str,
lang: str,
) -> str:
h = hashlib.sha1()
h.update(f"m={model}|t={threshold:.6f}|min={min_cluster_size}|tm={title_model}|lang={lang}|nv={_PROMPT_NORMALIZE_VERSION}|nm={_PROMPT_NORMALIZE_MODE}".encode("utf-8"))
for iid in sorted(member_ids):
h.update(b"|")
h.update(str(int(iid)).encode("utf-8"))
return h.hexdigest()
def _normalize_output_lang(lang: Optional[str]) -> str:
"""
Map frontend language keys to a human-readable instruction for LLM output language.
Frontend uses: en / zhHans / zhHant / de
"""
if not lang:
return "English"
l = str(lang).strip()
ll = l.lower()
if ll in ["zh", "zhhans", "zh-hans", "zh_cn", "zh-cn", "cn", "zh-hans-cns", "zhs"]:
return "Chinese (Simplified)"
if ll in ["zhhant", "zh-hant", "zh_tw", "zh-tw", "zh_hk", "zh-hk", "tw", "hk", "zht"]:
return "Chinese (Traditional)"
if ll.startswith("de"):
return "German"
if ll.startswith("en"):
return "English"
# fallback
return "English"
@app.post(
f"{db_api_base}/cluster_iib_output",
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
)
async def cluster_iib_output(req: ClusterIibOutputReq):
if not openai_api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
if not openai_base_url:
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
folders: List[str] = []
if req.folder_paths:
for p in req.folder_paths:
if isinstance(p, str) and p.strip():
folders.append(os.path.normpath(p.strip()))
if req.folder and isinstance(req.folder, str) and req.folder.strip():
folders.append(os.path.normpath(req.folder.strip()))
# 用户不会用默认 iib_output未指定范围则直接报错
if not folders:
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to cluster)")
folders = list(dict.fromkeys(folders))
for f in folders:
if not os.path.exists(f) or not os.path.isdir(f):
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
# Ensure embeddings exist (incremental per folder)
for f in folders:
await build_iib_output_embeddings(
BuildIibOutputEmbeddingReq(
folder=f,
model=req.model,
force=req.force_embed,
batch_size=req.batch_size,
max_chars=req.max_chars,
)
)
folder = folders[0]
model = req.model or embedding_model
threshold = float(req.threshold or 0.86)
threshold = max(0.0, min(threshold, 0.999))
min_cluster_size = max(1, int(req.min_cluster_size or 2))
title_model = req.title_model or os.getenv("TOPIC_TITLE_MODEL") or ai_model
output_lang = _normalize_output_lang(req.lang)
assign_noise_threshold = req.assign_noise_threshold
if assign_noise_threshold is None:
# conservative: only reassign if very likely belongs to a large topic
assign_noise_threshold = max(0.72, min(threshold - 0.035, 0.93))
else:
assign_noise_threshold = max(0.0, min(float(assign_noise_threshold), 0.999))
use_title_cache = bool(True if req.use_title_cache is None else req.use_title_cache)
force_title = bool(req.force_title)
conn = DataBase.get_conn()
like_prefixes = [os.path.join(f, "%") for f in folders]
with closing(conn.cursor()) as cur:
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
cur.execute(
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
FROM image
INNER JOIN image_embedding ON image_embedding.image_id = image.id
WHERE ({where}) AND image_embedding.model = ?""",
(*like_prefixes, model),
)
rows = cur.fetchall()
items = []
for image_id, path, exif, vec_blob in rows:
if not isinstance(path, str) or not os.path.exists(path):
continue
if not vec_blob:
continue
vec = _blob_to_vec_f32(vec_blob)
n2 = _l2_norm_sq(vec)
if n2 <= 0:
continue
inv = 1.0 / math.sqrt(n2)
for i in range(len(vec)):
vec[i] *= inv
text_raw = _extract_prompt_text(exif, max_chars=int(req.max_chars or 4000))
if _PROMPT_NORMALIZE_ENABLED:
text = _clean_prompt_for_semantic(text_raw)
if not text:
text = text_raw
else:
text = text_raw
items.append({"id": int(image_id), "path": path, "text": text, "vec": vec})
if not items:
return {"folder": folder, "folders": folders, "model": model, "threshold": threshold, "clusters": [], "noise": []}
# Incremental clustering by centroid-direction (sum vector)
clusters = [] # {sum, norm_sq, members:[idx], sample_text}
for idx, it in enumerate(items):
v = it["vec"]
best_ci = -1
best_sim = -1.0
best_dot = 0.0
for ci, c in enumerate(clusters):
dotv = _dot(v, c["sum"])
denom = math.sqrt(c["norm_sq"]) if c["norm_sq"] > 0 else 1.0
sim = dotv / denom
if sim > best_sim:
best_sim = sim
best_ci = ci
best_dot = dotv
if best_ci != -1 and best_sim >= threshold:
c = clusters[best_ci]
for i in range(len(v)):
c["sum"][i] += v[i]
c["norm_sq"] = c["norm_sq"] + 2.0 * best_dot + 1.0
c["members"].append(idx)
else:
clusters.append({"sum": array("f", v), "norm_sq": 1.0, "members": [idx], "sample_text": it.get("text") or ""})
# Merge highly similar clusters (fix: same theme split into multiple clusters)
merge_threshold = min(0.995, max(threshold + 0.04, 0.90))
merged = True
while merged and len(clusters) > 1:
merged = False
best_i = best_j = -1
best_sim = merge_threshold
for i in range(len(clusters)):
ci = clusters[i]
for j in range(i + 1, len(clusters)):
cj = clusters[j]
sim = _cos_sum(ci["sum"], ci["norm_sq"], cj["sum"], cj["norm_sq"])
if sim >= best_sim:
best_sim = sim
best_i, best_j = i, j
if best_i != -1:
a = clusters[best_i]
b = clusters[best_j]
for k in range(len(a["sum"])):
a["sum"][k] += b["sum"][k]
a["norm_sq"] = _l2_norm_sq(a["sum"])
a["members"].extend(b["members"])
if not a.get("sample_text"):
a["sample_text"] = b.get("sample_text", "")
clusters.pop(best_j)
merged = True
# Reassign members from small clusters into best large cluster to reduce noise
if min_cluster_size > 1 and assign_noise_threshold > 0 and clusters:
large = [c for c in clusters if len(c["members"]) >= min_cluster_size]
if large:
new_large = []
# copy large clusters first
for c in clusters:
if len(c["members"]) >= min_cluster_size:
new_large.append(c)
# reassign items from small clusters
for c in clusters:
if len(c["members"]) >= min_cluster_size:
continue
for mi in c["members"]:
v = items[mi]["vec"]
best_ci = -1
best_sim = -1.0
best_dot = 0.0
for ci, bigc in enumerate(new_large):
dotv = _dot(v, bigc["sum"])
denom = math.sqrt(bigc["norm_sq"]) if bigc["norm_sq"] > 0 else 1.0
sim = dotv / denom
if sim > best_sim:
best_sim = sim
best_ci = ci
best_dot = dotv
if best_ci != -1 and best_sim >= assign_noise_threshold:
bigc = new_large[best_ci]
for k in range(len(v)):
bigc["sum"][k] += v[k]
bigc["norm_sq"] = bigc["norm_sq"] + 2.0 * best_dot + 1.0
bigc["members"].append(mi)
# else: keep in small cluster -> will become noise below
clusters = new_large
# Split small clusters to noise, generate titles
out_clusters = []
noise = []
for cidx, c in enumerate(clusters):
if len(c["members"]) < min_cluster_size:
for mi in c["members"]:
noise.append(items[mi]["path"])
continue
member_items = [items[mi] for mi in c["members"]]
paths = [x["path"] for x in member_items]
texts = [x.get("text") or "" for x in member_items]
member_ids = [x["id"] for x in member_items]
# Representative prompt for LLM title generation
rep = (c.get("sample_text") or (texts[0] if texts else "")).strip()
cached = None
cluster_hash = _cluster_sig(
member_ids=member_ids,
model=model,
threshold=threshold,
min_cluster_size=min_cluster_size,
title_model=title_model,
lang=output_lang,
)
if use_title_cache and (not force_title):
cached = TopicTitleCache.get(conn, cluster_hash)
if cached and isinstance(cached, dict) and cached.get("title"):
title = str(cached.get("title"))
keywords = cached.get("keywords") or []
else:
llm = await _call_chat_title(
base_url=openai_base_url,
api_key=openai_api_key,
model=title_model,
prompt_samples=[rep] + texts[:5],
output_lang=output_lang,
)
title = (llm or {}).get("title")
keywords = (llm or {}).get("keywords", [])
if not title:
raise HTTPException(status_code=502, detail="Chat API returned empty title")
if use_title_cache and title:
try:
TopicTitleCache.upsert(conn, cluster_hash, str(title), list(keywords or []), str(title_model))
conn.commit()
except Exception:
pass
out_clusters.append(
{
"id": f"topic_{cidx}",
"title": title,
"keywords": keywords,
"size": len(paths),
"paths": paths,
"sample_prompt": _clean_for_title(rep)[:200],
}
)
out_clusters.sort(key=lambda x: x["size"], reverse=True)
return {
"folder": folder,
"folders": folders,
"model": model,
"threshold": threshold,
"min_cluster_size": min_cluster_size,
"clusters": out_clusters,
"noise": noise,
"count": len(items),
"title_model": title_model,
}