diff --git a/scripts/iib/db/datamodel.py b/scripts/iib/db/datamodel.py index e950f12..e597a53 100644 --- a/scripts/iib/db/datamodel.py +++ b/scripts/iib/db/datamodel.py @@ -82,7 +82,9 @@ class DataBase: DirCoverCache.create_table(conn) GlobalSetting.create_table(conn) ImageEmbedding.create_table(conn) + ImageEmbeddingFail.create_table(conn) TopicTitleCache.create_table(conn) + TopicClusterCache.create_table(conn) finally: conn.commit() clz.num += 1 @@ -383,6 +385,77 @@ class ImageEmbedding: ) +class ImageEmbeddingFail: + """ + Cache embedding failures per image+model+text_hash to avoid repeatedly hitting the API + for known-failing inputs. This helps keep clustering/search usable by skipping bad items. + """ + + @classmethod + def create_table(cls, conn: Connection): + with closing(conn.cursor()) as cur: + cur.execute( + """CREATE TABLE IF NOT EXISTS image_embedding_fail ( + image_id INTEGER NOT NULL, + model TEXT NOT NULL, + text_hash TEXT NOT NULL, + error TEXT NOT NULL, + updated_at TEXT NOT NULL, + PRIMARY KEY(image_id, model, text_hash) + )""" + ) + cur.execute("CREATE INDEX IF NOT EXISTS image_embedding_fail_idx_model ON image_embedding_fail(model)") + + @classmethod + def get_by_image_ids(cls, conn: Connection, image_ids: List[int], model: str) -> Dict[int, Dict]: + if not image_ids: + return {} + ids = [int(x) for x in image_ids] + placeholders = ",".join(["?"] * len(ids)) + with closing(conn.cursor()) as cur: + cur.execute( + f"SELECT image_id, text_hash, error, updated_at FROM image_embedding_fail WHERE model = ? AND image_id IN ({placeholders})", + (str(model), *ids), + ) + rows = cur.fetchall() + out: Dict[int, Dict] = {} + for image_id, text_hash, error, updated_at in rows or []: + out[int(image_id)] = { + "text_hash": str(text_hash or ""), + "error": str(error or ""), + "updated_at": str(updated_at or ""), + } + return out + + @classmethod + def upsert( + cls, + conn: Connection, + *, + image_id: int, + model: str, + text_hash: str, + error: str, + updated_at: Optional[str] = None, + ): + updated_at = updated_at or datetime.now().isoformat() + with closing(conn.cursor()) as cur: + cur.execute( + """INSERT INTO image_embedding_fail (image_id, model, text_hash, error, updated_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(image_id, model, text_hash) DO UPDATE SET + error = excluded.error, + updated_at = excluded.updated_at + """, + (int(image_id), str(model), str(text_hash), str(error or "")[:600], str(updated_at)), + ) + + @classmethod + def delete(cls, conn: Connection, *, image_id: int, model: str): + with closing(conn.cursor()) as cur: + cur.execute("DELETE FROM image_embedding_fail WHERE image_id = ? AND model = ?", (int(image_id), str(model))) + + class TopicTitleCache: """ Cache cluster titles/keywords to avoid repeated LLM calls. @@ -449,6 +522,112 @@ class TopicTitleCache: ) +class TopicClusterCache: + """ + Persist the final clustering result (clusters/noise) to avoid re-clustering when: + - embeddings haven't changed (by max(updated_at) & count), and + - clustering parameters are unchanged. + + This is intentionally lightweight: + - result is stored as JSON text + - caller defines cache_key (sha1 over params + folders + normalize version + lang, etc.) + """ + + @classmethod + def create_table(cls, conn: Connection): + with closing(conn.cursor()) as cur: + cur.execute( + """CREATE TABLE IF NOT EXISTS topic_cluster_cache ( + cache_key TEXT PRIMARY KEY, + folders TEXT NOT NULL, + model TEXT NOT NULL, + params TEXT NOT NULL, + embeddings_count INTEGER NOT NULL, + embeddings_max_updated_at TEXT NOT NULL, + result TEXT NOT NULL, + updated_at TEXT NOT NULL + )""" + ) + cur.execute("CREATE INDEX IF NOT EXISTS topic_cluster_cache_idx_model ON topic_cluster_cache(model)") + + @classmethod + def get(cls, conn: Connection, cache_key: str): + with closing(conn.cursor()) as cur: + cur.execute( + "SELECT folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at FROM topic_cluster_cache WHERE cache_key = ?", + (cache_key,), + ) + row = cur.fetchone() + if not row: + return None + folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at = row + try: + folders_obj = json.loads(folders) if isinstance(folders, str) else [] + except Exception: + folders_obj = [] + try: + params_obj = json.loads(params) if isinstance(params, str) else {} + except Exception: + params_obj = {} + try: + result_obj = json.loads(result) if isinstance(result, str) else None + except Exception: + result_obj = None + return { + "cache_key": cache_key, + "folders": folders_obj if isinstance(folders_obj, list) else [], + "model": str(model), + "params": params_obj if isinstance(params_obj, dict) else {}, + "embeddings_count": int(embeddings_count or 0), + "embeddings_max_updated_at": str(embeddings_max_updated_at or ""), + "result": result_obj, + "updated_at": str(updated_at or ""), + } + + @classmethod + def upsert( + cls, + conn: Connection, + *, + cache_key: str, + folders: List[str], + model: str, + params: Dict, + embeddings_count: int, + embeddings_max_updated_at: str, + result: Dict, + updated_at: Optional[str] = None, + ): + updated_at = updated_at or datetime.now().isoformat() + folders_s = json.dumps([str(x) for x in (folders or [])], ensure_ascii=False) + params_s = json.dumps(params or {}, ensure_ascii=False, sort_keys=True) + result_s = json.dumps(result or {}, ensure_ascii=False) + with closing(conn.cursor()) as cur: + cur.execute( + """INSERT INTO topic_cluster_cache + (cache_key, folders, model, params, embeddings_count, embeddings_max_updated_at, result, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(cache_key) DO UPDATE SET + folders = excluded.folders, + model = excluded.model, + params = excluded.params, + embeddings_count = excluded.embeddings_count, + embeddings_max_updated_at = excluded.embeddings_max_updated_at, + result = excluded.result, + updated_at = excluded.updated_at + """, + ( + cache_key, + folders_s, + str(model), + params_s, + int(embeddings_count or 0), + str(embeddings_max_updated_at or ""), + result_s, + updated_at, + ), + ) + class Tag: def __init__(self, name: str, score: int, type: str, count=0, color = ""): self.name = name diff --git a/scripts/iib/topic_cluster.py b/scripts/iib/topic_cluster.py index cfb551b..d581781 100644 --- a/scripts/iib/topic_cluster.py +++ b/scripts/iib/topic_cluster.py @@ -15,7 +15,7 @@ import requests from fastapi import Depends, FastAPI, HTTPException from pydantic import BaseModel -from scripts.iib.db.datamodel import DataBase, ImageEmbedding, TopicTitleCache +from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache from scripts.iib.tool import cwd @@ -634,6 +634,45 @@ def mount_topic_cluster_routes( progress_cb=_embed_cb, ) + # If embeddings didn't change and we have a cached clustering result, return it directly. + conn = DataBase.get_conn() + like_prefixes = [os.path.join(f, "%") for f in folders] + with closing(conn.cursor()) as cur: + where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes)) + cur.execute( + f"""SELECT COUNT(*), MAX(image_embedding.updated_at) + FROM image + INNER JOIN image_embedding ON image_embedding.image_id = image.id + WHERE ({where}) AND image_embedding.model = ?""", + (*like_prefixes, model), + ) + row = cur.fetchone() or (0, "") + embeddings_count = int(row[0] or 0) + embeddings_max_updated_at = str(row[1] or "") + + cache_params = { + "model": model, + "threshold": float(req.threshold or 0.86), + "min_cluster_size": int(req.min_cluster_size or 2), + "assign_noise_threshold": req.assign_noise_threshold, + "title_model": req.title_model, + "lang": str(req.lang or ""), + "nv": _PROMPT_NORMALIZE_VERSION, + "nm": _PROMPT_NORMALIZE_MODE, + } + h = hashlib.sha1() + h.update(json.dumps({"folders": folders, "params": cache_params}, ensure_ascii=False, sort_keys=True).encode("utf-8")) + cache_key = h.hexdigest() + cached = TopicClusterCache.get(conn, cache_key) + if ( + cached + and int(cached.get("embeddings_count") or 0) == embeddings_count + and str(cached.get("embeddings_max_updated_at") or "") == embeddings_max_updated_at + and isinstance(cached.get("result"), dict) + ): + _job_upsert(job_id, {"status": "done", "stage": "done", "result": cached["result"], "cache_hit": True}) + return + # Clustering + titling progress def _cluster_cb(p: Dict) -> None: if not isinstance(p, dict): @@ -657,6 +696,20 @@ def mount_topic_cluster_routes( _job_upsert(job_id, patch) res = await _cluster_after_embeddings(req, folders, progress_cb=_cluster_cb) + try: + TopicClusterCache.upsert( + conn, + cache_key=cache_key, + folders=folders, + model=model, + params=cache_params, + embeddings_count=embeddings_count, + embeddings_max_updated_at=embeddings_max_updated_at, + result=res, + ) + conn.commit() + except Exception: + pass _job_upsert(job_id, {"status": "done", "stage": "done", "result": res}) except HTTPException as e: _job_upsert(job_id, {"status": "error", "stage": "error", "error": str(e.detail)}) @@ -728,13 +781,16 @@ def mount_topic_cluster_routes( id_list = [x["id"] for x in images] existing = ImageEmbedding.get_by_image_ids(conn, id_list) + existing_fail = ImageEmbeddingFail.get_by_image_ids(conn, id_list, model) to_embed = [] skipped = 0 + skipped_failed = 0 for item in images: # include normalize version to force refresh when rules change text_hash = ImageEmbedding.compute_text_hash(f"{_PROMPT_NORMALIZE_VERSION}:{item['text']}") old = existing.get(item["id"]) + old_fail = existing_fail.get(item["id"]) if ( (not force) and old @@ -744,6 +800,10 @@ def mount_topic_cluster_routes( ): skipped += 1 continue + # Skip known failures for the same model+text_hash (unless force is enabled). + if (not force) and old_fail and str(old_fail.get("text_hash") or "") == text_hash: + skipped_failed += 1 + continue to_embed.append({**item, "text_hash": text_hash}) if progress_cb: @@ -756,11 +816,14 @@ def mount_topic_cluster_routes( "embedded_done": 0, "updated": 0, "skipped": skipped, + "skipped_failed": skipped_failed, + "failed": 0, } ) updated = 0 embedded_done = 0 + failed = 0 batches = _batched_by_token_budget( to_embed, max_items=batch_size, @@ -774,12 +837,49 @@ def mount_topic_cluster_routes( print( f"[iib][embed] folder={folder} batch={bi+1}/{len(batches)} n={len(inputs)} token_sum~={token_sum} token_max~={token_max}" ) - vectors = await _call_embeddings( - inputs=inputs, - model=model, - base_url=openai_base_url, - api_key=openai_api_key, - ) + try: + vectors = await _call_embeddings( + inputs=inputs, + model=model, + base_url=openai_base_url, + api_key=openai_api_key, + ) + except HTTPException as e: + # Cache failures for this batch and continue (skip these images for now). + err = str(e.detail) + for it in batch: + try: + ImageEmbeddingFail.upsert( + conn, + image_id=int(it["id"]), + model=str(model), + text_hash=str(it.get("text_hash") or ""), + error=err, + ) + except Exception: + pass + try: + conn.commit() + except Exception: + pass + failed += len(batch) + if progress_cb: + progress_cb( + { + "stage": "embedding", + "folder": folder, + "scanned": len(images), + "to_embed": len(to_embed), + "embedded_done": embedded_done, + "updated": updated, + "skipped": skipped, + "skipped_failed": skipped_failed, + "failed": failed, + "batch_n": len(inputs), + } + ) + await asyncio.sleep(0) + continue if len(vectors) != len(batch): raise HTTPException(status_code=500, detail="Embeddings count mismatch") for item, vec in zip(batch, vectors): @@ -791,6 +891,11 @@ def mount_topic_cluster_routes( text_hash=item["text_hash"], vec_blob=_vec_to_blob_f32(vec), ) + # Success -> clear fail cache for this image+model (any old failure becomes irrelevant). + try: + ImageEmbeddingFail.delete(conn, image_id=int(item["id"]), model=str(model)) + except Exception: + pass updated += 1 embedded_done += 1 conn.commit() @@ -804,13 +909,23 @@ def mount_topic_cluster_routes( "embedded_done": embedded_done, "updated": updated, "skipped": skipped, + "skipped_failed": skipped_failed, + "failed": failed, "batch_n": len(inputs), } ) # yield between batches await asyncio.sleep(0) - return {"folder": folder, "count": len(images), "updated": updated, "skipped": skipped, "model": model} + return { + "folder": folder, + "count": len(images), + "updated": updated, + "skipped": skipped, + "skipped_failed": skipped_failed, + "failed": failed, + "model": model, + } class BuildIibOutputEmbeddingReq(BaseModel): folder: Optional[str] = None # default: {cwd}/iib_output @@ -1312,250 +1427,76 @@ def mount_topic_cluster_routes( dependencies=[Depends(verify_secret), Depends(write_permission_required)], ) async def cluster_iib_output(req: ClusterIibOutputReq): - if not openai_api_key: - raise HTTPException(status_code=500, detail="OpenAI API Key not configured") - if not openai_base_url: - raise HTTPException(status_code=500, detail="OpenAI Base URL not configured") - folders: List[str] = [] - if req.folder_paths: - for p in req.folder_paths: - if isinstance(p, str) and p.strip(): - folders.append(os.path.normpath(p.strip())) - if req.folder and isinstance(req.folder, str) and req.folder.strip(): - folders.append(os.path.normpath(req.folder.strip())) - # 用户不会用默认 iib_output:未指定范围则直接报错 - if not folders: - raise HTTPException(status_code=400, detail="folder_paths is required (select folders to cluster)") + # Keep this endpoint for compatibility, but avoid duplicating the heavy logic. + # If embeddings haven't changed and we have a cached clustering result, return it directly. + folders = _extract_and_validate_folders(req) - folders = list(dict.fromkeys(folders)) - for f in folders: - if not os.path.exists(f) or not os.path.isdir(f): - raise HTTPException(status_code=400, detail=f"Folder not found: {f}") - - # Ensure embeddings exist (incremental per folder) - for f in folders: - await build_iib_output_embeddings( - BuildIibOutputEmbeddingReq( - folder=f, - model=req.model, - force=req.force_embed, - batch_size=req.batch_size, - max_chars=req.max_chars, - ) - ) - - folder = folders[0] model = req.model or embedding_model - threshold = float(req.threshold or 0.86) - threshold = max(0.0, min(threshold, 0.999)) - min_cluster_size = max(1, int(req.min_cluster_size or 2)) - title_model = req.title_model or os.getenv("TOPIC_TITLE_MODEL") or ai_model - output_lang = _normalize_output_lang(req.lang) - assign_noise_threshold = req.assign_noise_threshold - if assign_noise_threshold is None: - # conservative: only reassign if very likely belongs to a large topic - assign_noise_threshold = max(0.72, min(threshold - 0.035, 0.93)) - else: - assign_noise_threshold = max(0.0, min(float(assign_noise_threshold), 0.999)) - use_title_cache = bool(True if req.use_title_cache is None else req.use_title_cache) - force_title = bool(req.force_title) + batch_size = max(1, min(int(req.batch_size or 64), 256)) + max_chars = max(256, min(int(req.max_chars or 4000), 8000)) + force = bool(req.force_embed) + for f in folders: + await _build_embeddings_one_folder( + folder=f, + model=model, + force=force, + batch_size=batch_size, + max_chars=max_chars, + progress_cb=None, + ) conn = DataBase.get_conn() like_prefixes = [os.path.join(f, "%") for f in folders] with closing(conn.cursor()) as cur: where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes)) cur.execute( - f"""SELECT image.id, image.path, image.exif, image_embedding.vec + f"""SELECT COUNT(*), MAX(image_embedding.updated_at) FROM image INNER JOIN image_embedding ON image_embedding.image_id = image.id WHERE ({where}) AND image_embedding.model = ?""", (*like_prefixes, model), ) - rows = cur.fetchall() + row = cur.fetchone() or (0, "") + embeddings_count = int(row[0] or 0) + embeddings_max_updated_at = str(row[1] or "") - items = [] - for image_id, path, exif, vec_blob in rows: - if not isinstance(path, str) or not os.path.exists(path): - continue - if not vec_blob: - continue - vec = _blob_to_vec_f32(vec_blob) - n2 = _l2_norm_sq(vec) - if n2 <= 0: - continue - inv = 1.0 / math.sqrt(n2) - for i in range(len(vec)): - vec[i] *= inv - text_raw = _extract_prompt_text(exif, max_chars=int(req.max_chars or 4000)) - if _PROMPT_NORMALIZE_ENABLED: - text = _clean_prompt_for_semantic(text_raw) - if not text: - text = text_raw - else: - text = text_raw - items.append({"id": int(image_id), "path": path, "text": text, "vec": vec}) - - if not items: - return {"folder": folder, "folders": folders, "model": model, "threshold": threshold, "clusters": [], "noise": []} - - # Incremental clustering by centroid-direction (sum vector) - clusters = [] # {sum, norm_sq, members:[idx], sample_text} - for idx, it in enumerate(items): - v = it["vec"] - best_ci = -1 - best_sim = -1.0 - best_dot = 0.0 - for ci, c in enumerate(clusters): - dotv = _dot(v, c["sum"]) - denom = math.sqrt(c["norm_sq"]) if c["norm_sq"] > 0 else 1.0 - sim = dotv / denom - if sim > best_sim: - best_sim = sim - best_ci = ci - best_dot = dotv - if best_ci != -1 and best_sim >= threshold: - c = clusters[best_ci] - for i in range(len(v)): - c["sum"][i] += v[i] - c["norm_sq"] = c["norm_sq"] + 2.0 * best_dot + 1.0 - c["members"].append(idx) - else: - clusters.append({"sum": array("f", v), "norm_sq": 1.0, "members": [idx], "sample_text": it.get("text") or ""}) - - # Merge highly similar clusters (fix: same theme split into multiple clusters) - merge_threshold = min(0.995, max(threshold + 0.04, 0.90)) - merged = True - while merged and len(clusters) > 1: - merged = False - best_i = best_j = -1 - best_sim = merge_threshold - for i in range(len(clusters)): - ci = clusters[i] - for j in range(i + 1, len(clusters)): - cj = clusters[j] - sim = _cos_sum(ci["sum"], ci["norm_sq"], cj["sum"], cj["norm_sq"]) - if sim >= best_sim: - best_sim = sim - best_i, best_j = i, j - if best_i != -1: - a = clusters[best_i] - b = clusters[best_j] - for k in range(len(a["sum"])): - a["sum"][k] += b["sum"][k] - a["norm_sq"] = _l2_norm_sq(a["sum"]) - a["members"].extend(b["members"]) - if not a.get("sample_text"): - a["sample_text"] = b.get("sample_text", "") - clusters.pop(best_j) - merged = True - - # Reassign members from small clusters into best large cluster to reduce noise - if min_cluster_size > 1 and assign_noise_threshold > 0 and clusters: - large = [c for c in clusters if len(c["members"]) >= min_cluster_size] - if large: - new_large = [] - # copy large clusters first - for c in clusters: - if len(c["members"]) >= min_cluster_size: - new_large.append(c) - # reassign items from small clusters - for c in clusters: - if len(c["members"]) >= min_cluster_size: - continue - for mi in c["members"]: - v = items[mi]["vec"] - best_ci = -1 - best_sim = -1.0 - best_dot = 0.0 - for ci, bigc in enumerate(new_large): - dotv = _dot(v, bigc["sum"]) - denom = math.sqrt(bigc["norm_sq"]) if bigc["norm_sq"] > 0 else 1.0 - sim = dotv / denom - if sim > best_sim: - best_sim = sim - best_ci = ci - best_dot = dotv - if best_ci != -1 and best_sim >= assign_noise_threshold: - bigc = new_large[best_ci] - for k in range(len(v)): - bigc["sum"][k] += v[k] - bigc["norm_sq"] = bigc["norm_sq"] + 2.0 * best_dot + 1.0 - bigc["members"].append(mi) - # else: keep in small cluster -> will become noise below - clusters = new_large - - # Split small clusters to noise, generate titles - out_clusters = [] - noise = [] - for cidx, c in enumerate(clusters): - if len(c["members"]) < min_cluster_size: - for mi in c["members"]: - noise.append(items[mi]["path"]) - continue - - member_items = [items[mi] for mi in c["members"]] - paths = [x["path"] for x in member_items] - texts = [x.get("text") or "" for x in member_items] - member_ids = [x["id"] for x in member_items] - - # Representative prompt for LLM title generation - rep = (c.get("sample_text") or (texts[0] if texts else "")).strip() - - cached = None - cluster_hash = _cluster_sig( - member_ids=member_ids, - model=model, - threshold=threshold, - min_cluster_size=min_cluster_size, - title_model=title_model, - lang=output_lang, - ) - if use_title_cache and (not force_title): - cached = TopicTitleCache.get(conn, cluster_hash) - if cached and isinstance(cached, dict) and cached.get("title"): - title = str(cached.get("title")) - keywords = cached.get("keywords") or [] - else: - llm = await _call_chat_title( - base_url=openai_base_url, - api_key=openai_api_key, - model=title_model, - prompt_samples=[rep] + texts[:5], - output_lang=output_lang, - ) - title = (llm or {}).get("title") - keywords = (llm or {}).get("keywords", []) - if not title: - raise HTTPException(status_code=502, detail="Chat API returned empty title") - if use_title_cache and title: - try: - TopicTitleCache.upsert(conn, cluster_hash, str(title), list(keywords or []), str(title_model)) - conn.commit() - except Exception: - pass - - out_clusters.append( - { - "id": f"topic_{cidx}", - "title": title, - "keywords": keywords, - "size": len(paths), - "paths": paths, - "sample_prompt": _clean_for_title(rep)[:200], - } - ) - - out_clusters.sort(key=lambda x: x["size"], reverse=True) - return { - "folder": folder, - "folders": folders, + cache_params = { "model": model, - "threshold": threshold, - "min_cluster_size": min_cluster_size, - "clusters": out_clusters, - "noise": noise, - "count": len(items), - "title_model": title_model, + "threshold": float(req.threshold or 0.86), + "min_cluster_size": int(req.min_cluster_size or 2), + "assign_noise_threshold": req.assign_noise_threshold, + "title_model": req.title_model, + "lang": str(req.lang or ""), + "nv": _PROMPT_NORMALIZE_VERSION, + "nm": _PROMPT_NORMALIZE_MODE, } + h = hashlib.sha1() + h.update(json.dumps({"folders": folders, "params": cache_params}, ensure_ascii=False, sort_keys=True).encode("utf-8")) + cache_key = h.hexdigest() + cached = TopicClusterCache.get(conn, cache_key) + if ( + cached + and int(cached.get("embeddings_count") or 0) == embeddings_count + and str(cached.get("embeddings_max_updated_at") or "") == embeddings_max_updated_at + and isinstance(cached.get("result"), dict) + ): + return cached["result"] + + res = await _cluster_after_embeddings(req, folders, progress_cb=None) + try: + TopicClusterCache.upsert( + conn, + cache_key=cache_key, + folders=folders, + model=model, + params=cache_params, + embeddings_count=embeddings_count, + embeddings_max_updated_at=embeddings_max_updated_at, + result=res, + ) + conn.commit() + except Exception: + pass + return res