diff --git a/scripts/iib/db/datamodel.py b/scripts/iib/db/datamodel.py index 90bf3f1..e35af0f 100644 --- a/scripts/iib/db/datamodel.py +++ b/scripts/iib/db/datamodel.py @@ -516,6 +516,38 @@ class TopicTitleCache: kw = [] return {"title": title, "keywords": kw, "model": model, "updated_at": updated_at} + @classmethod + def get_all_keywords_frequency(cls, conn: Connection, model: Optional[str] = None) -> Dict[str, int]: + """ + Get keyword frequency from all cached clusters. + Optionally filter by model. + Returns a dictionary mapping keyword -> frequency. + """ + with closing(conn.cursor()) as cur: + if model: + cur.execute( + "SELECT keywords FROM topic_title_cache WHERE model = ?", + (model,), + ) + else: + cur.execute( + "SELECT keywords FROM topic_title_cache", + ) + rows = cur.fetchall() + + keyword_frequency: Dict[str, int] = {} + for row in rows: + keywords_str = row[0] if row else None + try: + keywords = json.loads(keywords_str) if isinstance(keywords_str, str) else [] + except Exception: + keywords = [] + if isinstance(keywords, list): + for kw in keywords: + if isinstance(kw, str) and kw.strip(): + keyword_frequency[kw] = keyword_frequency.get(kw, 0) + 1 + return keyword_frequency + @classmethod def upsert( cls, diff --git a/scripts/iib/tag_graph.py b/scripts/iib/tag_graph.py index 48b338e..d480c00 100644 --- a/scripts/iib/tag_graph.py +++ b/scripts/iib/tag_graph.py @@ -10,18 +10,19 @@ import time from typing import Dict, List, Optional from pydantic import BaseModel from fastapi import Depends, FastAPI, HTTPException +from fastapi.responses import StreamingResponse from scripts.iib.db.datamodel import DataBase, GlobalSetting from scripts.iib.tool import normalize_output_lang from scripts.iib.logger import logger # Cache version for tag abstraction - increment to invalidate all caches -TAG_ABSTRACTION_CACHE_VERSION = 3 +TAG_ABSTRACTION_CACHE_VERSION = 2.1 TAG_GRAPH_CACHE_VERSION = 1 -_MAX_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_MAX_TAGS_FOR_LLM", "200") or "200") -_TOPK_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_TOPK_TAGS_FOR_LLM", "200") or "200") -_LLM_REQUEST_TIMEOUT_SEC = int(os.getenv("IIB_TAG_GRAPH_LLM_TIMEOUT_SEC", "30") or "30") -_LLM_MAX_ATTEMPTS = int(os.getenv("IIB_TAG_GRAPH_LLM_MAX_ATTEMPTS", "2") or "2") +_MAX_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_MAX_TAGS_FOR_LLM", "300") or "300") +_TOPK_TAGS_FOR_LLM = int(os.getenv("IIB_TAG_GRAPH_TOPK_TAGS_FOR_LLM", "300") or "300") +_LLM_REQUEST_TIMEOUT_SEC = int(os.getenv("IIB_TAG_GRAPH_LLM_TIMEOUT_SEC", "180") or "180") +_LLM_MAX_ATTEMPTS = int(os.getenv("IIB_TAG_GRAPH_LLM_MAX_ATTEMPTS", "5") or "5") class TagGraphReq(BaseModel): @@ -95,6 +96,7 @@ def mount_tag_graph_routes( # Normalize language for consistent LLM output normalized_lang = normalize_output_lang(lang) + print(f"tags length: {len(tags)}") sys_prompt = f"""You are a tag categorization assistant. Organize tags into hierarchical categories. STRICT RULES: @@ -118,88 +120,110 @@ If unsure about Level 2, OMIT it entirely. Start response with {{ and end with } {"role": "user", "content": user_prompt} ], "temperature": 0.0, - "max_tokens": 2048, + "max_tokens": 8096, + "stream": True, } # Retry a few times then fallback quickly (to avoid frontend timeout on large datasets). + # Use streaming requests to avoid blocking too long on a single non-stream response. last_error = "" for attempt in range(1, _LLM_MAX_ATTEMPTS + 1): try: - resp = requests.post(url, json=payload, headers=headers, timeout=_LLM_REQUEST_TIMEOUT_SEC) + resp = requests.post( + url, + json=payload, + headers=headers, + timeout=_LLM_REQUEST_TIMEOUT_SEC, + stream=True, + ) + # Check status early + if resp.status_code != 200: + body = (resp.text or "")[:400] + if resp.status_code == 429 or resp.status_code >= 500: + last_error = f"api_error_retriable: status={resp.status_code}" + logger.warning("[tag_graph] llm_http_error attempt=%s status=%s body=%s", attempt, resp.status_code, body) + continue + logger.error("[tag_graph] llm_http_client_error attempt=%s status=%s body=%s", attempt, resp.status_code, body) + raise Exception(f"API client error: {resp.status_code} {body}") + + # Accumulate streamed content chunks + content_buffer = "" + for raw in resp.iter_lines(decode_unicode=False): + if not raw: + continue + # Ensure explicit UTF-8 decoding to avoid mojibake + try: + line = raw.decode('utf-8') if isinstance(raw, (bytes, bytearray)) else str(raw) + except Exception: + line = raw.decode('utf-8', errors='replace') if isinstance(raw, (bytes, bytearray)) else str(raw) + line = line.strip() + if line.startswith('data: '): + line = line[6:].strip() + if line == '[DONE]': + break + try: + obj = json.loads(line) + except Exception: + # Some providers may return partial JSON or non-JSON lines; skip + continue + # Try to extract incremental content (compat with OpenAI-style streaming) + delta = (obj.get('choices') or [{}])[0].get('delta') or {} + chunk_text = delta.get('content') or '' + if chunk_text: + # Print when data chunk received (truncate for safety) + try: + print(f"[tag_graph] stream_chunk_received len={len(chunk_text)} snippet={chunk_text[:200]}") + except Exception: + pass + content_buffer += chunk_text + + content = content_buffer.strip() + print(f"content: {content}") + # Strategy 1: Direct parse (if response is pure JSON) + try: + result = json.loads(content) + if isinstance(result, dict) and 'layers' in result: + return result + except Exception: + pass + + # Strategy 2: Extract JSON from markdown code blocks + json_str = None + code_block = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", content) + if code_block: + json_str = code_block.group(1) + else: + m = re.search(r"\{[\s\S]*\}", content) + if m: + json_str = m.group(0) + + if not json_str: + last_error = f"no_json_found: {content[:200]}" + logger.warning("[tag_graph] llm_no_json attempt=%s err=%s", attempt, last_error) + continue + + # Clean up common JSON issues + json_str = json_str.strip() + json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) + + try: + result = json.loads(json_str) + except json.JSONDecodeError as e: + last_error = f"json_parse_error: {e}" + logger.warning("[tag_graph] llm_json_parse_error attempt=%s err=%s json=%s", attempt, last_error, json_str[:400]) + continue + + if not isinstance(result, dict) or 'layers' not in result: + last_error = f"invalid_structure: {str(result)[:200]}" + logger.warning("[tag_graph] llm_invalid_structure attempt=%s err=%s", attempt, last_error) + continue + + return result except requests.RequestException as e: last_error = f"network_error: {type(e).__name__}: {e}" logger.warning("[tag_graph] llm_request_error attempt=%s err=%s", attempt, last_error) continue - # Retry on 429 or 5xx, fail immediately on other 4xx - if resp.status_code != 200: - body = (resp.text or "")[:400] - if resp.status_code == 429 or resp.status_code >= 500: - last_error = f"api_error_retriable: status={resp.status_code}" - logger.warning("[tag_graph] llm_http_error attempt=%s status=%s body=%s", attempt, resp.status_code, body) - continue - # 4xx client error - fail immediately - logger.error("[tag_graph] llm_http_client_error attempt=%s status=%s body=%s", attempt, resp.status_code, body) - raise Exception(f"API client error: {resp.status_code} {body}") - - try: - data = resp.json() - except Exception as e: - last_error = f"response_not_json: {type(e).__name__}" - logger.warning("[tag_graph] llm_response_not_json attempt=%s err=%s body=%s", attempt, last_error, (resp.text or "")[:400]) - continue - - choice0 = (data.get("choices") or [{}])[0] - msg = (choice0 or {}).get("message") or {} - content = (msg.get("content") or "").strip() - - # Extract JSON from content - try multiple strategies - json_str = None - - # Strategy 1: Direct parse (if response is pure JSON) - try: - result = json.loads(content) - if isinstance(result, dict) and "layers" in result: - return result - except: - pass - - # Strategy 2: Extract JSON from markdown code blocks - code_block = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", content) - if code_block: - json_str = code_block.group(1) - else: - # Strategy 3: Find largest JSON object - m = re.search(r"\{[\s\S]*\}", content) - if m: - json_str = m.group(0) - - if not json_str: - last_error = f"no_json_found: {content[:200]}" - logger.warning("[tag_graph] llm_no_json attempt=%s err=%s", attempt, last_error) - continue - - # Clean up common JSON issues - json_str = json_str.strip() - # Remove trailing commas before closing braces/brackets - json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) - - try: - result = json.loads(json_str) - except json.JSONDecodeError as e: - last_error = f"json_parse_error: {e}" - logger.warning("[tag_graph] llm_json_parse_error attempt=%s err=%s json=%s", attempt, last_error, json_str[:400]) - continue - - # Validate structure - if not isinstance(result, dict) or "layers" not in result: - last_error = f"invalid_structure: {str(result)[:200]}" - logger.warning("[tag_graph] llm_invalid_structure attempt=%s err=%s", attempt, last_error) - continue - - # Success! - return result - # No fallback: expose error to frontend, but log enough info for debugging. logger.error( "[tag_graph] llm_failed attempts=%s timeout_sec=%s last_error=%s", @@ -218,7 +242,7 @@ If unsure about Level 2, OMIT it entirely. Start response with {{ and end with } ) return await asyncio.to_thread(_call_sync) - + @app.post( f"{db_api_base}/cluster_tag_graph", dependencies=[Depends(verify_secret)], diff --git a/scripts/iib/topic_cluster.py b/scripts/iib/topic_cluster.py index de79350..1b53a5b 100644 --- a/scripts/iib/topic_cluster.py +++ b/scripts/iib/topic_cluster.py @@ -448,6 +448,7 @@ def _call_chat_title_sync( model: str, prompt_samples: List[str], output_lang: str, + existing_keywords: Optional[List[str]] = None, ) -> Optional[Dict]: """ Ask LLM to generate a short topic title and a few keywords. Returns dict or None. @@ -477,9 +478,16 @@ def _call_chat_title_sync( "- Do NOT output explanations. Do NOT output markdown/code fences.\n" "- The output MUST start with '{' and end with '}' (no leading/trailing characters).\n" "\n" - "Output STRICT JSON only:\n" - + json_example ) + if existing_keywords: + top_keywords = existing_keywords[:100] + sys += ( + f"IMPORTANT: You MUST prioritize selecting keywords from this existing list. " + f"Use keywords from the list that best match the current theme. " + f"Only create new keywords when absolutely necessary.\n" + f"Existing keywords (top {len(top_keywords)}): {', '.join(top_keywords)}\n\n" + ) + sys += "Output STRICT JSON only:\n" + json_example user = "Prompt snippets:\n" + "\n".join([f"- {s}" for s in samples]) payload = { @@ -590,6 +598,7 @@ async def _call_chat_title( model: str, prompt_samples: List[str], output_lang: str, + existing_keywords: Optional[List[str]] = None, ) -> Dict: """ Same rationale as embeddings: @@ -603,6 +612,7 @@ async def _call_chat_title( model=model, prompt_samples=prompt_samples, output_lang=output_lang, + existing_keywords=existing_keywords, ) if not isinstance(ret, dict): raise HTTPException(status_code=502, detail="Chat API returned empty title payload") @@ -1412,6 +1422,15 @@ def mount_topic_cluster_routes( if progress_cb: progress_cb({"stage": "titling", "clusters_total": len(clusters)}) + existing_keywords: List[str] = [] + keyword_frequency: Dict[str, int] = TopicTitleCache.get_all_keywords_frequency(conn, model) + + def _get_top_keywords() -> List[str]: + if not keyword_frequency: + return [] + sorted_keywords = sorted(keyword_frequency.items(), key=lambda x: x[1], reverse=True) + return [k for k, v in sorted_keywords[:100]] + for cidx, c in enumerate(clusters): if len(c["members"]) < min_cluster_size: for mi in c["members"]: @@ -1441,12 +1460,14 @@ def mount_topic_cluster_routes( title = str(cached.get("title")) keywords = cached.get("keywords") or [] else: + top_keywords = _get_top_keywords() llm = await _call_chat_title( base_url=openai_base_url, api_key=openai_api_key, model=title_model, prompt_samples=[rep] + texts[:5], output_lang=output_lang, + existing_keywords=top_keywords, ) title = (llm or {}).get("title") keywords = (llm or {}).get("keywords", []) @@ -1459,6 +1480,9 @@ def mount_topic_cluster_routes( except Exception: pass + for kw in keywords or []: + keyword_frequency[kw] = keyword_frequency.get(kw, 0) + 1 + out_clusters.append( { "id": f"topic_{cidx}",