feat: cache embedding failures and persist cluster results

pull/870/head
wuqinchuan 2025-12-29 00:39:10 +08:00 committed by zanllp
parent 2edf9e52d7
commit efcb500c53
2 changed files with 357 additions and 237 deletions

View File

@ -82,7 +82,9 @@ class DataBase:
DirCoverCache.create_table(conn)
GlobalSetting.create_table(conn)
ImageEmbedding.create_table(conn)
ImageEmbeddingFail.create_table(conn)
TopicTitleCache.create_table(conn)
TopicClusterCache.create_table(conn)
finally:
conn.commit()
clz.num += 1
@ -383,6 +385,77 @@ class ImageEmbedding:
)
class ImageEmbeddingFail:
"""
Cache embedding failures per image+model+text_hash to avoid repeatedly hitting the API
for known-failing inputs. This helps keep clustering/search usable by skipping bad items.
"""
@classmethod
def create_table(cls, conn: Connection):
with closing(conn.cursor()) as cur:
cur.execute(
"""CREATE TABLE IF NOT EXISTS image_embedding_fail (
image_id INTEGER NOT NULL,
model TEXT NOT NULL,
text_hash TEXT NOT NULL,
error TEXT NOT NULL,
updated_at TEXT NOT NULL,
PRIMARY KEY(image_id, model, text_hash)
)"""
)
cur.execute("CREATE INDEX IF NOT EXISTS image_embedding_fail_idx_model ON image_embedding_fail(model)")
@classmethod
def get_by_image_ids(cls, conn: Connection, image_ids: List[int], model: str) -> Dict[int, Dict]:
if not image_ids:
return {}
ids = [int(x) for x in image_ids]
placeholders = ",".join(["?"] * len(ids))
with closing(conn.cursor()) as cur:
cur.execute(
f"SELECT image_id, text_hash, error, updated_at FROM image_embedding_fail WHERE model = ? AND image_id IN ({placeholders})",
(str(model), *ids),
)
rows = cur.fetchall()
out: Dict[int, Dict] = {}
for image_id, text_hash, error, updated_at in rows or []:
out[int(image_id)] = {
"text_hash": str(text_hash or ""),
"error": str(error or ""),
"updated_at": str(updated_at or ""),
}
return out
@classmethod
def upsert(
cls,
conn: Connection,
*,
image_id: int,
model: str,
text_hash: str,
error: str,
updated_at: Optional[str] = None,
):
updated_at = updated_at or datetime.now().isoformat()
with closing(conn.cursor()) as cur:
cur.execute(
"""INSERT INTO image_embedding_fail (image_id, model, text_hash, error, updated_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(image_id, model, text_hash) DO UPDATE SET
error = excluded.error,
updated_at = excluded.updated_at
""",
(int(image_id), str(model), str(text_hash), str(error or "")[:600], str(updated_at)),
)
@classmethod
def delete(cls, conn: Connection, *, image_id: int, model: str):
with closing(conn.cursor()) as cur:
cur.execute("DELETE FROM image_embedding_fail WHERE image_id = ? AND model = ?", (int(image_id), str(model)))
class TopicTitleCache:
"""
Cache cluster titles/keywords to avoid repeated LLM calls.
@ -449,6 +522,112 @@ class TopicTitleCache:
)
class TopicClusterCache:
"""
Persist the final clustering result (clusters/noise) to avoid re-clustering when:
- embeddings haven't changed (by max(updated_at) & count), and
- clustering parameters are unchanged.
This is intentionally lightweight:
- result is stored as JSON text
- caller defines cache_key (sha1 over params + folders + normalize version + lang, etc.)
"""
@classmethod
def create_table(cls, conn: Connection):
with closing(conn.cursor()) as cur:
cur.execute(
"""CREATE TABLE IF NOT EXISTS topic_cluster_cache (
cache_key TEXT PRIMARY KEY,
folders TEXT NOT NULL,
model TEXT NOT NULL,
params TEXT NOT NULL,
embeddings_count INTEGER NOT NULL,
embeddings_max_updated_at TEXT NOT NULL,
result TEXT NOT NULL,
updated_at TEXT NOT NULL
)"""
)
cur.execute("CREATE INDEX IF NOT EXISTS topic_cluster_cache_idx_model ON topic_cluster_cache(model)")
@classmethod
def get(cls, conn: Connection, cache_key: str):
with closing(conn.cursor()) as cur:
cur.execute(
"SELECT folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at FROM topic_cluster_cache WHERE cache_key = ?",
(cache_key,),
)
row = cur.fetchone()
if not row:
return None
folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at = row
try:
folders_obj = json.loads(folders) if isinstance(folders, str) else []
except Exception:
folders_obj = []
try:
params_obj = json.loads(params) if isinstance(params, str) else {}
except Exception:
params_obj = {}
try:
result_obj = json.loads(result) if isinstance(result, str) else None
except Exception:
result_obj = None
return {
"cache_key": cache_key,
"folders": folders_obj if isinstance(folders_obj, list) else [],
"model": str(model),
"params": params_obj if isinstance(params_obj, dict) else {},
"embeddings_count": int(embeddings_count or 0),
"embeddings_max_updated_at": str(embeddings_max_updated_at or ""),
"result": result_obj,
"updated_at": str(updated_at or ""),
}
@classmethod
def upsert(
cls,
conn: Connection,
*,
cache_key: str,
folders: List[str],
model: str,
params: Dict,
embeddings_count: int,
embeddings_max_updated_at: str,
result: Dict,
updated_at: Optional[str] = None,
):
updated_at = updated_at or datetime.now().isoformat()
folders_s = json.dumps([str(x) for x in (folders or [])], ensure_ascii=False)
params_s = json.dumps(params or {}, ensure_ascii=False, sort_keys=True)
result_s = json.dumps(result or {}, ensure_ascii=False)
with closing(conn.cursor()) as cur:
cur.execute(
"""INSERT INTO topic_cluster_cache
(cache_key, folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(cache_key) DO UPDATE SET
folders = excluded.folders,
model = excluded.model,
params = excluded.params,
embeddings_count = excluded.embeddings_count,
embeddings_max_updated_at = excluded.embeddings_max_updated_at,
result = excluded.result,
updated_at = excluded.updated_at
""",
(
cache_key,
folders_s,
str(model),
params_s,
int(embeddings_count or 0),
str(embeddings_max_updated_at or ""),
result_s,
updated_at,
),
)
class Tag:
def __init__(self, name: str, score: int, type: str, count=0, color = ""):
self.name = name

View File

@ -15,7 +15,7 @@ import requests
from fastapi import Depends, FastAPI, HTTPException
from pydantic import BaseModel
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, TopicTitleCache
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache
from scripts.iib.tool import cwd
@ -634,6 +634,45 @@ def mount_topic_cluster_routes(
progress_cb=_embed_cb,
)
# If embeddings didn't change and we have a cached clustering result, return it directly.
conn = DataBase.get_conn()
like_prefixes = [os.path.join(f, "%") for f in folders]
with closing(conn.cursor()) as cur:
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
cur.execute(
f"""SELECT COUNT(*), MAX(image_embedding.updated_at)
FROM image
INNER JOIN image_embedding ON image_embedding.image_id = image.id
WHERE ({where}) AND image_embedding.model = ?""",
(*like_prefixes, model),
)
row = cur.fetchone() or (0, "")
embeddings_count = int(row[0] or 0)
embeddings_max_updated_at = str(row[1] or "")
cache_params = {
"model": model,
"threshold": float(req.threshold or 0.86),
"min_cluster_size": int(req.min_cluster_size or 2),
"assign_noise_threshold": req.assign_noise_threshold,
"title_model": req.title_model,
"lang": str(req.lang or ""),
"nv": _PROMPT_NORMALIZE_VERSION,
"nm": _PROMPT_NORMALIZE_MODE,
}
h = hashlib.sha1()
h.update(json.dumps({"folders": folders, "params": cache_params}, ensure_ascii=False, sort_keys=True).encode("utf-8"))
cache_key = h.hexdigest()
cached = TopicClusterCache.get(conn, cache_key)
if (
cached
and int(cached.get("embeddings_count") or 0) == embeddings_count
and str(cached.get("embeddings_max_updated_at") or "") == embeddings_max_updated_at
and isinstance(cached.get("result"), dict)
):
_job_upsert(job_id, {"status": "done", "stage": "done", "result": cached["result"], "cache_hit": True})
return
# Clustering + titling progress
def _cluster_cb(p: Dict) -> None:
if not isinstance(p, dict):
@ -657,6 +696,20 @@ def mount_topic_cluster_routes(
_job_upsert(job_id, patch)
res = await _cluster_after_embeddings(req, folders, progress_cb=_cluster_cb)
try:
TopicClusterCache.upsert(
conn,
cache_key=cache_key,
folders=folders,
model=model,
params=cache_params,
embeddings_count=embeddings_count,
embeddings_max_updated_at=embeddings_max_updated_at,
result=res,
)
conn.commit()
except Exception:
pass
_job_upsert(job_id, {"status": "done", "stage": "done", "result": res})
except HTTPException as e:
_job_upsert(job_id, {"status": "error", "stage": "error", "error": str(e.detail)})
@ -728,13 +781,16 @@ def mount_topic_cluster_routes(
id_list = [x["id"] for x in images]
existing = ImageEmbedding.get_by_image_ids(conn, id_list)
existing_fail = ImageEmbeddingFail.get_by_image_ids(conn, id_list, model)
to_embed = []
skipped = 0
skipped_failed = 0
for item in images:
# include normalize version to force refresh when rules change
text_hash = ImageEmbedding.compute_text_hash(f"{_PROMPT_NORMALIZE_VERSION}:{item['text']}")
old = existing.get(item["id"])
old_fail = existing_fail.get(item["id"])
if (
(not force)
and old
@ -744,6 +800,10 @@ def mount_topic_cluster_routes(
):
skipped += 1
continue
# Skip known failures for the same model+text_hash (unless force is enabled).
if (not force) and old_fail and str(old_fail.get("text_hash") or "") == text_hash:
skipped_failed += 1
continue
to_embed.append({**item, "text_hash": text_hash})
if progress_cb:
@ -756,11 +816,14 @@ def mount_topic_cluster_routes(
"embedded_done": 0,
"updated": 0,
"skipped": skipped,
"skipped_failed": skipped_failed,
"failed": 0,
}
)
updated = 0
embedded_done = 0
failed = 0
batches = _batched_by_token_budget(
to_embed,
max_items=batch_size,
@ -774,12 +837,49 @@ def mount_topic_cluster_routes(
print(
f"[iib][embed] folder={folder} batch={bi+1}/{len(batches)} n={len(inputs)} token_sum~={token_sum} token_max~={token_max}"
)
vectors = await _call_embeddings(
inputs=inputs,
model=model,
base_url=openai_base_url,
api_key=openai_api_key,
)
try:
vectors = await _call_embeddings(
inputs=inputs,
model=model,
base_url=openai_base_url,
api_key=openai_api_key,
)
except HTTPException as e:
# Cache failures for this batch and continue (skip these images for now).
err = str(e.detail)
for it in batch:
try:
ImageEmbeddingFail.upsert(
conn,
image_id=int(it["id"]),
model=str(model),
text_hash=str(it.get("text_hash") or ""),
error=err,
)
except Exception:
pass
try:
conn.commit()
except Exception:
pass
failed += len(batch)
if progress_cb:
progress_cb(
{
"stage": "embedding",
"folder": folder,
"scanned": len(images),
"to_embed": len(to_embed),
"embedded_done": embedded_done,
"updated": updated,
"skipped": skipped,
"skipped_failed": skipped_failed,
"failed": failed,
"batch_n": len(inputs),
}
)
await asyncio.sleep(0)
continue
if len(vectors) != len(batch):
raise HTTPException(status_code=500, detail="Embeddings count mismatch")
for item, vec in zip(batch, vectors):
@ -791,6 +891,11 @@ def mount_topic_cluster_routes(
text_hash=item["text_hash"],
vec_blob=_vec_to_blob_f32(vec),
)
# Success -> clear fail cache for this image+model (any old failure becomes irrelevant).
try:
ImageEmbeddingFail.delete(conn, image_id=int(item["id"]), model=str(model))
except Exception:
pass
updated += 1
embedded_done += 1
conn.commit()
@ -804,13 +909,23 @@ def mount_topic_cluster_routes(
"embedded_done": embedded_done,
"updated": updated,
"skipped": skipped,
"skipped_failed": skipped_failed,
"failed": failed,
"batch_n": len(inputs),
}
)
# yield between batches
await asyncio.sleep(0)
return {"folder": folder, "count": len(images), "updated": updated, "skipped": skipped, "model": model}
return {
"folder": folder,
"count": len(images),
"updated": updated,
"skipped": skipped,
"skipped_failed": skipped_failed,
"failed": failed,
"model": model,
}
class BuildIibOutputEmbeddingReq(BaseModel):
folder: Optional[str] = None # default: {cwd}/iib_output
@ -1312,250 +1427,76 @@ def mount_topic_cluster_routes(
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
)
async def cluster_iib_output(req: ClusterIibOutputReq):
if not openai_api_key:
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
if not openai_base_url:
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
folders: List[str] = []
if req.folder_paths:
for p in req.folder_paths:
if isinstance(p, str) and p.strip():
folders.append(os.path.normpath(p.strip()))
if req.folder and isinstance(req.folder, str) and req.folder.strip():
folders.append(os.path.normpath(req.folder.strip()))
# 用户不会用默认 iib_output未指定范围则直接报错
if not folders:
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to cluster)")
# Keep this endpoint for compatibility, but avoid duplicating the heavy logic.
# If embeddings haven't changed and we have a cached clustering result, return it directly.
folders = _extract_and_validate_folders(req)
folders = list(dict.fromkeys(folders))
for f in folders:
if not os.path.exists(f) or not os.path.isdir(f):
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
# Ensure embeddings exist (incremental per folder)
for f in folders:
await build_iib_output_embeddings(
BuildIibOutputEmbeddingReq(
folder=f,
model=req.model,
force=req.force_embed,
batch_size=req.batch_size,
max_chars=req.max_chars,
)
)
folder = folders[0]
model = req.model or embedding_model
threshold = float(req.threshold or 0.86)
threshold = max(0.0, min(threshold, 0.999))
min_cluster_size = max(1, int(req.min_cluster_size or 2))
title_model = req.title_model or os.getenv("TOPIC_TITLE_MODEL") or ai_model
output_lang = _normalize_output_lang(req.lang)
assign_noise_threshold = req.assign_noise_threshold
if assign_noise_threshold is None:
# conservative: only reassign if very likely belongs to a large topic
assign_noise_threshold = max(0.72, min(threshold - 0.035, 0.93))
else:
assign_noise_threshold = max(0.0, min(float(assign_noise_threshold), 0.999))
use_title_cache = bool(True if req.use_title_cache is None else req.use_title_cache)
force_title = bool(req.force_title)
batch_size = max(1, min(int(req.batch_size or 64), 256))
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
force = bool(req.force_embed)
for f in folders:
await _build_embeddings_one_folder(
folder=f,
model=model,
force=force,
batch_size=batch_size,
max_chars=max_chars,
progress_cb=None,
)
conn = DataBase.get_conn()
like_prefixes = [os.path.join(f, "%") for f in folders]
with closing(conn.cursor()) as cur:
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
cur.execute(
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
f"""SELECT COUNT(*), MAX(image_embedding.updated_at)
FROM image
INNER JOIN image_embedding ON image_embedding.image_id = image.id
WHERE ({where}) AND image_embedding.model = ?""",
(*like_prefixes, model),
)
rows = cur.fetchall()
row = cur.fetchone() or (0, "")
embeddings_count = int(row[0] or 0)
embeddings_max_updated_at = str(row[1] or "")
items = []
for image_id, path, exif, vec_blob in rows:
if not isinstance(path, str) or not os.path.exists(path):
continue
if not vec_blob:
continue
vec = _blob_to_vec_f32(vec_blob)
n2 = _l2_norm_sq(vec)
if n2 <= 0:
continue
inv = 1.0 / math.sqrt(n2)
for i in range(len(vec)):
vec[i] *= inv
text_raw = _extract_prompt_text(exif, max_chars=int(req.max_chars or 4000))
if _PROMPT_NORMALIZE_ENABLED:
text = _clean_prompt_for_semantic(text_raw)
if not text:
text = text_raw
else:
text = text_raw
items.append({"id": int(image_id), "path": path, "text": text, "vec": vec})
if not items:
return {"folder": folder, "folders": folders, "model": model, "threshold": threshold, "clusters": [], "noise": []}
# Incremental clustering by centroid-direction (sum vector)
clusters = [] # {sum, norm_sq, members:[idx], sample_text}
for idx, it in enumerate(items):
v = it["vec"]
best_ci = -1
best_sim = -1.0
best_dot = 0.0
for ci, c in enumerate(clusters):
dotv = _dot(v, c["sum"])
denom = math.sqrt(c["norm_sq"]) if c["norm_sq"] > 0 else 1.0
sim = dotv / denom
if sim > best_sim:
best_sim = sim
best_ci = ci
best_dot = dotv
if best_ci != -1 and best_sim >= threshold:
c = clusters[best_ci]
for i in range(len(v)):
c["sum"][i] += v[i]
c["norm_sq"] = c["norm_sq"] + 2.0 * best_dot + 1.0
c["members"].append(idx)
else:
clusters.append({"sum": array("f", v), "norm_sq": 1.0, "members": [idx], "sample_text": it.get("text") or ""})
# Merge highly similar clusters (fix: same theme split into multiple clusters)
merge_threshold = min(0.995, max(threshold + 0.04, 0.90))
merged = True
while merged and len(clusters) > 1:
merged = False
best_i = best_j = -1
best_sim = merge_threshold
for i in range(len(clusters)):
ci = clusters[i]
for j in range(i + 1, len(clusters)):
cj = clusters[j]
sim = _cos_sum(ci["sum"], ci["norm_sq"], cj["sum"], cj["norm_sq"])
if sim >= best_sim:
best_sim = sim
best_i, best_j = i, j
if best_i != -1:
a = clusters[best_i]
b = clusters[best_j]
for k in range(len(a["sum"])):
a["sum"][k] += b["sum"][k]
a["norm_sq"] = _l2_norm_sq(a["sum"])
a["members"].extend(b["members"])
if not a.get("sample_text"):
a["sample_text"] = b.get("sample_text", "")
clusters.pop(best_j)
merged = True
# Reassign members from small clusters into best large cluster to reduce noise
if min_cluster_size > 1 and assign_noise_threshold > 0 and clusters:
large = [c for c in clusters if len(c["members"]) >= min_cluster_size]
if large:
new_large = []
# copy large clusters first
for c in clusters:
if len(c["members"]) >= min_cluster_size:
new_large.append(c)
# reassign items from small clusters
for c in clusters:
if len(c["members"]) >= min_cluster_size:
continue
for mi in c["members"]:
v = items[mi]["vec"]
best_ci = -1
best_sim = -1.0
best_dot = 0.0
for ci, bigc in enumerate(new_large):
dotv = _dot(v, bigc["sum"])
denom = math.sqrt(bigc["norm_sq"]) if bigc["norm_sq"] > 0 else 1.0
sim = dotv / denom
if sim > best_sim:
best_sim = sim
best_ci = ci
best_dot = dotv
if best_ci != -1 and best_sim >= assign_noise_threshold:
bigc = new_large[best_ci]
for k in range(len(v)):
bigc["sum"][k] += v[k]
bigc["norm_sq"] = bigc["norm_sq"] + 2.0 * best_dot + 1.0
bigc["members"].append(mi)
# else: keep in small cluster -> will become noise below
clusters = new_large
# Split small clusters to noise, generate titles
out_clusters = []
noise = []
for cidx, c in enumerate(clusters):
if len(c["members"]) < min_cluster_size:
for mi in c["members"]:
noise.append(items[mi]["path"])
continue
member_items = [items[mi] for mi in c["members"]]
paths = [x["path"] for x in member_items]
texts = [x.get("text") or "" for x in member_items]
member_ids = [x["id"] for x in member_items]
# Representative prompt for LLM title generation
rep = (c.get("sample_text") or (texts[0] if texts else "")).strip()
cached = None
cluster_hash = _cluster_sig(
member_ids=member_ids,
model=model,
threshold=threshold,
min_cluster_size=min_cluster_size,
title_model=title_model,
lang=output_lang,
)
if use_title_cache and (not force_title):
cached = TopicTitleCache.get(conn, cluster_hash)
if cached and isinstance(cached, dict) and cached.get("title"):
title = str(cached.get("title"))
keywords = cached.get("keywords") or []
else:
llm = await _call_chat_title(
base_url=openai_base_url,
api_key=openai_api_key,
model=title_model,
prompt_samples=[rep] + texts[:5],
output_lang=output_lang,
)
title = (llm or {}).get("title")
keywords = (llm or {}).get("keywords", [])
if not title:
raise HTTPException(status_code=502, detail="Chat API returned empty title")
if use_title_cache and title:
try:
TopicTitleCache.upsert(conn, cluster_hash, str(title), list(keywords or []), str(title_model))
conn.commit()
except Exception:
pass
out_clusters.append(
{
"id": f"topic_{cidx}",
"title": title,
"keywords": keywords,
"size": len(paths),
"paths": paths,
"sample_prompt": _clean_for_title(rep)[:200],
}
)
out_clusters.sort(key=lambda x: x["size"], reverse=True)
return {
"folder": folder,
"folders": folders,
cache_params = {
"model": model,
"threshold": threshold,
"min_cluster_size": min_cluster_size,
"clusters": out_clusters,
"noise": noise,
"count": len(items),
"title_model": title_model,
"threshold": float(req.threshold or 0.86),
"min_cluster_size": int(req.min_cluster_size or 2),
"assign_noise_threshold": req.assign_noise_threshold,
"title_model": req.title_model,
"lang": str(req.lang or ""),
"nv": _PROMPT_NORMALIZE_VERSION,
"nm": _PROMPT_NORMALIZE_MODE,
}
h = hashlib.sha1()
h.update(json.dumps({"folders": folders, "params": cache_params}, ensure_ascii=False, sort_keys=True).encode("utf-8"))
cache_key = h.hexdigest()
cached = TopicClusterCache.get(conn, cache_key)
if (
cached
and int(cached.get("embeddings_count") or 0) == embeddings_count
and str(cached.get("embeddings_max_updated_at") or "") == embeddings_max_updated_at
and isinstance(cached.get("result"), dict)
):
return cached["result"]
res = await _cluster_after_embeddings(req, folders, progress_cb=None)
try:
TopicClusterCache.upsert(
conn,
cache_key=cache_key,
folders=folders,
model=model,
params=cache_params,
embeddings_count=embeddings_count,
embeddings_max_updated_at=embeddings_max_updated_at,
result=res,
)
conn.commit()
except Exception:
pass
return res