chore: require numpy/hnswlib for topic clustering
parent
25190f7307
commit
3ba8997626
|
|
@ -18,16 +18,31 @@ from pydantic import BaseModel
|
|||
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache
|
||||
from scripts.iib.tool import cwd
|
||||
|
||||
# Optional perf deps (heavy but worth it for 10k+ images)
|
||||
try:
|
||||
import numpy as _np # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
_np = None
|
||||
# Perf deps (required for this feature)
|
||||
_np = None
|
||||
_hnswlib = None
|
||||
_PERF_DEPS_READY = False
|
||||
|
||||
try:
|
||||
import hnswlib as _hnswlib # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
_hnswlib = None
|
||||
|
||||
def _ensure_perf_deps() -> None:
|
||||
"""
|
||||
We do NOT allow users to use TopicSearch/TopicClustering feature without these deps.
|
||||
But we also do NOT want the whole server to crash on import, so we lazy-load and fail at API layer.
|
||||
"""
|
||||
global _np, _hnswlib, _PERF_DEPS_READY
|
||||
if _PERF_DEPS_READY:
|
||||
return
|
||||
try:
|
||||
import numpy as np # type: ignore
|
||||
import hnswlib as hnswlib # type: ignore
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Topic clustering requires numpy+hnswlib. Please install them. import_error={type(e).__name__}: {e}",
|
||||
)
|
||||
_np = np
|
||||
_hnswlib = hnswlib
|
||||
_PERF_DEPS_READY = True
|
||||
|
||||
|
||||
_TOPIC_CLUSTER_JOBS: Dict[str, Dict] = {}
|
||||
|
|
@ -333,18 +348,13 @@ def _dot(a: array, b: array) -> float:
|
|||
return sum((x * y for x, y in zip(a, b)))
|
||||
|
||||
|
||||
def _can_use_ann() -> bool:
|
||||
return _np is not None and _hnswlib is not None
|
||||
|
||||
|
||||
def _centroid_vec_np(sum_vec: array, norm_sq: float):
|
||||
"""
|
||||
Convert centroid sum-vector to a normalized numpy float32 vector.
|
||||
Cosine between unit v and centroid is dot(v, sum)/sqrt(norm_sq).
|
||||
So centroid direction is sum / ||sum||.
|
||||
"""
|
||||
if _np is None:
|
||||
return None
|
||||
_ensure_perf_deps()
|
||||
if norm_sq <= 0:
|
||||
return None
|
||||
inv = 1.0 / math.sqrt(norm_sq)
|
||||
|
|
@ -359,7 +369,8 @@ def _build_hnsw_index(centroids_np, *, ef: int = 64, M: int = 32):
|
|||
Build a cosine HNSW index over centroid vectors.
|
||||
Returns index or None.
|
||||
"""
|
||||
if not _can_use_ann() or centroids_np is None:
|
||||
_ensure_perf_deps()
|
||||
if centroids_np is None:
|
||||
return None
|
||||
if len(centroids_np) == 0:
|
||||
return None
|
||||
|
|
@ -620,6 +631,9 @@ def mount_topic_cluster_routes(
|
|||
"""
|
||||
Mount embedding + topic clustering endpoints (MVP: manual, iib_output only).
|
||||
"""
|
||||
# Fail fast at API layer if required perf deps are missing.
|
||||
# (We don't crash the whole server at import time.)
|
||||
_ensure_perf_deps()
|
||||
|
||||
async def _run_cluster_job(job_id: str, req) -> None:
|
||||
try:
|
||||
|
|
@ -1142,7 +1156,7 @@ def mount_topic_cluster_routes(
|
|||
best_sim = -1.0
|
||||
best_dot = 0.0
|
||||
# Build / rebuild ANN index when helpful (many clusters)
|
||||
if _can_use_ann() and len(clusters) >= 64:
|
||||
if len(clusters) >= 64:
|
||||
if ann_idx is None or (idx % ann_rebuild_every == 0):
|
||||
# rebuild from current centroids
|
||||
cents = []
|
||||
|
|
@ -1226,7 +1240,7 @@ def mount_topic_cluster_routes(
|
|||
new_large.append(c)
|
||||
# Build ANN over large centroids once (optional)
|
||||
ann_large = None
|
||||
if _can_use_ann() and len(new_large) >= 64:
|
||||
if len(new_large) >= 64:
|
||||
cents = []
|
||||
for c in new_large:
|
||||
cv = _centroid_vec_np(c["sum"], c["norm_sq"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue