chore: require numpy/hnswlib for topic clustering
parent
25190f7307
commit
3ba8997626
|
|
@ -18,16 +18,31 @@ from pydantic import BaseModel
|
||||||
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache
|
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, ImageEmbeddingFail, TopicClusterCache, TopicTitleCache
|
||||||
from scripts.iib.tool import cwd
|
from scripts.iib.tool import cwd
|
||||||
|
|
||||||
# Optional perf deps (heavy but worth it for 10k+ images)
|
# Perf deps (required for this feature)
|
||||||
try:
|
|
||||||
import numpy as _np # type: ignore
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
_np = None
|
_np = None
|
||||||
|
|
||||||
try:
|
|
||||||
import hnswlib as _hnswlib # type: ignore
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
_hnswlib = None
|
_hnswlib = None
|
||||||
|
_PERF_DEPS_READY = False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_perf_deps() -> None:
|
||||||
|
"""
|
||||||
|
We do NOT allow users to use TopicSearch/TopicClustering feature without these deps.
|
||||||
|
But we also do NOT want the whole server to crash on import, so we lazy-load and fail at API layer.
|
||||||
|
"""
|
||||||
|
global _np, _hnswlib, _PERF_DEPS_READY
|
||||||
|
if _PERF_DEPS_READY:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import numpy as np # type: ignore
|
||||||
|
import hnswlib as hnswlib # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Topic clustering requires numpy+hnswlib. Please install them. import_error={type(e).__name__}: {e}",
|
||||||
|
)
|
||||||
|
_np = np
|
||||||
|
_hnswlib = hnswlib
|
||||||
|
_PERF_DEPS_READY = True
|
||||||
|
|
||||||
|
|
||||||
_TOPIC_CLUSTER_JOBS: Dict[str, Dict] = {}
|
_TOPIC_CLUSTER_JOBS: Dict[str, Dict] = {}
|
||||||
|
|
@ -333,18 +348,13 @@ def _dot(a: array, b: array) -> float:
|
||||||
return sum((x * y for x, y in zip(a, b)))
|
return sum((x * y for x, y in zip(a, b)))
|
||||||
|
|
||||||
|
|
||||||
def _can_use_ann() -> bool:
|
|
||||||
return _np is not None and _hnswlib is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _centroid_vec_np(sum_vec: array, norm_sq: float):
|
def _centroid_vec_np(sum_vec: array, norm_sq: float):
|
||||||
"""
|
"""
|
||||||
Convert centroid sum-vector to a normalized numpy float32 vector.
|
Convert centroid sum-vector to a normalized numpy float32 vector.
|
||||||
Cosine between unit v and centroid is dot(v, sum)/sqrt(norm_sq).
|
Cosine between unit v and centroid is dot(v, sum)/sqrt(norm_sq).
|
||||||
So centroid direction is sum / ||sum||.
|
So centroid direction is sum / ||sum||.
|
||||||
"""
|
"""
|
||||||
if _np is None:
|
_ensure_perf_deps()
|
||||||
return None
|
|
||||||
if norm_sq <= 0:
|
if norm_sq <= 0:
|
||||||
return None
|
return None
|
||||||
inv = 1.0 / math.sqrt(norm_sq)
|
inv = 1.0 / math.sqrt(norm_sq)
|
||||||
|
|
@ -359,7 +369,8 @@ def _build_hnsw_index(centroids_np, *, ef: int = 64, M: int = 32):
|
||||||
Build a cosine HNSW index over centroid vectors.
|
Build a cosine HNSW index over centroid vectors.
|
||||||
Returns index or None.
|
Returns index or None.
|
||||||
"""
|
"""
|
||||||
if not _can_use_ann() or centroids_np is None:
|
_ensure_perf_deps()
|
||||||
|
if centroids_np is None:
|
||||||
return None
|
return None
|
||||||
if len(centroids_np) == 0:
|
if len(centroids_np) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
@ -620,6 +631,9 @@ def mount_topic_cluster_routes(
|
||||||
"""
|
"""
|
||||||
Mount embedding + topic clustering endpoints (MVP: manual, iib_output only).
|
Mount embedding + topic clustering endpoints (MVP: manual, iib_output only).
|
||||||
"""
|
"""
|
||||||
|
# Fail fast at API layer if required perf deps are missing.
|
||||||
|
# (We don't crash the whole server at import time.)
|
||||||
|
_ensure_perf_deps()
|
||||||
|
|
||||||
async def _run_cluster_job(job_id: str, req) -> None:
|
async def _run_cluster_job(job_id: str, req) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1142,7 +1156,7 @@ def mount_topic_cluster_routes(
|
||||||
best_sim = -1.0
|
best_sim = -1.0
|
||||||
best_dot = 0.0
|
best_dot = 0.0
|
||||||
# Build / rebuild ANN index when helpful (many clusters)
|
# Build / rebuild ANN index when helpful (many clusters)
|
||||||
if _can_use_ann() and len(clusters) >= 64:
|
if len(clusters) >= 64:
|
||||||
if ann_idx is None or (idx % ann_rebuild_every == 0):
|
if ann_idx is None or (idx % ann_rebuild_every == 0):
|
||||||
# rebuild from current centroids
|
# rebuild from current centroids
|
||||||
cents = []
|
cents = []
|
||||||
|
|
@ -1226,7 +1240,7 @@ def mount_topic_cluster_routes(
|
||||||
new_large.append(c)
|
new_large.append(c)
|
||||||
# Build ANN over large centroids once (optional)
|
# Build ANN over large centroids once (optional)
|
||||||
ann_large = None
|
ann_large = None
|
||||||
if _can_use_ann() and len(new_large) >= 64:
|
if len(new_large) >= 64:
|
||||||
cents = []
|
cents = []
|
||||||
for c in new_large:
|
for c in new_large:
|
||||||
cv = _centroid_vec_np(c["sum"], c["norm_sq"])
|
cv = _centroid_vec_np(c["sum"], c["norm_sq"])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue