chore: require numpy/hnswlib for topic clustering

pull/870/head
wuqinchuan 2026-01-03 16:32:09 +08:00 committed by zanllp
parent 25190f7307
commit 3ba8997626
1 changed files with 32 additions and 18 deletions

View File

@ -18,16 +18,31 @@ from pydantic import BaseModel
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache
from scripts.iib.tool import cwd from scripts.iib.tool import cwd
# Optional perf deps (heavy but worth it for 10k+ images) # Perf deps (required for this feature)
try: _np = None
import numpy as _np # type: ignore _hnswlib = None
except Exception: # pragma: no cover _PERF_DEPS_READY = False
_np = None
try:
import hnswlib as _hnswlib # type: ignore def _ensure_perf_deps() -> None:
except Exception: # pragma: no cover """
_hnswlib = None We do NOT allow users to use TopicSearch/TopicClustering feature without these deps.
But we also do NOT want the whole server to crash on import, so we lazy-load and fail at API layer.
"""
global _np, _hnswlib, _PERF_DEPS_READY
if _PERF_DEPS_READY:
return
try:
import numpy as np # type: ignore
import hnswlib as hnswlib # type: ignore
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Topic clustering requires numpy+hnswlib. Please install them. import_error={type(e).__name__}: {e}",
)
_np = np
_hnswlib = hnswlib
_PERF_DEPS_READY = True
_TOPIC_CLUSTER_JOBS: Dict[str, Dict] = {} _TOPIC_CLUSTER_JOBS: Dict[str, Dict] = {}
@ -333,18 +348,13 @@ def _dot(a: array, b: array) -> float:
return sum((x * y for x, y in zip(a, b))) return sum((x * y for x, y in zip(a, b)))
def _can_use_ann() -> bool:
return _np is not None and _hnswlib is not None
def _centroid_vec_np(sum_vec: array, norm_sq: float): def _centroid_vec_np(sum_vec: array, norm_sq: float):
""" """
Convert centroid sum-vector to a normalized numpy float32 vector. Convert centroid sum-vector to a normalized numpy float32 vector.
Cosine between unit v and centroid is dot(v, sum)/sqrt(norm_sq). Cosine between unit v and centroid is dot(v, sum)/sqrt(norm_sq).
So centroid direction is sum / ||sum||. So centroid direction is sum / ||sum||.
""" """
if _np is None: _ensure_perf_deps()
return None
if norm_sq <= 0: if norm_sq <= 0:
return None return None
inv = 1.0 / math.sqrt(norm_sq) inv = 1.0 / math.sqrt(norm_sq)
@ -359,7 +369,8 @@ def _build_hnsw_index(centroids_np, *, ef: int = 64, M: int = 32):
Build a cosine HNSW index over centroid vectors. Build a cosine HNSW index over centroid vectors.
Returns index or None. Returns index or None.
""" """
if not _can_use_ann() or centroids_np is None: _ensure_perf_deps()
if centroids_np is None:
return None return None
if len(centroids_np) == 0: if len(centroids_np) == 0:
return None return None
@ -620,6 +631,9 @@ def mount_topic_cluster_routes(
""" """
Mount embedding + topic clustering endpoints (MVP: manual, iib_output only). Mount embedding + topic clustering endpoints (MVP: manual, iib_output only).
""" """
# Fail fast at API layer if required perf deps are missing.
# (We don't crash the whole server at import time.)
_ensure_perf_deps()
async def _run_cluster_job(job_id: str, req) -> None: async def _run_cluster_job(job_id: str, req) -> None:
try: try:
@ -1142,7 +1156,7 @@ def mount_topic_cluster_routes(
best_sim = -1.0 best_sim = -1.0
best_dot = 0.0 best_dot = 0.0
# Build / rebuild ANN index when helpful (many clusters) # Build / rebuild ANN index when helpful (many clusters)
if _can_use_ann() and len(clusters) >= 64: if len(clusters) >= 64:
if ann_idx is None or (idx % ann_rebuild_every == 0): if ann_idx is None or (idx % ann_rebuild_every == 0):
# rebuild from current centroids # rebuild from current centroids
cents = [] cents = []
@ -1226,7 +1240,7 @@ def mount_topic_cluster_routes(
new_large.append(c) new_large.append(c)
# Build ANN over large centroids once (optional) # Build ANN over large centroids once (optional)
ann_large = None ann_large = None
if _can_use_ann() and len(new_large) >= 64: if len(new_large) >= 64:
cents = [] cents = []
for c in new_large: for c in new_large:
cv = _centroid_vec_np(c["sum"], c["norm_sq"]) cv = _centroid_vec_np(c["sum"], c["norm_sq"])