Improve keyword consistency in topic clustering
- Add get_all_keywords_frequency method to TopicTitleCache - Initialize keyword frequency from historical cached cluster keywords - Prioritize top 100 high-frequency keywords when generating new keywords - Update LLM prompt to prefer existing keywords from frequency list - Reduce duplicate/similar keyword generation across clusters - Add streaming support for tag_graph LLM requests - Increase LLM timeout and retry limits for better reliabilityfeature/tag-relationship-graph
parent
6c16bd0d82
commit
0c1998bdb0
|
|
@ -516,6 +516,38 @@ class TopicTitleCache:
|
|||
kw = []
|
||||
return {"title": title, "keywords": kw, "model": model, "updated_at": updated_at}
|
||||
|
||||
@classmethod
|
||||
def get_all_keywords_frequency(cls, conn: Connection, model: Optional[str] = None) -> Dict[str, int]:
|
||||
"""
|
||||
Get keyword frequency from all cached clusters.
|
||||
Optionally filter by model.
|
||||
Returns a dictionary mapping keyword -> frequency.
|
||||
"""
|
||||
with closing(conn.cursor()) as cur:
|
||||
if model:
|
||||
cur.execute(
|
||||
"SELECT keywords FROM topic_title_cache WHERE model = ?",
|
||||
(model,),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"SELECT keywords FROM topic_title_cache",
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
keyword_frequency: Dict[str, int] = {}
|
||||
for row in rows:
|
||||
keywords_str = row[0] if row else None
|
||||
try:
|
||||
keywords = json.loads(keywords_str) if isinstance(keywords_str, str) else []
|
||||
except Exception:
|
||||
keywords = []
|
||||
if isinstance(keywords, list):
|
||||
for kw in keywords:
|
||||
if isinstance(kw, str) and kw.strip():
|
||||
keyword_frequency[kw] = keyword_frequency.get(kw, 0) + 1
|
||||
return keyword_frequency
|
||||
|
||||
@classmethod
|
||||
def upsert(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -10,18 +10,19 @@ import time
|
|||
from typing import Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from scripts.iib.db.datamodel import DataBase, GlobalSetting
|
||||
from scripts.iib.tool import normalize_output_lang
|
||||
from scripts.iib.logger import logger
|
||||
|
||||
# Cache version for tag abstraction - increment to invalidate all caches
|
||||
TAG_ABSTRACTION_CACHE_VERSION = 3
|
||||
TAG_ABSTRACTION_CACHE_VERSION = 2.1
|
||||
TAG_GRAPH_CACHE_VERSION = 1
|
||||
_MAX_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_MAX_TAGS_FOR_LLM", "200") or "200")
|
||||
_TOPK_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_TOPK_TAGS_FOR_LLM", "200") or "200")
|
||||
_LLM_REQUEST_TIMEOUT_SEC = int(os.getenv("IIB_TAG_GRAPH_LLM_TIMEOUT_SEC", "30") or "30")
|
||||
_LLM_MAX_ATTEMPTS = int(os.getenv("IIB_TAG_GRAPH_LLM_MAX_ATTEMPTS", "2") or "2")
|
||||
_MAX_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_MAX_TAGS_FOR_LLM", "300") or "300")
|
||||
_TOPK_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_TOPK_TAGS_FOR_LLM", "300") or "300")
|
||||
_LLM_REQUEST_TIMEOUT_SEC = int(os.getenv("IIB_TAG_GRAPH_LLM_TIMEOUT_SEC", "180") or "180")
|
||||
_LLM_MAX_ATTEMPTS = int(os.getenv("IIB_TAG_GRAPH_LLM_MAX_ATTEMPTS", "5") or "5")
|
||||
|
||||
|
||||
class TagGraphReq(BaseModel):
|
||||
|
|
@ -95,6 +96,7 @@ def mount_tag_graph_routes(
|
|||
|
||||
# Normalize language for consistent LLM output
|
||||
normalized_lang = normalize_output_lang(lang)
|
||||
print(f"tags length: {len(tags)}")
|
||||
sys_prompt = f"""You are a tag categorization assistant. Organize tags into hierarchical categories.
|
||||
|
||||
STRICT RULES:
|
||||
|
|
@ -118,88 +120,110 @@ If unsure about Level 2, OMIT it entirely. Start response with {{ and end with }
|
|||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 2048,
|
||||
"max_tokens": 8096,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Retry a few times then fallback quickly (to avoid frontend timeout on large datasets).
|
||||
# Use streaming requests to avoid blocking too long on a single non-stream response.
|
||||
last_error = ""
|
||||
for attempt in range(1, _LLM_MAX_ATTEMPTS + 1):
|
||||
try:
|
||||
resp = requests.post(url, json=payload, headers=headers, timeout=_LLM_REQUEST_TIMEOUT_SEC)
|
||||
resp = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=_LLM_REQUEST_TIMEOUT_SEC,
|
||||
stream=True,
|
||||
)
|
||||
# Check status early
|
||||
if resp.status_code != 200:
|
||||
body = (resp.text or "")[:400]
|
||||
if resp.status_code == 429 or resp.status_code >= 500:
|
||||
last_error = f"api_error_retriable: status={resp.status_code}"
|
||||
logger.warning("[tag_graph] llm_http_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
|
||||
continue
|
||||
logger.error("[tag_graph] llm_http_client_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
|
||||
raise Exception(f"API client error: {resp.status_code} {body}")
|
||||
|
||||
# Accumulate streamed content chunks
|
||||
content_buffer = ""
|
||||
for raw in resp.iter_lines(decode_unicode=False):
|
||||
if not raw:
|
||||
continue
|
||||
# Ensure explicit UTF-8 decoding to avoid mojibake
|
||||
try:
|
||||
line = raw.decode('utf-8') if isinstance(raw, (bytes, bytearray)) else str(raw)
|
||||
except Exception:
|
||||
line = raw.decode('utf-8', errors='replace') if isinstance(raw, (bytes, bytearray)) else str(raw)
|
||||
line = line.strip()
|
||||
if line.startswith('data: '):
|
||||
line = line[6:].strip()
|
||||
if line == '[DONE]':
|
||||
break
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except Exception:
|
||||
# Some providers may return partial JSON or non-JSON lines; skip
|
||||
continue
|
||||
# Try to extract incremental content (compat with OpenAI-style streaming)
|
||||
delta = (obj.get('choices') or [{}])[0].get('delta') or {}
|
||||
chunk_text = delta.get('content') or ''
|
||||
if chunk_text:
|
||||
# Print when data chunk received (truncate for safety)
|
||||
try:
|
||||
print(f"[tag_graph] stream_chunk_received len={len(chunk_text)} snippet={chunk_text[:200]}")
|
||||
except Exception:
|
||||
pass
|
||||
content_buffer += chunk_text
|
||||
|
||||
content = content_buffer.strip()
|
||||
print(f"content: {content}")
|
||||
# Strategy 1: Direct parse (if response is pure JSON)
|
||||
try:
|
||||
result = json.loads(content)
|
||||
if isinstance(result, dict) and 'layers' in result:
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Strategy 2: Extract JSON from markdown code blocks
|
||||
json_str = None
|
||||
code_block = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", content)
|
||||
if code_block:
|
||||
json_str = code_block.group(1)
|
||||
else:
|
||||
m = re.search(r"\{[\s\S]*\}", content)
|
||||
if m:
|
||||
json_str = m.group(0)
|
||||
|
||||
if not json_str:
|
||||
last_error = f"no_json_found: {content[:200]}"
|
||||
logger.warning("[tag_graph] llm_no_json attempt=%s err=%s", attempt, last_error)
|
||||
continue
|
||||
|
||||
# Clean up common JSON issues
|
||||
json_str = json_str.strip()
|
||||
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
|
||||
|
||||
try:
|
||||
result = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
last_error = f"json_parse_error: {e}"
|
||||
logger.warning("[tag_graph] llm_json_parse_error attempt=%s err=%s json=%s", attempt, last_error, json_str[:400])
|
||||
continue
|
||||
|
||||
if not isinstance(result, dict) or 'layers' not in result:
|
||||
last_error = f"invalid_structure: {str(result)[:200]}"
|
||||
logger.warning("[tag_graph] llm_invalid_structure attempt=%s err=%s", attempt, last_error)
|
||||
continue
|
||||
|
||||
return result
|
||||
except requests.RequestException as e:
|
||||
last_error = f"network_error: {type(e).__name__}: {e}"
|
||||
logger.warning("[tag_graph] llm_request_error attempt=%s err=%s", attempt, last_error)
|
||||
continue
|
||||
|
||||
# Retry on 429 or 5xx, fail immediately on other 4xx
|
||||
if resp.status_code != 200:
|
||||
body = (resp.text or "")[:400]
|
||||
if resp.status_code == 429 or resp.status_code >= 500:
|
||||
last_error = f"api_error_retriable: status={resp.status_code}"
|
||||
logger.warning("[tag_graph] llm_http_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
|
||||
continue
|
||||
# 4xx client error - fail immediately
|
||||
logger.error("[tag_graph] llm_http_client_error attempt=%s status=%s body=%s", attempt, resp.status_code, body)
|
||||
raise Exception(f"API client error: {resp.status_code} {body}")
|
||||
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
last_error = f"response_not_json: {type(e).__name__}"
|
||||
logger.warning("[tag_graph] llm_response_not_json attempt=%s err=%s body=%s", attempt, last_error, (resp.text or "")[:400])
|
||||
continue
|
||||
|
||||
choice0 = (data.get("choices") or [{}])[0]
|
||||
msg = (choice0 or {}).get("message") or {}
|
||||
content = (msg.get("content") or "").strip()
|
||||
|
||||
# Extract JSON from content - try multiple strategies
|
||||
json_str = None
|
||||
|
||||
# Strategy 1: Direct parse (if response is pure JSON)
|
||||
try:
|
||||
result = json.loads(content)
|
||||
if isinstance(result, dict) and "layers" in result:
|
||||
return result
|
||||
except:
|
||||
pass
|
||||
|
||||
# Strategy 2: Extract JSON from markdown code blocks
|
||||
code_block = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", content)
|
||||
if code_block:
|
||||
json_str = code_block.group(1)
|
||||
else:
|
||||
# Strategy 3: Find largest JSON object
|
||||
m = re.search(r"\{[\s\S]*\}", content)
|
||||
if m:
|
||||
json_str = m.group(0)
|
||||
|
||||
if not json_str:
|
||||
last_error = f"no_json_found: {content[:200]}"
|
||||
logger.warning("[tag_graph] llm_no_json attempt=%s err=%s", attempt, last_error)
|
||||
continue
|
||||
|
||||
# Clean up common JSON issues
|
||||
json_str = json_str.strip()
|
||||
# Remove trailing commas before closing braces/brackets
|
||||
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
|
||||
|
||||
try:
|
||||
result = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
last_error = f"json_parse_error: {e}"
|
||||
logger.warning("[tag_graph] llm_json_parse_error attempt=%s err=%s json=%s", attempt, last_error, json_str[:400])
|
||||
continue
|
||||
|
||||
# Validate structure
|
||||
if not isinstance(result, dict) or "layers" not in result:
|
||||
last_error = f"invalid_structure: {str(result)[:200]}"
|
||||
logger.warning("[tag_graph] llm_invalid_structure attempt=%s err=%s", attempt, last_error)
|
||||
continue
|
||||
|
||||
# Success!
|
||||
return result
|
||||
|
||||
# No fallback: expose error to frontend, but log enough info for debugging.
|
||||
logger.error(
|
||||
"[tag_graph] llm_failed attempts=%s timeout_sec=%s last_error=%s",
|
||||
|
|
@ -218,7 +242,7 @@ If unsure about Level 2, OMIT it entirely. Start response with {{ and end with }
|
|||
)
|
||||
|
||||
return await asyncio.to_thread(_call_sync)
|
||||
|
||||
|
||||
@app.post(
|
||||
f"{db_api_base}/cluster_tag_graph",
|
||||
dependencies=[Depends(verify_secret)],
|
||||
|
|
|
|||
|
|
@ -448,6 +448,7 @@ def _call_chat_title_sync(
|
|||
model: str,
|
||||
prompt_samples: List[str],
|
||||
output_lang: str,
|
||||
existing_keywords: Optional[List[str]] = None,
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Ask LLM to generate a short topic title and a few keywords. Returns dict or None.
|
||||
|
|
@ -477,9 +478,16 @@ def _call_chat_title_sync(
|
|||
"- Do NOT output explanations. Do NOT output markdown/code fences.\n"
|
||||
"- The output MUST start with '{' and end with '}' (no leading/trailing characters).\n"
|
||||
"\n"
|
||||
"Output STRICT JSON only:\n"
|
||||
+ json_example
|
||||
)
|
||||
if existing_keywords:
|
||||
top_keywords = existing_keywords[:100]
|
||||
sys += (
|
||||
f"IMPORTANT: You MUST prioritize selecting keywords from this existing list. "
|
||||
f"Use keywords from the list that best match the current theme. "
|
||||
f"Only create new keywords when absolutely necessary.\n"
|
||||
f"Existing keywords (top {len(top_keywords)}): {', '.join(top_keywords)}\n\n"
|
||||
)
|
||||
sys += "Output STRICT JSON only:\n" + json_example
|
||||
user = "Prompt snippets:\n" + "\n".join([f"- {s}" for s in samples])
|
||||
|
||||
payload = {
|
||||
|
|
@ -590,6 +598,7 @@ async def _call_chat_title(
|
|||
model: str,
|
||||
prompt_samples: List[str],
|
||||
output_lang: str,
|
||||
existing_keywords: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Same rationale as embeddings:
|
||||
|
|
@ -603,6 +612,7 @@ async def _call_chat_title(
|
|||
model=model,
|
||||
prompt_samples=prompt_samples,
|
||||
output_lang=output_lang,
|
||||
existing_keywords=existing_keywords,
|
||||
)
|
||||
if not isinstance(ret, dict):
|
||||
raise HTTPException(status_code=502, detail="Chat API returned empty title payload")
|
||||
|
|
@ -1412,6 +1422,15 @@ def mount_topic_cluster_routes(
|
|||
if progress_cb:
|
||||
progress_cb({"stage": "titling", "clusters_total": len(clusters)})
|
||||
|
||||
existing_keywords: List[str] = []
|
||||
keyword_frequency: Dict[str, int] = TopicTitleCache.get_all_keywords_frequency(conn, model)
|
||||
|
||||
def _get_top_keywords() -> List[str]:
|
||||
if not keyword_frequency:
|
||||
return []
|
||||
sorted_keywords = sorted(keyword_frequency.items(), key=lambda x: x[1], reverse=True)
|
||||
return [k for k, v in sorted_keywords[:100]]
|
||||
|
||||
for cidx, c in enumerate(clusters):
|
||||
if len(c["members"]) < min_cluster_size:
|
||||
for mi in c["members"]:
|
||||
|
|
@ -1441,12 +1460,14 @@ def mount_topic_cluster_routes(
|
|||
title = str(cached.get("title"))
|
||||
keywords = cached.get("keywords") or []
|
||||
else:
|
||||
top_keywords = _get_top_keywords()
|
||||
llm = await _call_chat_title(
|
||||
base_url=openai_base_url,
|
||||
api_key=openai_api_key,
|
||||
model=title_model,
|
||||
prompt_samples=[rep] + texts[:5],
|
||||
output_lang=output_lang,
|
||||
existing_keywords=top_keywords,
|
||||
)
|
||||
title = (llm or {}).get("title")
|
||||
keywords = (llm or {}).get("keywords", [])
|
||||
|
|
@ -1459,6 +1480,9 @@ def mount_topic_cluster_routes(
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
for kw in keywords or []:
|
||||
keyword_frequency[kw] = keyword_frequency.get(kw, 0) + 1
|
||||
|
||||
out_clusters.append(
|
||||
{
|
||||
"id": f"topic_{cidx}",
|
||||
|
|
|
|||
Loading…
Reference in New Issue