Improve keyword consistency in topic clustering

- Add get_all_keywords_frequency method to TopicTitleCache
- Initialize keyword frequency from historical cached cluster keywords
- Prioritize top 100 high-frequency keywords when generating new keywords
- Update LLM prompt to prefer existing keywords from frequency list
- Reduce duplicate/similar keyword generation across clusters
- Add streaming support for tag_graph LLM requests
- Increase LLM timeout and retry limits for better reliability
feature/tag-relationship-graph
zanllp 2026-01-12 00:56:03 +08:00
parent 6c16bd0d82
commit 0c1998bdb0
3 changed files with 159 additions and 79 deletions

View File

@ -516,6 +516,38 @@ class TopicTitleCache:
kw = []
return {"title": title, "keywords": kw, "model": model, "updated_at": updated_at}
@classmethod
def get_all_keywords_frequency(cls, conn: Connection, model: Optional[str] = None) -> Dict[str, int]:
"""
Get keyword frequency from all cached clusters.
Optionally filter by model.
Returns a dictionary mapping keyword -> frequency.
"""
with closing(conn.cursor()) as cur:
if model:
cur.execute(
"SELECT keywords FROM topic_title_cache WHERE model = ?",
(model,),
)
else:
cur.execute(
"SELECT keywords FROM topic_title_cache",
)
rows = cur.fetchall()
keyword_frequency: Dict[str, int] = {}
for row in rows:
keywords_str = row[0] if row else None
try:
keywords = json.loads(keywords_str) if isinstance(keywords_str, str) else []
except Exception:
keywords = []
if isinstance(keywords, list):
for kw in keywords:
if isinstance(kw, str) and kw.strip():
keyword_frequency[kw] = keyword_frequency.get(kw, 0) + 1
return keyword_frequency
@classmethod
def upsert(
cls,

View File

@ -10,18 +10,19 @@ import time
from typing import Dict, List, Optional
from pydantic import BaseModel
from fastapi import Depends, FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from scripts.iib.db.datamodel import DataBase, GlobalSetting
from scripts.iib.tool import normalize_output_lang
from scripts.iib.logger import logger
# Cache version for tag abstraction - increment to invalidate all caches
TAG_ABSTRACTION_CACHE_VERSION = 3
TAG_ABSTRACTION_CACHE_VERSION = 2.1
TAG_GRAPH_CACHE_VERSION = 1
_MAX_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_MAX_TAGS_FOR_LLM", "200") or "200")
_TOPK_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_TOPK_TAGS_FOR_LLM", "200") or "200")
_LLM_REQUEST_TIMEOUT_SEC = int(os.getenv("IIB_TAG_GRAPH_LLM_TIMEOUT_SEC", "30") or "30")
_LLM_MAX_ATTEMPTS = int(os.getenv("IIB_TAG_GRAPH_LLM_MAX_ATTEMPTS", "2") or "2")
_MAX_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_MAX_TAGS_FOR_LLM", "300") or "300")
_TOPK_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_TOPK_TAGS_FOR_LLM", "300") or "300")
_LLM_REQUEST_TIMEOUT_SEC = int(os.getenv("IIB_TAG_GRAPH_LLM_TIMEOUT_SEC", "180") or "180")
_LLM_MAX_ATTEMPTS = int(os.getenv("IIB_TAG_GRAPH_LLM_MAX_ATTEMPTS", "5") or "5")
class TagGraphReq(BaseModel):
@ -95,6 +96,7 @@ def mount_tag_graph_routes(
# Normalize language for consistent LLM output
normalized_lang = normalize_output_lang(lang)
print(f"tags length: {len(tags)}")
sys_prompt = f"""You are a tag categorization assistant. Organize tags into hierarchical categories.
STRICT RULES:
@ -118,88 +120,110 @@ If unsure about Level 2, OMIT it entirely. Start response with {{ and end with }
{"role": "user", "content": user_prompt}
],
"temperature": 0.0,
"max_tokens": 2048,
"max_tokens": 8096,
"stream": True,
}
# Retry a few times then fallback quickly (to avoid frontend timeout on large datasets).
# Use streaming requests to avoid blocking too long on a single non-stream response.
last_error = ""
for attempt in range(1, _LLM_MAX_ATTEMPTS + 1):
try:
resp = requests.post(url, json=payload, headers=headers, timeout=_LLM_REQUEST_TIMEOUT_SEC)
resp = requests.post(
url,
json=payload,
headers=headers,
timeout=_LLM_REQUEST_TIMEOUT_SEC,
stream=True,
)
# Check status early
if resp.status_code != 200:
body = (resp.text or "")[:400]
if resp.status_code == 429 or resp.status_code >= 500:
last_error = f"api_error_retriable: status={resp.status_code}"
logger.warning("[tag_graph] llm_http_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
continue
logger.error("[tag_graph] llm_http_client_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
raise Exception(f"API client error: {resp.status_code} {body}")
# Accumulate streamed content chunks
content_buffer = ""
for raw in resp.iter_lines(decode_unicode=False):
if not raw:
continue
# Ensure explicit UTF-8 decoding to avoid mojibake
try:
line = raw.decode('utf-8') if isinstance(raw, (bytes, bytearray)) else str(raw)
except Exception:
line = raw.decode('utf-8', errors='replace') if isinstance(raw, (bytes, bytearray)) else str(raw)
line = line.strip()
if line.startswith('data: '):
line = line[6:].strip()
if line == '[DONE]':
break
try:
obj = json.loads(line)
except Exception:
# Some providers may return partial JSON or non-JSON lines; skip
continue
# Try to extract incremental content (compat with OpenAI-style streaming)
delta = (obj.get('choices') or [{}])[0].get('delta') or {}
chunk_text = delta.get('content') or ''
if chunk_text:
# Print when data chunk received (truncate for safety)
try:
print(f"[tag_graph] stream_chunk_received len={len(chunk_text)} snippet={chunk_text[:200]}")
except Exception:
pass
content_buffer += chunk_text
content = content_buffer.strip()
print(f"content: {content}")
# Strategy 1: Direct parse (if response is pure JSON)
try:
result = json.loads(content)
if isinstance(result, dict) and 'layers' in result:
return result
except Exception:
pass
# Strategy 2: Extract JSON from markdown code blocks
json_str = None
code_block = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", content)
if code_block:
json_str = code_block.group(1)
else:
m = re.search(r"\{[\s\S]*\}", content)
if m:
json_str = m.group(0)
if not json_str:
last_error = f"no_json_found: {content[:200]}"
logger.warning("[tag_graph] llm_no_json attempt=%s err=%s", attempt, last_error)
continue
# Clean up common JSON issues
json_str = json_str.strip()
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
try:
result = json.loads(json_str)
except json.JSONDecodeError as e:
last_error = f"json_parse_error: {e}"
logger.warning("[tag_graph] llm_json_parse_error attempt=%s err=%s json=%s", attempt, last_error, json_str[:400])
continue
if not isinstance(result, dict) or 'layers' not in result:
last_error = f"invalid_structure: {str(result)[:200]}"
logger.warning("[tag_graph] llm_invalid_structure attempt=%s err=%s", attempt, last_error)
continue
return result
except requests.RequestException as e:
last_error = f"network_error: {type(e).__name__}: {e}"
logger.warning("[tag_graph] llm_request_error attempt=%s err=%s", attempt, last_error)
continue
# Retry on 429 or 5xx, fail immediately on other 4xx
if resp.status_code != 200:
body = (resp.text or "")[:400]
if resp.status_code == 429 or resp.status_code >= 500:
last_error = f"api_error_retriable: status={resp.status_code}"
logger.warning("[tag_graph] llm_http_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
continue
# 4xx client error - fail immediately
logger.error("[tag_graph] llm_http_client_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
raise Exception(f"API client error: {resp.status_code} {body}")
try:
data = resp.json()
except Exception as e:
last_error = f"response_not_json: {type(e).__name__}"
logger.warning("[tag_graph] llm_response_not_json attempt=%s err=%s body=%s", attempt, last_error, (resp.text or "")[:400])
continue
choice0 = (data.get("choices") or [{}])[0]
msg = (choice0 or {}).get("message") or {}
content = (msg.get("content") or "").strip()
# Extract JSON from content - try multiple strategies
json_str = None
# Strategy 1: Direct parse (if response is pure JSON)
try:
result = json.loads(content)
if isinstance(result, dict) and "layers" in result:
return result
except:
pass
# Strategy 2: Extract JSON from markdown code blocks
code_block = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", content)
if code_block:
json_str = code_block.group(1)
else:
# Strategy 3: Find largest JSON object
m = re.search(r"\{[\s\S]*\}", content)
if m:
json_str = m.group(0)
if not json_str:
last_error = f"no_json_found: {content[:200]}"
logger.warning("[tag_graph] llm_no_json attempt=%s err=%s", attempt, last_error)
continue
# Clean up common JSON issues
json_str = json_str.strip()
# Remove trailing commas before closing braces/brackets
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
try:
result = json.loads(json_str)
except json.JSONDecodeError as e:
last_error = f"json_parse_error: {e}"
logger.warning("[tag_graph] llm_json_parse_error attempt=%s err=%s json=%s", attempt, last_error, json_str[:400])
continue
# Validate structure
if not isinstance(result, dict) or "layers" not in result:
last_error = f"invalid_structure: {str(result)[:200]}"
logger.warning("[tag_graph] llm_invalid_structure attempt=%s err=%s", attempt, last_error)
continue
# Success!
return result
# No fallback: expose error to frontend, but log enough info for debugging.
logger.error(
"[tag_graph] llm_failed attempts=%s timeout_sec=%s last_error=%s",
@ -218,7 +242,7 @@ If unsure about Level 2, OMIT it entirely. Start response with {{ and end with }
)
return await asyncio.to_thread(_call_sync)
@app.post(
f"{db_api_base}/cluster_tag_graph",
dependencies=[Depends(verify_secret)],

View File

@ -448,6 +448,7 @@ def _call_chat_title_sync(
model: str,
prompt_samples: List[str],
output_lang: str,
existing_keywords: Optional[List[str]] = None,
) -> Optional[Dict]:
"""
Ask LLM to generate a short topic title and a few keywords. Returns dict or None.
@ -477,9 +478,16 @@ def _call_chat_title_sync(
"- Do NOT output explanations. Do NOT output markdown/code fences.\n"
"- The output MUST start with '{' and end with '}' (no leading/trailing characters).\n"
"\n"
"Output STRICT JSON only:\n"
+ json_example
)
if existing_keywords:
top_keywords = existing_keywords[:100]
sys += (
f"IMPORTANT: You MUST prioritize selecting keywords from this existing list. "
f"Use keywords from the list that best match the current theme. "
f"Only create new keywords when absolutely necessary.\n"
f"Existing keywords (top {len(top_keywords)}): {', '.join(top_keywords)}\n\n"
)
sys += "Output STRICT JSON only:\n" + json_example
user = "Prompt snippets:\n" + "\n".join([f"- {s}" for s in samples])
payload = {
@ -590,6 +598,7 @@ async def _call_chat_title(
model: str,
prompt_samples: List[str],
output_lang: str,
existing_keywords: Optional[List[str]] = None,
) -> Dict:
"""
Same rationale as embeddings:
@ -603,6 +612,7 @@ async def _call_chat_title(
model=model,
prompt_samples=prompt_samples,
output_lang=output_lang,
existing_keywords=existing_keywords,
)
if not isinstance(ret, dict):
raise HTTPException(status_code=502, detail="Chat API returned empty title payload")
@ -1412,6 +1422,15 @@ def mount_topic_cluster_routes(
if progress_cb:
progress_cb({"stage": "titling", "clusters_total": len(clusters)})
existing_keywords: List[str] = []
keyword_frequency: Dict[str, int] = TopicTitleCache.get_all_keywords_frequency(conn, model)
def _get_top_keywords() -> List[str]:
if not keyword_frequency:
return []
sorted_keywords = sorted(keyword_frequency.items(), key=lambda x: x[1], reverse=True)
return [k for k, v in sorted_keywords[:100]]
for cidx, c in enumerate(clusters):
if len(c["members"]) < min_cluster_size:
for mi in c["members"]:
@ -1441,12 +1460,14 @@ def mount_topic_cluster_routes(
title = str(cached.get("title"))
keywords = cached.get("keywords") or []
else:
top_keywords = _get_top_keywords()
llm = await _call_chat_title(
base_url=openai_base_url,
api_key=openai_api_key,
model=title_model,
prompt_samples=[rep] + texts[:5],
output_lang=output_lang,
existing_keywords=top_keywords,
)
title = (llm or {}).get("title")
keywords = (llm or {}).get("keywords", [])
@ -1459,6 +1480,9 @@ def mount_topic_cluster_routes(
except Exception:
pass
for kw in keywords or []:
keyword_frequency[kw] = keyword_frequency.get(kw, 0) + 1
out_clusters.append(
{
"id": f"topic_{cidx}",