feat: cache embedding failures and persist cluster results
parent
2edf9e52d7
commit
efcb500c53
|
|
@ -82,7 +82,9 @@ class DataBase:
|
|||
DirCoverCache.create_table(conn)
|
||||
GlobalSetting.create_table(conn)
|
||||
ImageEmbedding.create_table(conn)
|
||||
ImageEmbeddingFail.create_table(conn)
|
||||
TopicTitleCache.create_table(conn)
|
||||
TopicClusterCache.create_table(conn)
|
||||
finally:
|
||||
conn.commit()
|
||||
clz.num += 1
|
||||
|
|
@ -383,6 +385,77 @@ class ImageEmbedding:
|
|||
)
|
||||
|
||||
|
||||
class ImageEmbeddingFail:
|
||||
"""
|
||||
Cache embedding failures per image+model+text_hash to avoid repeatedly hitting the API
|
||||
for known-failing inputs. This helps keep clustering/search usable by skipping bad items.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_table(cls, conn: Connection):
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS image_embedding_fail (
|
||||
image_id INTEGER NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
text_hash TEXT NOT NULL,
|
||||
error TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
PRIMARY KEY(image_id, model, text_hash)
|
||||
)"""
|
||||
)
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS image_embedding_fail_idx_model ON image_embedding_fail(model)")
|
||||
|
||||
@classmethod
|
||||
def get_by_image_ids(cls, conn: Connection, image_ids: List[int], model: str) -> Dict[int, Dict]:
|
||||
if not image_ids:
|
||||
return {}
|
||||
ids = [int(x) for x in image_ids]
|
||||
placeholders = ",".join(["?"] * len(ids))
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute(
|
||||
f"SELECT image_id, text_hash, error, updated_at FROM image_embedding_fail WHERE model = ? AND image_id IN ({placeholders})",
|
||||
(str(model), *ids),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
out: Dict[int, Dict] = {}
|
||||
for image_id, text_hash, error, updated_at in rows or []:
|
||||
out[int(image_id)] = {
|
||||
"text_hash": str(text_hash or ""),
|
||||
"error": str(error or ""),
|
||||
"updated_at": str(updated_at or ""),
|
||||
}
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def upsert(
|
||||
cls,
|
||||
conn: Connection,
|
||||
*,
|
||||
image_id: int,
|
||||
model: str,
|
||||
text_hash: str,
|
||||
error: str,
|
||||
updated_at: Optional[str] = None,
|
||||
):
|
||||
updated_at = updated_at or datetime.now().isoformat()
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO image_embedding_fail (image_id, model, text_hash, error, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(image_id, model, text_hash) DO UPDATE SET
|
||||
error = excluded.error,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(int(image_id), str(model), str(text_hash), str(error or "")[:600], str(updated_at)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, conn: Connection, *, image_id: int, model: str):
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute("DELETE FROM image_embedding_fail WHERE image_id = ? AND model = ?", (int(image_id), str(model)))
|
||||
|
||||
|
||||
class TopicTitleCache:
|
||||
"""
|
||||
Cache cluster titles/keywords to avoid repeated LLM calls.
|
||||
|
|
@ -449,6 +522,112 @@ class TopicTitleCache:
|
|||
)
|
||||
|
||||
|
||||
class TopicClusterCache:
|
||||
"""
|
||||
Persist the final clustering result (clusters/noise) to avoid re-clustering when:
|
||||
- embeddings haven't changed (by max(updated_at) & count), and
|
||||
- clustering parameters are unchanged.
|
||||
|
||||
This is intentionally lightweight:
|
||||
- result is stored as JSON text
|
||||
- caller defines cache_key (sha1 over params + folders + normalize version + lang, etc.)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_table(cls, conn: Connection):
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS topic_cluster_cache (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
folders TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
params TEXT NOT NULL,
|
||||
embeddings_count INTEGER NOT NULL,
|
||||
embeddings_max_updated_at TEXT NOT NULL,
|
||||
result TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS topic_cluster_cache_idx_model ON topic_cluster_cache(model)")
|
||||
|
||||
@classmethod
|
||||
def get(cls, conn: Connection, cache_key: str):
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute(
|
||||
"SELECT folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at FROM topic_cluster_cache WHERE cache_key = ?",
|
||||
(cache_key,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at = row
|
||||
try:
|
||||
folders_obj = json.loads(folders) if isinstance(folders, str) else []
|
||||
except Exception:
|
||||
folders_obj = []
|
||||
try:
|
||||
params_obj = json.loads(params) if isinstance(params, str) else {}
|
||||
except Exception:
|
||||
params_obj = {}
|
||||
try:
|
||||
result_obj = json.loads(result) if isinstance(result, str) else None
|
||||
except Exception:
|
||||
result_obj = None
|
||||
return {
|
||||
"cache_key": cache_key,
|
||||
"folders": folders_obj if isinstance(folders_obj, list) else [],
|
||||
"model": str(model),
|
||||
"params": params_obj if isinstance(params_obj, dict) else {},
|
||||
"embeddings_count": int(embeddings_count or 0),
|
||||
"embeddings_max_updated_at": str(embeddings_max_updated_at or ""),
|
||||
"result": result_obj,
|
||||
"updated_at": str(updated_at or ""),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def upsert(
|
||||
cls,
|
||||
conn: Connection,
|
||||
*,
|
||||
cache_key: str,
|
||||
folders: List[str],
|
||||
model: str,
|
||||
params: Dict,
|
||||
embeddings_count: int,
|
||||
embeddings_max_updated_at: str,
|
||||
result: Dict,
|
||||
updated_at: Optional[str] = None,
|
||||
):
|
||||
updated_at = updated_at or datetime.now().isoformat()
|
||||
folders_s = json.dumps([str(x) for x in (folders or [])], ensure_ascii=False)
|
||||
params_s = json.dumps(params or {}, ensure_ascii=False, sort_keys=True)
|
||||
result_s = json.dumps(result or {}, ensure_ascii=False)
|
||||
with closing(conn.cursor()) as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO topic_cluster_cache
|
||||
(cache_key, folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(cache_key) DO UPDATE SET
|
||||
folders = excluded.folders,
|
||||
model = excluded.model,
|
||||
params = excluded.params,
|
||||
embeddings_count = excluded.embeddings_count,
|
||||
embeddings_max_updated_at = excluded.embeddings_max_updated_at,
|
||||
result = excluded.result,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(
|
||||
cache_key,
|
||||
folders_s,
|
||||
str(model),
|
||||
params_s,
|
||||
int(embeddings_count or 0),
|
||||
str(embeddings_max_updated_at or ""),
|
||||
result_s,
|
||||
updated_at,
|
||||
),
|
||||
)
|
||||
|
||||
class Tag:
|
||||
def __init__(self, name: str, score: int, type: str, count=0, color = ""):
|
||||
self.name = name
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import requests
|
|||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, TopicTitleCache
|
||||
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache
|
||||
from scripts.iib.tool import cwd
|
||||
|
||||
|
||||
|
|
@ -634,6 +634,45 @@ def mount_topic_cluster_routes(
|
|||
progress_cb=_embed_cb,
|
||||
)
|
||||
|
||||
# If embeddings didn't change and we have a cached clustering result, return it directly.
|
||||
conn = DataBase.get_conn()
|
||||
like_prefixes = [os.path.join(f, "%") for f in folders]
|
||||
with closing(conn.cursor()) as cur:
|
||||
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
|
||||
cur.execute(
|
||||
f"""SELECT COUNT(*), MAX(image_embedding.updated_at)
|
||||
FROM image
|
||||
INNER JOIN image_embedding ON image_embedding.image_id = image.id
|
||||
WHERE ({where}) AND image_embedding.model = ?""",
|
||||
(*like_prefixes, model),
|
||||
)
|
||||
row = cur.fetchone() or (0, "")
|
||||
embeddings_count = int(row[0] or 0)
|
||||
embeddings_max_updated_at = str(row[1] or "")
|
||||
|
||||
cache_params = {
|
||||
"model": model,
|
||||
"threshold": float(req.threshold or 0.86),
|
||||
"min_cluster_size": int(req.min_cluster_size or 2),
|
||||
"assign_noise_threshold": req.assign_noise_threshold,
|
||||
"title_model": req.title_model,
|
||||
"lang": str(req.lang or ""),
|
||||
"nv": _PROMPT_NORMALIZE_VERSION,
|
||||
"nm": _PROMPT_NORMALIZE_MODE,
|
||||
}
|
||||
h = hashlib.sha1()
|
||||
h.update(json.dumps({"folders": folders, "params": cache_params}, ensure_ascii=False, sort_keys=True).encode("utf-8"))
|
||||
cache_key = h.hexdigest()
|
||||
cached = TopicClusterCache.get(conn, cache_key)
|
||||
if (
|
||||
cached
|
||||
and int(cached.get("embeddings_count") or 0) == embeddings_count
|
||||
and str(cached.get("embeddings_max_updated_at") or "") == embeddings_max_updated_at
|
||||
and isinstance(cached.get("result"), dict)
|
||||
):
|
||||
_job_upsert(job_id, {"status": "done", "stage": "done", "result": cached["result"], "cache_hit": True})
|
||||
return
|
||||
|
||||
# Clustering + titling progress
|
||||
def _cluster_cb(p: Dict) -> None:
|
||||
if not isinstance(p, dict):
|
||||
|
|
@ -657,6 +696,20 @@ def mount_topic_cluster_routes(
|
|||
_job_upsert(job_id, patch)
|
||||
|
||||
res = await _cluster_after_embeddings(req, folders, progress_cb=_cluster_cb)
|
||||
try:
|
||||
TopicClusterCache.upsert(
|
||||
conn,
|
||||
cache_key=cache_key,
|
||||
folders=folders,
|
||||
model=model,
|
||||
params=cache_params,
|
||||
embeddings_count=embeddings_count,
|
||||
embeddings_max_updated_at=embeddings_max_updated_at,
|
||||
result=res,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
_job_upsert(job_id, {"status": "done", "stage": "done", "result": res})
|
||||
except HTTPException as e:
|
||||
_job_upsert(job_id, {"status": "error", "stage": "error", "error": str(e.detail)})
|
||||
|
|
@ -728,13 +781,16 @@ def mount_topic_cluster_routes(
|
|||
|
||||
id_list = [x["id"] for x in images]
|
||||
existing = ImageEmbedding.get_by_image_ids(conn, id_list)
|
||||
existing_fail = ImageEmbeddingFail.get_by_image_ids(conn, id_list, model)
|
||||
|
||||
to_embed = []
|
||||
skipped = 0
|
||||
skipped_failed = 0
|
||||
for item in images:
|
||||
# include normalize version to force refresh when rules change
|
||||
text_hash = ImageEmbedding.compute_text_hash(f"{_PROMPT_NORMALIZE_VERSION}:{item['text']}")
|
||||
old = existing.get(item["id"])
|
||||
old_fail = existing_fail.get(item["id"])
|
||||
if (
|
||||
(not force)
|
||||
and old
|
||||
|
|
@ -744,6 +800,10 @@ def mount_topic_cluster_routes(
|
|||
):
|
||||
skipped += 1
|
||||
continue
|
||||
# Skip known failures for the same model+text_hash (unless force is enabled).
|
||||
if (not force) and old_fail and str(old_fail.get("text_hash") or "") == text_hash:
|
||||
skipped_failed += 1
|
||||
continue
|
||||
to_embed.append({**item, "text_hash": text_hash})
|
||||
|
||||
if progress_cb:
|
||||
|
|
@ -756,11 +816,14 @@ def mount_topic_cluster_routes(
|
|||
"embedded_done": 0,
|
||||
"updated": 0,
|
||||
"skipped": skipped,
|
||||
"skipped_failed": skipped_failed,
|
||||
"failed": 0,
|
||||
}
|
||||
)
|
||||
|
||||
updated = 0
|
||||
embedded_done = 0
|
||||
failed = 0
|
||||
batches = _batched_by_token_budget(
|
||||
to_embed,
|
||||
max_items=batch_size,
|
||||
|
|
@ -774,12 +837,49 @@ def mount_topic_cluster_routes(
|
|||
print(
|
||||
f"[iib][embed] folder={folder} batch={bi+1}/{len(batches)} n={len(inputs)} token_sum~={token_sum} token_max~={token_max}"
|
||||
)
|
||||
vectors = await _call_embeddings(
|
||||
inputs=inputs,
|
||||
model=model,
|
||||
base_url=openai_base_url,
|
||||
api_key=openai_api_key,
|
||||
)
|
||||
try:
|
||||
vectors = await _call_embeddings(
|
||||
inputs=inputs,
|
||||
model=model,
|
||||
base_url=openai_base_url,
|
||||
api_key=openai_api_key,
|
||||
)
|
||||
except HTTPException as e:
|
||||
# Cache failures for this batch and continue (skip these images for now).
|
||||
err = str(e.detail)
|
||||
for it in batch:
|
||||
try:
|
||||
ImageEmbeddingFail.upsert(
|
||||
conn,
|
||||
image_id=int(it["id"]),
|
||||
model=str(model),
|
||||
text_hash=str(it.get("text_hash") or ""),
|
||||
error=err,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
failed += len(batch)
|
||||
if progress_cb:
|
||||
progress_cb(
|
||||
{
|
||||
"stage": "embedding",
|
||||
"folder": folder,
|
||||
"scanned": len(images),
|
||||
"to_embed": len(to_embed),
|
||||
"embedded_done": embedded_done,
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"skipped_failed": skipped_failed,
|
||||
"failed": failed,
|
||||
"batch_n": len(inputs),
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
if len(vectors) != len(batch):
|
||||
raise HTTPException(status_code=500, detail="Embeddings count mismatch")
|
||||
for item, vec in zip(batch, vectors):
|
||||
|
|
@ -791,6 +891,11 @@ def mount_topic_cluster_routes(
|
|||
text_hash=item["text_hash"],
|
||||
vec_blob=_vec_to_blob_f32(vec),
|
||||
)
|
||||
# Success -> clear fail cache for this image+model (any old failure becomes irrelevant).
|
||||
try:
|
||||
ImageEmbeddingFail.delete(conn, image_id=int(item["id"]), model=str(model))
|
||||
except Exception:
|
||||
pass
|
||||
updated += 1
|
||||
embedded_done += 1
|
||||
conn.commit()
|
||||
|
|
@ -804,13 +909,23 @@ def mount_topic_cluster_routes(
|
|||
"embedded_done": embedded_done,
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"skipped_failed": skipped_failed,
|
||||
"failed": failed,
|
||||
"batch_n": len(inputs),
|
||||
}
|
||||
)
|
||||
# yield between batches
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return {"folder": folder, "count": len(images), "updated": updated, "skipped": skipped, "model": model}
|
||||
return {
|
||||
"folder": folder,
|
||||
"count": len(images),
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"skipped_failed": skipped_failed,
|
||||
"failed": failed,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
class BuildIibOutputEmbeddingReq(BaseModel):
|
||||
folder: Optional[str] = None # default: {cwd}/iib_output
|
||||
|
|
@ -1312,250 +1427,76 @@ def mount_topic_cluster_routes(
|
|||
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
|
||||
)
|
||||
async def cluster_iib_output(req: ClusterIibOutputReq):
|
||||
if not openai_api_key:
|
||||
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
|
||||
if not openai_base_url:
|
||||
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
|
||||
folders: List[str] = []
|
||||
if req.folder_paths:
|
||||
for p in req.folder_paths:
|
||||
if isinstance(p, str) and p.strip():
|
||||
folders.append(os.path.normpath(p.strip()))
|
||||
if req.folder and isinstance(req.folder, str) and req.folder.strip():
|
||||
folders.append(os.path.normpath(req.folder.strip()))
|
||||
# 用户不会用默认 iib_output:未指定范围则直接报错
|
||||
if not folders:
|
||||
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to cluster)")
|
||||
# Keep this endpoint for compatibility, but avoid duplicating the heavy logic.
|
||||
# If embeddings haven't changed and we have a cached clustering result, return it directly.
|
||||
folders = _extract_and_validate_folders(req)
|
||||
|
||||
folders = list(dict.fromkeys(folders))
|
||||
for f in folders:
|
||||
if not os.path.exists(f) or not os.path.isdir(f):
|
||||
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
|
||||
|
||||
# Ensure embeddings exist (incremental per folder)
|
||||
for f in folders:
|
||||
await build_iib_output_embeddings(
|
||||
BuildIibOutputEmbeddingReq(
|
||||
folder=f,
|
||||
model=req.model,
|
||||
force=req.force_embed,
|
||||
batch_size=req.batch_size,
|
||||
max_chars=req.max_chars,
|
||||
)
|
||||
)
|
||||
|
||||
folder = folders[0]
|
||||
model = req.model or embedding_model
|
||||
threshold = float(req.threshold or 0.86)
|
||||
threshold = max(0.0, min(threshold, 0.999))
|
||||
min_cluster_size = max(1, int(req.min_cluster_size or 2))
|
||||
title_model = req.title_model or os.getenv("TOPIC_TITLE_MODEL") or ai_model
|
||||
output_lang = _normalize_output_lang(req.lang)
|
||||
assign_noise_threshold = req.assign_noise_threshold
|
||||
if assign_noise_threshold is None:
|
||||
# conservative: only reassign if very likely belongs to a large topic
|
||||
assign_noise_threshold = max(0.72, min(threshold - 0.035, 0.93))
|
||||
else:
|
||||
assign_noise_threshold = max(0.0, min(float(assign_noise_threshold), 0.999))
|
||||
use_title_cache = bool(True if req.use_title_cache is None else req.use_title_cache)
|
||||
force_title = bool(req.force_title)
|
||||
batch_size = max(1, min(int(req.batch_size or 64), 256))
|
||||
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
|
||||
force = bool(req.force_embed)
|
||||
for f in folders:
|
||||
await _build_embeddings_one_folder(
|
||||
folder=f,
|
||||
model=model,
|
||||
force=force,
|
||||
batch_size=batch_size,
|
||||
max_chars=max_chars,
|
||||
progress_cb=None,
|
||||
)
|
||||
|
||||
conn = DataBase.get_conn()
|
||||
like_prefixes = [os.path.join(f, "%") for f in folders]
|
||||
with closing(conn.cursor()) as cur:
|
||||
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
|
||||
cur.execute(
|
||||
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
|
||||
f"""SELECT COUNT(*), MAX(image_embedding.updated_at)
|
||||
FROM image
|
||||
INNER JOIN image_embedding ON image_embedding.image_id = image.id
|
||||
WHERE ({where}) AND image_embedding.model = ?""",
|
||||
(*like_prefixes, model),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
row = cur.fetchone() or (0, "")
|
||||
embeddings_count = int(row[0] or 0)
|
||||
embeddings_max_updated_at = str(row[1] or "")
|
||||
|
||||
items = []
|
||||
for image_id, path, exif, vec_blob in rows:
|
||||
if not isinstance(path, str) or not os.path.exists(path):
|
||||
continue
|
||||
if not vec_blob:
|
||||
continue
|
||||
vec = _blob_to_vec_f32(vec_blob)
|
||||
n2 = _l2_norm_sq(vec)
|
||||
if n2 <= 0:
|
||||
continue
|
||||
inv = 1.0 / math.sqrt(n2)
|
||||
for i in range(len(vec)):
|
||||
vec[i] *= inv
|
||||
text_raw = _extract_prompt_text(exif, max_chars=int(req.max_chars or 4000))
|
||||
if _PROMPT_NORMALIZE_ENABLED:
|
||||
text = _clean_prompt_for_semantic(text_raw)
|
||||
if not text:
|
||||
text = text_raw
|
||||
else:
|
||||
text = text_raw
|
||||
items.append({"id": int(image_id), "path": path, "text": text, "vec": vec})
|
||||
|
||||
if not items:
|
||||
return {"folder": folder, "folders": folders, "model": model, "threshold": threshold, "clusters": [], "noise": []}
|
||||
|
||||
# Incremental clustering by centroid-direction (sum vector)
|
||||
clusters = [] # {sum, norm_sq, members:[idx], sample_text}
|
||||
for idx, it in enumerate(items):
|
||||
v = it["vec"]
|
||||
best_ci = -1
|
||||
best_sim = -1.0
|
||||
best_dot = 0.0
|
||||
for ci, c in enumerate(clusters):
|
||||
dotv = _dot(v, c["sum"])
|
||||
denom = math.sqrt(c["norm_sq"]) if c["norm_sq"] > 0 else 1.0
|
||||
sim = dotv / denom
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_ci = ci
|
||||
best_dot = dotv
|
||||
if best_ci != -1 and best_sim >= threshold:
|
||||
c = clusters[best_ci]
|
||||
for i in range(len(v)):
|
||||
c["sum"][i] += v[i]
|
||||
c["norm_sq"] = c["norm_sq"] + 2.0 * best_dot + 1.0
|
||||
c["members"].append(idx)
|
||||
else:
|
||||
clusters.append({"sum": array("f", v), "norm_sq": 1.0, "members": [idx], "sample_text": it.get("text") or ""})
|
||||
|
||||
# Merge highly similar clusters (fix: same theme split into multiple clusters)
|
||||
merge_threshold = min(0.995, max(threshold + 0.04, 0.90))
|
||||
merged = True
|
||||
while merged and len(clusters) > 1:
|
||||
merged = False
|
||||
best_i = best_j = -1
|
||||
best_sim = merge_threshold
|
||||
for i in range(len(clusters)):
|
||||
ci = clusters[i]
|
||||
for j in range(i + 1, len(clusters)):
|
||||
cj = clusters[j]
|
||||
sim = _cos_sum(ci["sum"], ci["norm_sq"], cj["sum"], cj["norm_sq"])
|
||||
if sim >= best_sim:
|
||||
best_sim = sim
|
||||
best_i, best_j = i, j
|
||||
if best_i != -1:
|
||||
a = clusters[best_i]
|
||||
b = clusters[best_j]
|
||||
for k in range(len(a["sum"])):
|
||||
a["sum"][k] += b["sum"][k]
|
||||
a["norm_sq"] = _l2_norm_sq(a["sum"])
|
||||
a["members"].extend(b["members"])
|
||||
if not a.get("sample_text"):
|
||||
a["sample_text"] = b.get("sample_text", "")
|
||||
clusters.pop(best_j)
|
||||
merged = True
|
||||
|
||||
# Reassign members from small clusters into best large cluster to reduce noise
|
||||
if min_cluster_size > 1 and assign_noise_threshold > 0 and clusters:
|
||||
large = [c for c in clusters if len(c["members"]) >= min_cluster_size]
|
||||
if large:
|
||||
new_large = []
|
||||
# copy large clusters first
|
||||
for c in clusters:
|
||||
if len(c["members"]) >= min_cluster_size:
|
||||
new_large.append(c)
|
||||
# reassign items from small clusters
|
||||
for c in clusters:
|
||||
if len(c["members"]) >= min_cluster_size:
|
||||
continue
|
||||
for mi in c["members"]:
|
||||
v = items[mi]["vec"]
|
||||
best_ci = -1
|
||||
best_sim = -1.0
|
||||
best_dot = 0.0
|
||||
for ci, bigc in enumerate(new_large):
|
||||
dotv = _dot(v, bigc["sum"])
|
||||
denom = math.sqrt(bigc["norm_sq"]) if bigc["norm_sq"] > 0 else 1.0
|
||||
sim = dotv / denom
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_ci = ci
|
||||
best_dot = dotv
|
||||
if best_ci != -1 and best_sim >= assign_noise_threshold:
|
||||
bigc = new_large[best_ci]
|
||||
for k in range(len(v)):
|
||||
bigc["sum"][k] += v[k]
|
||||
bigc["norm_sq"] = bigc["norm_sq"] + 2.0 * best_dot + 1.0
|
||||
bigc["members"].append(mi)
|
||||
# else: keep in small cluster -> will become noise below
|
||||
clusters = new_large
|
||||
|
||||
# Split small clusters to noise, generate titles
|
||||
out_clusters = []
|
||||
noise = []
|
||||
for cidx, c in enumerate(clusters):
|
||||
if len(c["members"]) < min_cluster_size:
|
||||
for mi in c["members"]:
|
||||
noise.append(items[mi]["path"])
|
||||
continue
|
||||
|
||||
member_items = [items[mi] for mi in c["members"]]
|
||||
paths = [x["path"] for x in member_items]
|
||||
texts = [x.get("text") or "" for x in member_items]
|
||||
member_ids = [x["id"] for x in member_items]
|
||||
|
||||
# Representative prompt for LLM title generation
|
||||
rep = (c.get("sample_text") or (texts[0] if texts else "")).strip()
|
||||
|
||||
cached = None
|
||||
cluster_hash = _cluster_sig(
|
||||
member_ids=member_ids,
|
||||
model=model,
|
||||
threshold=threshold,
|
||||
min_cluster_size=min_cluster_size,
|
||||
title_model=title_model,
|
||||
lang=output_lang,
|
||||
)
|
||||
if use_title_cache and (not force_title):
|
||||
cached = TopicTitleCache.get(conn, cluster_hash)
|
||||
if cached and isinstance(cached, dict) and cached.get("title"):
|
||||
title = str(cached.get("title"))
|
||||
keywords = cached.get("keywords") or []
|
||||
else:
|
||||
llm = await _call_chat_title(
|
||||
base_url=openai_base_url,
|
||||
api_key=openai_api_key,
|
||||
model=title_model,
|
||||
prompt_samples=[rep] + texts[:5],
|
||||
output_lang=output_lang,
|
||||
)
|
||||
title = (llm or {}).get("title")
|
||||
keywords = (llm or {}).get("keywords", [])
|
||||
if not title:
|
||||
raise HTTPException(status_code=502, detail="Chat API returned empty title")
|
||||
if use_title_cache and title:
|
||||
try:
|
||||
TopicTitleCache.upsert(conn, cluster_hash, str(title), list(keywords or []), str(title_model))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out_clusters.append(
|
||||
{
|
||||
"id": f"topic_{cidx}",
|
||||
"title": title,
|
||||
"keywords": keywords,
|
||||
"size": len(paths),
|
||||
"paths": paths,
|
||||
"sample_prompt": _clean_for_title(rep)[:200],
|
||||
}
|
||||
)
|
||||
|
||||
out_clusters.sort(key=lambda x: x["size"], reverse=True)
|
||||
return {
|
||||
"folder": folder,
|
||||
"folders": folders,
|
||||
cache_params = {
|
||||
"model": model,
|
||||
"threshold": threshold,
|
||||
"min_cluster_size": min_cluster_size,
|
||||
"clusters": out_clusters,
|
||||
"noise": noise,
|
||||
"count": len(items),
|
||||
"title_model": title_model,
|
||||
"threshold": float(req.threshold or 0.86),
|
||||
"min_cluster_size": int(req.min_cluster_size or 2),
|
||||
"assign_noise_threshold": req.assign_noise_threshold,
|
||||
"title_model": req.title_model,
|
||||
"lang": str(req.lang or ""),
|
||||
"nv": _PROMPT_NORMALIZE_VERSION,
|
||||
"nm": _PROMPT_NORMALIZE_MODE,
|
||||
}
|
||||
h = hashlib.sha1()
|
||||
h.update(json.dumps({"folders": folders, "params": cache_params}, ensure_ascii=False, sort_keys=True).encode("utf-8"))
|
||||
cache_key = h.hexdigest()
|
||||
cached = TopicClusterCache.get(conn, cache_key)
|
||||
if (
|
||||
cached
|
||||
and int(cached.get("embeddings_count") or 0) == embeddings_count
|
||||
and str(cached.get("embeddings_max_updated_at") or "") == embeddings_max_updated_at
|
||||
and isinstance(cached.get("result"), dict)
|
||||
):
|
||||
return cached["result"]
|
||||
|
||||
res = await _cluster_after_embeddings(req, folders, progress_cb=None)
|
||||
try:
|
||||
TopicClusterCache.upsert(
|
||||
conn,
|
||||
cache_key=cache_key,
|
||||
folders=folders,
|
||||
model=model,
|
||||
params=cache_params,
|
||||
embeddings_count=embeddings_count,
|
||||
embeddings_max_updated_at=embeddings_max_updated_at,
|
||||
result=res,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue