32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
import os, requests
|
|
from typing import Dict, Any, List
|
|
|
|
RAG_API_URL = os.getenv("RAG_API_URL", "http://localhost:7090")
|
|
|
|
def query_rag(query: str, where: Dict[str, Any] | None = None, k: int = 6) -> Dict[str, Any]:
|
|
payload = {"query": query, "k": k}
|
|
if where:
|
|
payload["where"] = where
|
|
try:
|
|
r = requests.post(f"{RAG_API_URL}/rag/search", json=payload, timeout=8)
|
|
r.raise_for_status()
|
|
data = r.json() or {}
|
|
except Exception as e:
|
|
data = {"answer": "", "chunks": [], "error": str(e)}
|
|
return data
|
|
|
|
def format_rag_block(result: Dict[str, Any]) -> str:
|
|
answer = (result.get("answer") or "").strip()
|
|
chunks: List[Dict[str, Any]] = result.get("chunks") or []
|
|
lines = ["[RAG]"]
|
|
if answer:
|
|
lines.append(f"Synthesized answer: {answer}")
|
|
if chunks:
|
|
lines.append("Top excerpts:")
|
|
for i, c in enumerate(chunks[:5], 1):
|
|
src = c.get("metadata", {}).get("source", "unknown")
|
|
txt = (c.get("text") or "").strip().replace("\n", " ")
|
|
if len(txt) > 220:
|
|
txt = txt[:220] + "…"
|
|
lines.append(f" {i}. {txt} — {src}")
|
|
return "\n".join(lines) + ("\n" if lines else "") |