932 lines
38 KiB
Python
932 lines
38 KiB
Python
import hashlib
|
||
import json
|
||
import math
|
||
import os
|
||
import re
|
||
from array import array
|
||
from contextlib import closing
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
import requests
|
||
from fastapi import Depends, FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
|
||
from scripts.iib.db.datamodel import DataBase, ImageEmbedding, TopicTitleCache
|
||
from scripts.iib.tool import cwd
|
||
|
||
|
||
def _normalize_base_url(base_url: str) -> str:
|
||
return base_url[:-1] if base_url.endswith("/") else base_url
|
||
|
||
|
||
_PROMPT_NORMALIZE_ENABLED = os.getenv("IIB_PROMPT_NORMALIZE", "1").strip().lower() not in ["0", "false", "no", "off"]
|
||
# balanced: keep some discriminative style words (e.g. "科学插图/纪实/胶片") while dropping boilerplate quality/camera terms
|
||
# theme_only: more aggressive removal, closer to "only subject/theme nouns"
|
||
_PROMPT_NORMALIZE_MODE = (os.getenv("IIB_PROMPT_NORMALIZE_MODE", "balanced") or "balanced").strip().lower()
|
||
|
||
# Remove common SD boilerplate / quality descriptors.
|
||
_DROP_PATTERNS_COMMON = [
|
||
# SD / A1111 tags
|
||
r"<lora:[^>]+>",
|
||
r"<lyco:[^>]+>",
|
||
# quality / resolution / generic
|
||
r"\b(masterpiece|best\s*quality|high\s*quality|best\s*rating|(?:highly|ultra|hyper)[-\s\u2010\u2011\u2012\u2013\u2212]*detailed|absurdres|absurd\s*res|hires|hdr|uhd|8k|4k|2k|raw\s*photo|photorealistic|realistic|cinematic)\b",
|
||
# photography / camera / lens
|
||
r"\b(film\s+photography|photography|dslr|camera|canon|nikon|sony|sigma|leica|lens|bokeh|depth\s+of\s+field|dof|sharp\s+focus|wide\s+angle|fisheye)\b",
|
||
r"\b(iso\s*\d{2,5}|f\/\d+(?:\.\d+)?|\d{2,4}mm)\b",
|
||
]
|
||
# chinese quality descriptors (common)
|
||
_DROP_PATTERNS_ZH_COMMON = [
|
||
r"(超高分辨率|高分辨率|高清|超清|8K|4K|2K|照片级|高质量|最佳质量|大师作品|杰作|超细节|细节丰富|极致细节|极致|完美)",
|
||
]
|
||
# chinese "style/photography" words: keep in balanced mode (discriminative), drop in theme_only mode
|
||
_DROP_PATTERNS_ZH_STYLE = [
|
||
r"(电影质感|写真|写实|真实感|摄影|摄影作品|摄影图像|摄影图|镜头|景深|胶片|光圈|光影|构图|色彩|渲染|纪实|插图|科学插图)",
|
||
]
|
||
|
||
def _build_drop_re() -> re.Pattern:
|
||
pats = list(_DROP_PATTERNS_COMMON) + list(_DROP_PATTERNS_ZH_COMMON)
|
||
if _PROMPT_NORMALIZE_MODE in ["theme", "theme_only", "strict"]:
|
||
pats += list(_DROP_PATTERNS_ZH_STYLE)
|
||
return re.compile("|".join(f"(?:{p})" for p in pats), flags=re.IGNORECASE)
|
||
|
||
_DROP_RE = _build_drop_re()
|
||
|
||
|
||
def _compute_prompt_normalize_version() -> str:
|
||
"""
|
||
IMPORTANT:
|
||
- Do NOT allow users to override normalize-version via environment variables.
|
||
- Version should be deterministic from the normalization rules themselves, so cache invalidation
|
||
happens automatically when we change rules in code (or switch mode).
|
||
"""
|
||
payload = {
|
||
"enabled": bool(_PROMPT_NORMALIZE_ENABLED),
|
||
"mode": str(_PROMPT_NORMALIZE_MODE),
|
||
"drop_common": list(_DROP_PATTERNS_COMMON),
|
||
"drop_zh_common": list(_DROP_PATTERNS_ZH_COMMON),
|
||
"drop_zh_style": list(_DROP_PATTERNS_ZH_STYLE),
|
||
}
|
||
s = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||
return "nv_" + hashlib.sha1(s.encode("utf-8")).hexdigest()[:12]
|
||
|
||
|
||
# Derived normalize version fingerprint (for embedding/title cache invalidation)
|
||
_PROMPT_NORMALIZE_VERSION = _compute_prompt_normalize_version()
|
||
|
||
|
||
def _extract_prompt_text(raw_exif: str, max_chars: int = 4000) -> str:
|
||
"""
|
||
Extract the natural-language prompt part from stored exif text.
|
||
Keep text before 'Negative prompt:' to preserve semantics.
|
||
"""
|
||
if not isinstance(raw_exif, str):
|
||
return ""
|
||
s = raw_exif.strip()
|
||
if not s:
|
||
return ""
|
||
idx = s.lower().find("negative prompt:")
|
||
if idx != -1:
|
||
s = s[:idx].strip()
|
||
if len(s) > max_chars:
|
||
s = s[:max_chars]
|
||
return s.strip()
|
||
|
||
|
||
def _clean_prompt_for_semantic(text: str) -> str:
|
||
"""
|
||
Light, dependency-free prompt normalization:
|
||
- remove lora tags / SD boilerplate / quality & photography descriptors
|
||
- keep remaining text as 'theme' semantic signal for embeddings/clustering
|
||
"""
|
||
if not isinstance(text, str):
|
||
return ""
|
||
s = text
|
||
# remove negative prompt tail early (safety if caller passes raw exif)
|
||
s = re.sub(r"(negative prompt:).*", " ", s, flags=re.IGNORECASE | re.DOTALL)
|
||
# remove weights like (foo:1.2)
|
||
s = re.sub(r"\(([^()]{1,80}):\s*\d+(?:\.\d+)?\)", r"\1", s)
|
||
# drop boilerplate patterns
|
||
s = _DROP_RE.sub(" ", s)
|
||
# normalize separators
|
||
s = s.replace("**", " ")
|
||
s = re.sub(r"[\[\]{}()]", " ", s)
|
||
s = re.sub(r"\s+", " ", s).strip()
|
||
# If it's a comma-tag style prompt, remove empty / tiny segments.
|
||
parts = re.split(r"[,\n,;;]+", s)
|
||
kept: List[str] = []
|
||
for p in parts:
|
||
t = p.strip()
|
||
if not t:
|
||
continue
|
||
# drop segments that are basically leftover boilerplate (too short or all punctuation)
|
||
if len(t) <= 2:
|
||
continue
|
||
kept.append(t)
|
||
s2 = ",".join(kept) if kept else s
|
||
return s2.strip()
|
||
|
||
|
||
def _clean_for_title(text: str) -> str:
|
||
if not isinstance(text, str):
|
||
return ""
|
||
s = text
|
||
s = s.replace("**", " ")
|
||
s = re.sub(r"<lora:[^>]+>", " ", s, flags=re.IGNORECASE)
|
||
s = re.sub(r"<lyco:[^>]+>", " ", s, flags=re.IGNORECASE)
|
||
s = re.sub(r"(negative prompt:).*", " ", s, flags=re.IGNORECASE | re.DOTALL)
|
||
s = re.sub(r"(prompt:|提示词[::]|提示[::]|输出[::])", " ", s, flags=re.IGNORECASE)
|
||
s = re.sub(r"\s+", " ", s).strip()
|
||
return s
|
||
|
||
|
||
def _title_from_representative_prompt(text: str, max_len: int = 18) -> str:
|
||
"""
|
||
Local fallback title: take the first sentence/clause and truncate.
|
||
This is much more readable than token n-grams without Chinese word segmentation.
|
||
"""
|
||
# Use the same semantic cleaner as embeddings to avoid boilerplate titles like
|
||
# "masterpiece, best quality" / "A highly detailed ..."
|
||
base = _clean_for_title(text)
|
||
s = _clean_prompt_for_semantic(base) if _PROMPT_NORMALIZE_ENABLED else base
|
||
if not s:
|
||
s = base
|
||
if not s:
|
||
return "主题"
|
||
# Split by common sentence punctuations, keep the first segment.
|
||
seg = re.split(r"[。!?!?\n\r;;]+", s)[0].strip()
|
||
# Remove leading punctuation / separators
|
||
seg = re.sub(r"^[,,;;::\s-]+", "", seg).strip()
|
||
# Remove leading english articles for nicer titles
|
||
seg = re.sub(r"^(a|an|the)\s+", "", seg, flags=re.IGNORECASE).strip()
|
||
# Strip common boilerplate templates in titles while keeping discriminative words.
|
||
# English: "highly detailed scientific rendering/illustration of ..."
|
||
seg = re.sub(
|
||
r"^(?:(?:highly|ultra|hyper)[-\s\u2010\u2011\u2012\u2013\u2212]*detailed\s+)?(?:scientific\s+)?(?:rendering|illustration|image|depiction|scene)\s+of\s+",
|
||
"",
|
||
seg,
|
||
flags=re.IGNORECASE,
|
||
).strip()
|
||
# Chinese: "一张...图像/插图/照片..." template
|
||
seg = re.sub(r"^一[张幅]\s*[^,,。]{0,20}(?:图像|插图|照片|摄影图像|摄影作品)\s*[,, ]*", "", seg).strip()
|
||
# Remove trailing commas/colons
|
||
seg = re.sub(r"[,:,:]\s*$", "", seg).strip()
|
||
# If still too long, hard truncate.
|
||
if len(seg) > max_len:
|
||
seg = seg[:max_len].rstrip()
|
||
return seg or "主题"
|
||
|
||
|
||
def _vec_to_blob_f32(vec: List[float]) -> bytes:
|
||
arr = array("f", vec)
|
||
return arr.tobytes()
|
||
|
||
|
||
def _blob_to_vec_f32(blob: bytes) -> array:
|
||
arr = array("f")
|
||
arr.frombytes(blob)
|
||
return arr
|
||
|
||
|
||
def _l2_norm_sq(vec: array) -> float:
|
||
return sum((x * x for x in vec))
|
||
|
||
|
||
def _dot(a: array, b: array) -> float:
|
||
return sum((x * y for x, y in zip(a, b)))
|
||
|
||
|
||
def _cos_sum(a_sum: array, a_norm_sq: float, b_sum: array, b_norm_sq: float) -> float:
|
||
if a_norm_sq <= 0 or b_norm_sq <= 0:
|
||
return 0.0
|
||
dotv = sum((x * y for x, y in zip(a_sum, b_sum)))
|
||
return dotv / (math.sqrt(a_norm_sq) * math.sqrt(b_norm_sq))
|
||
|
||
|
||
def _call_embeddings(
|
||
*,
|
||
inputs: List[str],
|
||
model: str,
|
||
base_url: str,
|
||
api_key: str,
|
||
) -> List[List[float]]:
|
||
if not api_key:
|
||
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
|
||
|
||
url = f"{_normalize_base_url(base_url)}/embeddings"
|
||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||
payload = {"model": model, "input": inputs}
|
||
try:
|
||
resp = requests.post(url, json=payload, headers=headers, timeout=120)
|
||
except requests.RequestException as e:
|
||
raise HTTPException(status_code=502, detail=f"Embedding API request failed: {e}")
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=resp.status_code, detail=resp.text)
|
||
data = resp.json()
|
||
items = data.get("data") or []
|
||
items.sort(key=lambda x: x.get("index", 0))
|
||
embeddings = [x.get("embedding") for x in items]
|
||
if any((not isinstance(v, list) for v in embeddings)):
|
||
raise HTTPException(status_code=500, detail="Invalid embeddings response format")
|
||
return embeddings
|
||
|
||
|
||
def _call_chat_title(
|
||
*,
|
||
base_url: str,
|
||
api_key: str,
|
||
model: str,
|
||
prompt_samples: List[str],
|
||
output_lang: str,
|
||
) -> Optional[Dict]:
|
||
"""
|
||
Ask LLM to generate a short topic title and a few keywords. Returns dict or None.
|
||
"""
|
||
if not api_key:
|
||
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
|
||
url = f"{_normalize_base_url(base_url)}/chat/completions"
|
||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||
|
||
samples = [(_clean_prompt_for_semantic(_clean_for_title(s) or s) or s).strip() for s in prompt_samples if (s or "").strip()]
|
||
samples = [s[:400] for s in samples][:6]
|
||
if not samples:
|
||
raise HTTPException(status_code=400, detail="No prompt samples for title generation")
|
||
|
||
json_example = '{"title":"...","keywords":["...","..."]}'
|
||
sys = (
|
||
"You are a topic naming assistant for image-generation prompts.\n"
|
||
"Given several prompt snippets that belong to the SAME theme, output:\n"
|
||
"- a short topic title\n"
|
||
"- 3–6 keywords.\n"
|
||
"\n"
|
||
"Rules:\n"
|
||
f"- Output language MUST be: {output_lang}\n"
|
||
"- Prefer 4–12 characters for Chinese (Simplified/Traditional), otherwise 2–6 English/German words.\n"
|
||
"- Avoid generic boilerplate like: masterpiece, best quality, highly detailed, cinematic, etc.\n"
|
||
"- Keep distinctive terms if they help differentiate themes (e.g., Warhammer 40K, Lolita, scientific illustration).\n"
|
||
"- Do NOT output explanations. Do NOT output markdown/code fences.\n"
|
||
"- The output MUST start with '{' and end with '}' (no leading/trailing characters).\n"
|
||
"\n"
|
||
"Output STRICT JSON only:\n"
|
||
+ json_example
|
||
)
|
||
user = "Prompt snippets:\n" + "\n".join([f"- {s}" for s in samples])
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": [{"role": "system", "content": sys}, {"role": "user", "content": user}],
|
||
# Prefer deterministic, JSON-only output
|
||
"temperature": 0.0,
|
||
"top_p": 1.0,
|
||
# Give enough room for JSON across providers.
|
||
"max_tokens": 2048,
|
||
# Prefer tool/function call to force structured output across providers (e.g. Gemini).
|
||
"tools": [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "set_topic",
|
||
"description": "Return a concise topic title and 3-6 keywords.",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"title": {"type": "string"},
|
||
"keywords": {"type": "array", "items": {"type": "string"}},
|
||
},
|
||
"required": ["title", "keywords"],
|
||
"additionalProperties": False,
|
||
},
|
||
},
|
||
}
|
||
],
|
||
"tool_choice": {"type": "function", "function": {"name": "set_topic"}},
|
||
}
|
||
# Some OpenAI-compatible providers may use different token limit fields / casing.
|
||
# Set them all (still a single request; no retry/fallback).
|
||
payload["max_output_tokens"] = payload["max_tokens"]
|
||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||
payload["maxOutputTokens"] = payload["max_tokens"]
|
||
payload["maxCompletionTokens"] = payload["max_tokens"]
|
||
|
||
def _post_and_parse(payload_obj: Dict) -> Dict:
|
||
try:
|
||
resp = requests.post(url, json=payload_obj, headers=headers, timeout=60)
|
||
except requests.RequestException as e:
|
||
raise HTTPException(status_code=502, detail=f"Chat API request failed: {e}")
|
||
if resp.status_code != 200:
|
||
# keep response body for debugging (truncated)
|
||
body = (resp.text or "")[:600]
|
||
raise HTTPException(status_code=resp.status_code, detail=body)
|
||
try:
|
||
data = resp.json()
|
||
except Exception as e:
|
||
txt = (resp.text or "")[:600]
|
||
raise HTTPException(status_code=502, detail=f"Chat API response is not JSON: {e}; body={txt}")
|
||
choice0 = (data.get("choices") or [{}])[0] if isinstance(data.get("choices"), list) else {}
|
||
msg = (choice0 or {}).get("message") or {}
|
||
|
||
# OpenAI-compatible providers may return JSON in different places:
|
||
# - message.tool_calls[].function.arguments (JSON string) <-- preferred when tools are used
|
||
# - message.function_call.arguments (legacy)
|
||
# - message.content (common)
|
||
# - choice.text (legacy completions)
|
||
raw = ""
|
||
if not raw and isinstance(msg, dict):
|
||
tcs = msg.get("tool_calls") or []
|
||
if isinstance(tcs, list) and tcs:
|
||
fn = ((tcs[0] or {}).get("function") or {}) if isinstance(tcs[0], dict) else {}
|
||
args = (fn.get("arguments") or "") if isinstance(fn, dict) else ""
|
||
if isinstance(args, str) and args.strip():
|
||
raw = args.strip()
|
||
if not raw and isinstance(msg, dict):
|
||
fc = msg.get("function_call") or {}
|
||
args = (fc.get("arguments") or "") if isinstance(fc, dict) else ""
|
||
if isinstance(args, str) and args.strip():
|
||
raw = args.strip()
|
||
if not raw:
|
||
content = (msg.get("content") or "") if isinstance(msg, dict) else ""
|
||
if isinstance(content, str) and content.strip():
|
||
raw = content.strip()
|
||
if not raw:
|
||
txt = (choice0.get("text") or "") if isinstance(choice0, dict) else ""
|
||
if isinstance(txt, str) and txt.strip():
|
||
raw = txt.strip()
|
||
|
||
m = re.search(r"\{[\s\S]*\}", raw)
|
||
if not m:
|
||
snippet = (raw or "")[:400].replace("\n", "\\n")
|
||
choice_dump = json.dumps(choice0, ensure_ascii=False)[:600] if isinstance(choice0, dict) else str(choice0)[:600]
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"Chat API response has no JSON object; content_snippet={snippet}; choice0={choice_dump}",
|
||
)
|
||
try:
|
||
obj = json.loads(m.group(0))
|
||
except Exception as e:
|
||
snippet = (m.group(0) or "")[:400].replace("\n", "\\n")
|
||
raise HTTPException(status_code=502, detail=f"Chat API JSON parse failed: {e}; json_snippet={snippet}")
|
||
if not isinstance(obj, dict):
|
||
raise HTTPException(status_code=502, detail="Chat API response JSON is not an object")
|
||
title = str(obj.get("title") or "").strip()
|
||
keywords = obj.get("keywords") or []
|
||
if not title:
|
||
raise HTTPException(status_code=502, detail="Chat API response missing title")
|
||
if not isinstance(keywords, list):
|
||
keywords = []
|
||
keywords = [str(x).strip() for x in keywords if str(x).strip()][:6]
|
||
return {"title": title[:24], "keywords": keywords}
|
||
|
||
# No fallback / no retry: fail fast if provider/model doesn't support response_format or returns invalid output.
|
||
return _post_and_parse(payload)
|
||
|
||
|
||
def mount_topic_cluster_routes(
|
||
app: FastAPI,
|
||
db_api_base: str,
|
||
verify_secret,
|
||
write_permission_required,
|
||
*,
|
||
openai_base_url: str,
|
||
openai_api_key: str,
|
||
embedding_model: str,
|
||
ai_model: str,
|
||
):
|
||
"""
|
||
Mount embedding + topic clustering endpoints (MVP: manual, iib_output only).
|
||
"""
|
||
|
||
class BuildIibOutputEmbeddingReq(BaseModel):
|
||
folder: Optional[str] = None # default: {cwd}/iib_output
|
||
model: Optional[str] = None
|
||
force: Optional[bool] = False
|
||
batch_size: Optional[int] = 64
|
||
max_chars: Optional[int] = 4000
|
||
|
||
@app.post(
|
||
f"{db_api_base}/build_iib_output_embeddings",
|
||
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
|
||
)
|
||
async def build_iib_output_embeddings(req: BuildIibOutputEmbeddingReq):
|
||
if not openai_api_key:
|
||
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
|
||
if not openai_base_url:
|
||
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
|
||
folder = req.folder or os.path.join(cwd, "iib_output")
|
||
folder = os.path.normpath(folder)
|
||
model = req.model or embedding_model
|
||
batch_size = max(1, min(int(req.batch_size or 64), 256))
|
||
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
|
||
force = bool(req.force)
|
||
|
||
if not os.path.exists(folder) or not os.path.isdir(folder):
|
||
raise HTTPException(status_code=400, detail=f"Folder not found: {folder}")
|
||
|
||
conn = DataBase.get_conn()
|
||
like_prefix = os.path.join(folder, "%")
|
||
with closing(conn.cursor()) as cur:
|
||
cur.execute("SELECT id, path, exif FROM image WHERE path LIKE ?", (like_prefix,))
|
||
rows = cur.fetchall()
|
||
|
||
images = []
|
||
for image_id, path, exif in rows:
|
||
if not isinstance(path, str) or not os.path.exists(path):
|
||
continue
|
||
text_raw = _extract_prompt_text(exif, max_chars=max_chars)
|
||
if _PROMPT_NORMALIZE_ENABLED:
|
||
text = _clean_prompt_for_semantic(text_raw)
|
||
if not text:
|
||
text = text_raw
|
||
else:
|
||
text = text_raw
|
||
if not text:
|
||
continue
|
||
images.append({"id": int(image_id), "path": path, "text": text})
|
||
|
||
if not images:
|
||
return {"folder": folder, "count": 0, "updated": 0, "skipped": 0, "model": model}
|
||
|
||
id_list = [x["id"] for x in images]
|
||
existing = ImageEmbedding.get_by_image_ids(conn, id_list)
|
||
|
||
to_embed = []
|
||
skipped = 0
|
||
for item in images:
|
||
# include normalize version to force refresh when rules change
|
||
text_hash = ImageEmbedding.compute_text_hash(f"{_PROMPT_NORMALIZE_VERSION}:{item['text']}")
|
||
old = existing.get(item["id"])
|
||
if (
|
||
(not force)
|
||
and old
|
||
and old.get("model") == model
|
||
and old.get("text_hash") == text_hash
|
||
and old.get("vec")
|
||
):
|
||
skipped += 1
|
||
continue
|
||
to_embed.append({**item, "text_hash": text_hash})
|
||
|
||
updated = 0
|
||
for i in range(0, len(to_embed), batch_size):
|
||
batch = to_embed[i : i + batch_size]
|
||
inputs = [x["text"] for x in batch]
|
||
vectors = _call_embeddings(
|
||
inputs=inputs,
|
||
model=model,
|
||
base_url=openai_base_url,
|
||
api_key=openai_api_key,
|
||
)
|
||
if len(vectors) != len(batch):
|
||
raise HTTPException(status_code=500, detail="Embeddings count mismatch")
|
||
for item, vec in zip(batch, vectors):
|
||
ImageEmbedding.upsert(
|
||
conn=conn,
|
||
image_id=item["id"],
|
||
model=model,
|
||
dim=len(vec),
|
||
text_hash=item["text_hash"],
|
||
vec_blob=_vec_to_blob_f32(vec),
|
||
)
|
||
updated += 1
|
||
conn.commit()
|
||
|
||
return {"folder": folder, "count": len(images), "updated": updated, "skipped": skipped, "model": model}
|
||
|
||
class ClusterIibOutputReq(BaseModel):
|
||
folder: Optional[str] = None
|
||
folder_paths: Optional[List[str]] = None
|
||
model: Optional[str] = None
|
||
force_embed: Optional[bool] = False
|
||
threshold: Optional[float] = 0.86
|
||
batch_size: Optional[int] = 64
|
||
max_chars: Optional[int] = 4000
|
||
min_cluster_size: Optional[int] = 2
|
||
# B: LLM title generation
|
||
title_model: Optional[str] = None
|
||
# Reduce noise by reassigning small-cluster members to best large cluster if similarity is high enough
|
||
assign_noise_threshold: Optional[float] = None
|
||
# Cache titles in sqlite to avoid repeated LLM calls
|
||
use_title_cache: Optional[bool] = True
|
||
force_title: Optional[bool] = False
|
||
# Output language for titles/keywords (from frontend globalStore.lang)
|
||
lang: Optional[str] = None
|
||
|
||
class PromptSearchReq(BaseModel):
|
||
query: str
|
||
folder: Optional[str] = None
|
||
folder_paths: Optional[List[str]] = None
|
||
model: Optional[str] = None
|
||
top_k: Optional[int] = 50
|
||
min_score: Optional[float] = 0.0
|
||
# Ensure embeddings exist/updated before searching
|
||
ensure_embed: Optional[bool] = True
|
||
# Use the same normalization as clustering
|
||
max_chars: Optional[int] = 4000
|
||
|
||
@app.post(
|
||
f"{db_api_base}/search_iib_output_by_prompt",
|
||
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
|
||
)
|
||
async def search_iib_output_by_prompt(req: PromptSearchReq):
|
||
if not openai_api_key:
|
||
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
|
||
if not openai_base_url:
|
||
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
|
||
|
||
q = (req.query or "").strip()
|
||
if not q:
|
||
raise HTTPException(status_code=400, detail="query is required")
|
||
|
||
folders: List[str] = []
|
||
if req.folder_paths:
|
||
for p in req.folder_paths:
|
||
if isinstance(p, str) and p.strip():
|
||
folders.append(os.path.normpath(p.strip()))
|
||
if req.folder and isinstance(req.folder, str) and req.folder.strip():
|
||
folders.append(os.path.normpath(req.folder.strip()))
|
||
# 用户不会用默认 iib_output:未指定范围则直接报错
|
||
if not folders:
|
||
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to search)")
|
||
|
||
# validate folders
|
||
folders = list(dict.fromkeys(folders)) # de-dup keep order
|
||
for f in folders:
|
||
if not os.path.exists(f) or not os.path.isdir(f):
|
||
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
|
||
|
||
folder = folders[0]
|
||
model = req.model or embedding_model
|
||
top_k = max(1, min(int(req.top_k or 50), 500))
|
||
min_score = float(req.min_score or 0.0)
|
||
min_score = max(-1.0, min(min_score, 1.0))
|
||
max_chars = max(256, min(int(req.max_chars or 4000), 8000))
|
||
|
||
if bool(req.ensure_embed):
|
||
for f in folders:
|
||
await build_iib_output_embeddings(
|
||
BuildIibOutputEmbeddingReq(folder=f, model=model, force=False, batch_size=64, max_chars=max_chars)
|
||
)
|
||
|
||
# Build query embedding
|
||
q_text = _extract_prompt_text(q, max_chars=max_chars)
|
||
if _PROMPT_NORMALIZE_ENABLED:
|
||
q_text2 = _clean_prompt_for_semantic(q_text)
|
||
if q_text2:
|
||
q_text = q_text2
|
||
vecs = _call_embeddings(inputs=[q_text], model=model, base_url=openai_base_url, api_key=openai_api_key)
|
||
if not vecs or not isinstance(vecs[0], list) or not vecs[0]:
|
||
raise HTTPException(status_code=502, detail="Embedding API returned empty vector")
|
||
qv = array("f", [float(x) for x in vecs[0]])
|
||
qn2 = _l2_norm_sq(qv)
|
||
if qn2 <= 0:
|
||
raise HTTPException(status_code=502, detail="Query embedding has zero norm")
|
||
qinv = 1.0 / math.sqrt(qn2)
|
||
for i in range(len(qv)):
|
||
qv[i] *= qinv
|
||
|
||
conn = DataBase.get_conn()
|
||
like_prefixes = [os.path.join(f, "%") for f in folders]
|
||
with closing(conn.cursor()) as cur:
|
||
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
|
||
cur.execute(
|
||
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
|
||
FROM image
|
||
INNER JOIN image_embedding ON image_embedding.image_id = image.id
|
||
WHERE ({where}) AND image_embedding.model = ?""",
|
||
(*like_prefixes, model),
|
||
)
|
||
rows = cur.fetchall()
|
||
|
||
# TopK by cosine similarity (brute force; MVP only)
|
||
import heapq
|
||
|
||
heap: List[Tuple[float, Dict]] = []
|
||
total = 0
|
||
for image_id, path, exif, vec_blob in rows:
|
||
if not isinstance(path, str) or not os.path.exists(path):
|
||
continue
|
||
if not vec_blob:
|
||
continue
|
||
v = _blob_to_vec_f32(vec_blob)
|
||
n2 = _l2_norm_sq(v)
|
||
if n2 <= 0:
|
||
continue
|
||
inv = 1.0 / math.sqrt(n2)
|
||
for i in range(len(v)):
|
||
v[i] *= inv
|
||
score = _dot(qv, v)
|
||
total += 1
|
||
if score < min_score:
|
||
continue
|
||
item = {
|
||
"id": int(image_id),
|
||
"path": path,
|
||
"score": float(score),
|
||
"sample_prompt": _clean_for_title(_extract_prompt_text(exif, max_chars=max_chars))[:200],
|
||
}
|
||
if len(heap) < top_k:
|
||
heapq.heappush(heap, (score, item))
|
||
else:
|
||
if score > heap[0][0]:
|
||
heapq.heapreplace(heap, (score, item))
|
||
|
||
heap.sort(key=lambda x: x[0], reverse=True)
|
||
results = [x[1] for x in heap]
|
||
return {
|
||
"query": q,
|
||
"folder": folder,
|
||
"folders": folders,
|
||
"model": model,
|
||
"count": total,
|
||
"top_k": top_k,
|
||
"results": results,
|
||
}
|
||
|
||
def _cluster_sig(
|
||
*,
|
||
member_ids: List[int],
|
||
model: str,
|
||
threshold: float,
|
||
min_cluster_size: int,
|
||
title_model: str,
|
||
lang: str,
|
||
) -> str:
|
||
h = hashlib.sha1()
|
||
h.update(f"m={model}|t={threshold:.6f}|min={min_cluster_size}|tm={title_model}|lang={lang}|nv={_PROMPT_NORMALIZE_VERSION}|nm={_PROMPT_NORMALIZE_MODE}".encode("utf-8"))
|
||
for iid in sorted(member_ids):
|
||
h.update(b"|")
|
||
h.update(str(int(iid)).encode("utf-8"))
|
||
return h.hexdigest()
|
||
|
||
def _normalize_output_lang(lang: Optional[str]) -> str:
|
||
"""
|
||
Map frontend language keys to a human-readable instruction for LLM output language.
|
||
Frontend uses: en / zhHans / zhHant / de
|
||
"""
|
||
if not lang:
|
||
return "English"
|
||
l = str(lang).strip()
|
||
ll = l.lower()
|
||
if ll in ["zh", "zhhans", "zh-hans", "zh_cn", "zh-cn", "cn", "zh-hans-cns", "zhs"]:
|
||
return "Chinese (Simplified)"
|
||
if ll in ["zhhant", "zh-hant", "zh_tw", "zh-tw", "zh_hk", "zh-hk", "tw", "hk", "zht"]:
|
||
return "Chinese (Traditional)"
|
||
if ll.startswith("de"):
|
||
return "German"
|
||
if ll.startswith("en"):
|
||
return "English"
|
||
# fallback
|
||
return "English"
|
||
|
||
@app.post(
|
||
f"{db_api_base}/cluster_iib_output",
|
||
dependencies=[Depends(verify_secret), Depends(write_permission_required)],
|
||
)
|
||
async def cluster_iib_output(req: ClusterIibOutputReq):
|
||
if not openai_api_key:
|
||
raise HTTPException(status_code=500, detail="OpenAI API Key not configured")
|
||
if not openai_base_url:
|
||
raise HTTPException(status_code=500, detail="OpenAI Base URL not configured")
|
||
folders: List[str] = []
|
||
if req.folder_paths:
|
||
for p in req.folder_paths:
|
||
if isinstance(p, str) and p.strip():
|
||
folders.append(os.path.normpath(p.strip()))
|
||
if req.folder and isinstance(req.folder, str) and req.folder.strip():
|
||
folders.append(os.path.normpath(req.folder.strip()))
|
||
# 用户不会用默认 iib_output:未指定范围则直接报错
|
||
if not folders:
|
||
raise HTTPException(status_code=400, detail="folder_paths is required (select folders to cluster)")
|
||
|
||
folders = list(dict.fromkeys(folders))
|
||
for f in folders:
|
||
if not os.path.exists(f) or not os.path.isdir(f):
|
||
raise HTTPException(status_code=400, detail=f"Folder not found: {f}")
|
||
|
||
# Ensure embeddings exist (incremental per folder)
|
||
for f in folders:
|
||
await build_iib_output_embeddings(
|
||
BuildIibOutputEmbeddingReq(
|
||
folder=f,
|
||
model=req.model,
|
||
force=req.force_embed,
|
||
batch_size=req.batch_size,
|
||
max_chars=req.max_chars,
|
||
)
|
||
)
|
||
|
||
folder = folders[0]
|
||
model = req.model or embedding_model
|
||
threshold = float(req.threshold or 0.86)
|
||
threshold = max(0.0, min(threshold, 0.999))
|
||
min_cluster_size = max(1, int(req.min_cluster_size or 2))
|
||
title_model = req.title_model or os.getenv("TOPIC_TITLE_MODEL") or ai_model
|
||
output_lang = _normalize_output_lang(req.lang)
|
||
assign_noise_threshold = req.assign_noise_threshold
|
||
if assign_noise_threshold is None:
|
||
# conservative: only reassign if very likely belongs to a large topic
|
||
assign_noise_threshold = max(0.72, min(threshold - 0.035, 0.93))
|
||
else:
|
||
assign_noise_threshold = max(0.0, min(float(assign_noise_threshold), 0.999))
|
||
use_title_cache = bool(True if req.use_title_cache is None else req.use_title_cache)
|
||
force_title = bool(req.force_title)
|
||
|
||
conn = DataBase.get_conn()
|
||
like_prefixes = [os.path.join(f, "%") for f in folders]
|
||
with closing(conn.cursor()) as cur:
|
||
where = " OR ".join(["image.path LIKE ?"] * len(like_prefixes))
|
||
cur.execute(
|
||
f"""SELECT image.id, image.path, image.exif, image_embedding.vec
|
||
FROM image
|
||
INNER JOIN image_embedding ON image_embedding.image_id = image.id
|
||
WHERE ({where}) AND image_embedding.model = ?""",
|
||
(*like_prefixes, model),
|
||
)
|
||
rows = cur.fetchall()
|
||
|
||
items = []
|
||
for image_id, path, exif, vec_blob in rows:
|
||
if not isinstance(path, str) or not os.path.exists(path):
|
||
continue
|
||
if not vec_blob:
|
||
continue
|
||
vec = _blob_to_vec_f32(vec_blob)
|
||
n2 = _l2_norm_sq(vec)
|
||
if n2 <= 0:
|
||
continue
|
||
inv = 1.0 / math.sqrt(n2)
|
||
for i in range(len(vec)):
|
||
vec[i] *= inv
|
||
text_raw = _extract_prompt_text(exif, max_chars=int(req.max_chars or 4000))
|
||
if _PROMPT_NORMALIZE_ENABLED:
|
||
text = _clean_prompt_for_semantic(text_raw)
|
||
if not text:
|
||
text = text_raw
|
||
else:
|
||
text = text_raw
|
||
items.append({"id": int(image_id), "path": path, "text": text, "vec": vec})
|
||
|
||
if not items:
|
||
return {"folder": folder, "folders": folders, "model": model, "threshold": threshold, "clusters": [], "noise": []}
|
||
|
||
# Incremental clustering by centroid-direction (sum vector)
|
||
clusters = [] # {sum, norm_sq, members:[idx], sample_text}
|
||
for idx, it in enumerate(items):
|
||
v = it["vec"]
|
||
best_ci = -1
|
||
best_sim = -1.0
|
||
best_dot = 0.0
|
||
for ci, c in enumerate(clusters):
|
||
dotv = _dot(v, c["sum"])
|
||
denom = math.sqrt(c["norm_sq"]) if c["norm_sq"] > 0 else 1.0
|
||
sim = dotv / denom
|
||
if sim > best_sim:
|
||
best_sim = sim
|
||
best_ci = ci
|
||
best_dot = dotv
|
||
if best_ci != -1 and best_sim >= threshold:
|
||
c = clusters[best_ci]
|
||
for i in range(len(v)):
|
||
c["sum"][i] += v[i]
|
||
c["norm_sq"] = c["norm_sq"] + 2.0 * best_dot + 1.0
|
||
c["members"].append(idx)
|
||
else:
|
||
clusters.append({"sum": array("f", v), "norm_sq": 1.0, "members": [idx], "sample_text": it.get("text") or ""})
|
||
|
||
# Merge highly similar clusters (fix: same theme split into multiple clusters)
|
||
merge_threshold = min(0.995, max(threshold + 0.04, 0.90))
|
||
merged = True
|
||
while merged and len(clusters) > 1:
|
||
merged = False
|
||
best_i = best_j = -1
|
||
best_sim = merge_threshold
|
||
for i in range(len(clusters)):
|
||
ci = clusters[i]
|
||
for j in range(i + 1, len(clusters)):
|
||
cj = clusters[j]
|
||
sim = _cos_sum(ci["sum"], ci["norm_sq"], cj["sum"], cj["norm_sq"])
|
||
if sim >= best_sim:
|
||
best_sim = sim
|
||
best_i, best_j = i, j
|
||
if best_i != -1:
|
||
a = clusters[best_i]
|
||
b = clusters[best_j]
|
||
for k in range(len(a["sum"])):
|
||
a["sum"][k] += b["sum"][k]
|
||
a["norm_sq"] = _l2_norm_sq(a["sum"])
|
||
a["members"].extend(b["members"])
|
||
if not a.get("sample_text"):
|
||
a["sample_text"] = b.get("sample_text", "")
|
||
clusters.pop(best_j)
|
||
merged = True
|
||
|
||
# Reassign members from small clusters into best large cluster to reduce noise
|
||
if min_cluster_size > 1 and assign_noise_threshold > 0 and clusters:
|
||
large = [c for c in clusters if len(c["members"]) >= min_cluster_size]
|
||
if large:
|
||
new_large = []
|
||
# copy large clusters first
|
||
for c in clusters:
|
||
if len(c["members"]) >= min_cluster_size:
|
||
new_large.append(c)
|
||
# reassign items from small clusters
|
||
for c in clusters:
|
||
if len(c["members"]) >= min_cluster_size:
|
||
continue
|
||
for mi in c["members"]:
|
||
v = items[mi]["vec"]
|
||
best_ci = -1
|
||
best_sim = -1.0
|
||
best_dot = 0.0
|
||
for ci, bigc in enumerate(new_large):
|
||
dotv = _dot(v, bigc["sum"])
|
||
denom = math.sqrt(bigc["norm_sq"]) if bigc["norm_sq"] > 0 else 1.0
|
||
sim = dotv / denom
|
||
if sim > best_sim:
|
||
best_sim = sim
|
||
best_ci = ci
|
||
best_dot = dotv
|
||
if best_ci != -1 and best_sim >= assign_noise_threshold:
|
||
bigc = new_large[best_ci]
|
||
for k in range(len(v)):
|
||
bigc["sum"][k] += v[k]
|
||
bigc["norm_sq"] = bigc["norm_sq"] + 2.0 * best_dot + 1.0
|
||
bigc["members"].append(mi)
|
||
# else: keep in small cluster -> will become noise below
|
||
clusters = new_large
|
||
|
||
# Split small clusters to noise, generate titles
|
||
out_clusters = []
|
||
noise = []
|
||
for cidx, c in enumerate(clusters):
|
||
if len(c["members"]) < min_cluster_size:
|
||
for mi in c["members"]:
|
||
noise.append(items[mi]["path"])
|
||
continue
|
||
|
||
member_items = [items[mi] for mi in c["members"]]
|
||
paths = [x["path"] for x in member_items]
|
||
texts = [x.get("text") or "" for x in member_items]
|
||
member_ids = [x["id"] for x in member_items]
|
||
|
||
# Representative prompt for LLM title generation
|
||
rep = (c.get("sample_text") or (texts[0] if texts else "")).strip()
|
||
|
||
cached = None
|
||
cluster_hash = _cluster_sig(
|
||
member_ids=member_ids,
|
||
model=model,
|
||
threshold=threshold,
|
||
min_cluster_size=min_cluster_size,
|
||
title_model=title_model,
|
||
lang=output_lang,
|
||
)
|
||
if use_title_cache and (not force_title):
|
||
cached = TopicTitleCache.get(conn, cluster_hash)
|
||
if cached and isinstance(cached, dict) and cached.get("title"):
|
||
title = str(cached.get("title"))
|
||
keywords = cached.get("keywords") or []
|
||
else:
|
||
llm = _call_chat_title(
|
||
base_url=openai_base_url,
|
||
api_key=openai_api_key,
|
||
model=title_model,
|
||
prompt_samples=[rep] + texts[:5],
|
||
output_lang=output_lang,
|
||
)
|
||
title = (llm or {}).get("title") if isinstance(llm, dict) else None
|
||
keywords = (llm or {}).get("keywords", []) if isinstance(llm, dict) else []
|
||
if not title:
|
||
raise HTTPException(status_code=502, detail="Chat API returned empty title")
|
||
if use_title_cache and title:
|
||
try:
|
||
TopicTitleCache.upsert(conn, cluster_hash, str(title), list(keywords or []), str(title_model))
|
||
conn.commit()
|
||
except Exception:
|
||
pass
|
||
|
||
out_clusters.append(
|
||
{
|
||
"id": f"topic_{cidx}",
|
||
"title": title,
|
||
"keywords": keywords,
|
||
"size": len(paths),
|
||
"paths": paths,
|
||
"sample_prompt": _clean_for_title(rep)[:200],
|
||
}
|
||
)
|
||
|
||
out_clusters.sort(key=lambda x: x["size"], reverse=True)
|
||
return {
|
||
"folder": folder,
|
||
"folders": folders,
|
||
"model": model,
|
||
"threshold": threshold,
|
||
"min_cluster_size": min_cluster_size,
|
||
"clusters": out_clusters,
|
||
"noise": noise,
|
||
"count": len(items),
|
||
"title_model": title_model,
|
||
}
|
||
|
||
|