Files
project-lyra/cortex/llm/llm_router.py
2025-12-21 14:30:32 -05:00

176 lines
6.5 KiB
Python

# llm_router.py
import os
import httpx
import json
import logging
logger = logging.getLogger(__name__)
# ------------------------------------------------------------
# Load backend registry from root .env
# ------------------------------------------------------------
BACKENDS = {
"PRIMARY": {
"provider": os.getenv("LLM_PRIMARY_PROVIDER", "").lower(),
"url": os.getenv("LLM_PRIMARY_URL", ""),
"model": os.getenv("LLM_PRIMARY_MODEL", "")
},
"SECONDARY": {
"provider": os.getenv("LLM_SECONDARY_PROVIDER", "").lower(),
"url": os.getenv("LLM_SECONDARY_URL", ""),
"model": os.getenv("LLM_SECONDARY_MODEL", "")
},
"OPENAI": {
"provider": os.getenv("LLM_OPENAI_PROVIDER", "").lower(),
"url": os.getenv("LLM_OPENAI_URL", ""),
"model": os.getenv("LLM_OPENAI_MODEL", ""),
"api_key": os.getenv("OPENAI_API_KEY", "")
},
"FALLBACK": {
"provider": os.getenv("LLM_FALLBACK_PROVIDER", "").lower(),
"url": os.getenv("LLM_FALLBACK_URL", ""),
"model": os.getenv("LLM_FALLBACK_MODEL", "")
},
}
DEFAULT_BACKEND = "PRIMARY"
# Reusable async HTTP client
http_client = httpx.AsyncClient(timeout=120.0)
# ------------------------------------------------------------
# Public call
# ------------------------------------------------------------
async def call_llm(
prompt: str = None,
messages: list = None,
backend: str | None = None,
temperature: float = 0.7,
max_tokens: int = 512,
):
"""
Call an LLM backend.
Args:
prompt: String prompt (for completion-style APIs like mi50)
messages: List of message dicts (for chat-style APIs like Ollama/OpenAI)
backend: Which backend to use (PRIMARY, SECONDARY, OPENAI, etc.)
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
"""
backend = (backend or DEFAULT_BACKEND).upper()
if backend not in BACKENDS:
raise RuntimeError(f"Unknown backend '{backend}'")
cfg = BACKENDS[backend]
provider = cfg["provider"]
url = cfg["url"]
model = cfg["model"]
if not url or not model:
raise RuntimeError(f"Backend '{backend}' missing url/model in env")
# -------------------------------
# Provider: MI50 (llama.cpp server)
# -------------------------------
if provider == "mi50":
payload = {
"prompt": prompt,
"n_predict": max_tokens,
"temperature": temperature,
"stop": ["User:", "\nUser:", "Assistant:", "\n\n\n"]
}
try:
r = await http_client.post(f"{url}/completion", json=payload)
r.raise_for_status()
data = r.json()
return data.get("content", "")
except httpx.HTTPError as e:
logger.error(f"HTTP error calling mi50: {type(e).__name__}: {str(e)}")
raise RuntimeError(f"LLM API error (mi50): {type(e).__name__}: {str(e)}")
except (KeyError, json.JSONDecodeError) as e:
logger.error(f"Response parsing error from mi50: {e}")
raise RuntimeError(f"Invalid response format (mi50): {e}")
except Exception as e:
logger.error(f"Unexpected error calling mi50: {type(e).__name__}: {str(e)}")
raise RuntimeError(f"Unexpected error (mi50): {type(e).__name__}: {str(e)}")
# -------------------------------
# Provider: OLLAMA (your 3090)
# -------------------------------
if provider == "ollama":
# Use messages array if provided, otherwise convert prompt to single user message
if messages:
chat_messages = messages
else:
chat_messages = [{"role": "user", "content": prompt}]
payload = {
"model": model,
"messages": chat_messages,
"stream": False,
"options": {
"temperature": temperature,
"num_predict": max_tokens
}
}
try:
r = await http_client.post(f"{url}/api/chat", json=payload)
r.raise_for_status()
data = r.json()
return data["message"]["content"]
except httpx.HTTPError as e:
logger.error(f"HTTP error calling ollama: {type(e).__name__}: {str(e)}")
raise RuntimeError(f"LLM API error (ollama): {type(e).__name__}: {str(e)}")
except (KeyError, json.JSONDecodeError) as e:
logger.error(f"Response parsing error from ollama: {e}")
raise RuntimeError(f"Invalid response format (ollama): {e}")
except Exception as e:
logger.error(f"Unexpected error calling ollama: {type(e).__name__}: {str(e)}")
raise RuntimeError(f"Unexpected error (ollama): {type(e).__name__}: {str(e)}")
# -------------------------------
# Provider: OPENAI
# -------------------------------
if provider == "openai":
headers = {
"Authorization": f"Bearer {cfg['api_key']}",
"Content-Type": "application/json"
}
# Use messages array if provided, otherwise convert prompt to single user message
if messages:
chat_messages = messages
else:
chat_messages = [{"role": "user", "content": prompt}]
payload = {
"model": model,
"messages": chat_messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
try:
r = await http_client.post(f"{url}/chat/completions", json=payload, headers=headers)
r.raise_for_status()
data = r.json()
return data["choices"][0]["message"]["content"]
except httpx.HTTPError as e:
logger.error(f"HTTP error calling openai: {type(e).__name__}: {str(e)}")
raise RuntimeError(f"LLM API error (openai): {type(e).__name__}: {str(e)}")
except (KeyError, json.JSONDecodeError) as e:
logger.error(f"Response parsing error from openai: {e}")
raise RuntimeError(f"Invalid response format (openai): {e}")
except Exception as e:
logger.error(f"Unexpected error calling openai: {type(e).__name__}: {str(e)}")
raise RuntimeError(f"Unexpected error (openai): {type(e).__name__}: {str(e)}")
# -------------------------------
# Unknown provider
# -------------------------------
raise RuntimeError(f"Provider '{provider}' not implemented.")