Cortex rework in progress
This commit is contained in:
@@ -1,137 +1,102 @@
|
||||
import os
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
# ============================================================
|
||||
# Backend config lookup
|
||||
# ============================================================
|
||||
# ---------------------------------------------
|
||||
# Load backend definition from .env
|
||||
# ---------------------------------------------
|
||||
|
||||
def get_backend_config(name: str):
|
||||
def load_backend_config(name: str):
|
||||
"""
|
||||
Reads provider/URL/model for a backend.
|
||||
Example env:
|
||||
LLM_PRIMARY_PROVIDER=vllm
|
||||
LLM_PRIMARY_URL=http://10.0.0.43:8000
|
||||
LLM_PRIMARY_MODEL=/model
|
||||
Given a backend name like 'PRIMARY' or 'OPENAI',
|
||||
load the matching provider / url / model from env.
|
||||
"""
|
||||
key = name.upper()
|
||||
provider = os.getenv(f"LLM_{key}_PROVIDER", "vllm").lower()
|
||||
base_url = os.getenv(f"LLM_{key}_URL", "").rstrip("/")
|
||||
model = os.getenv(f"LLM_{key}_MODEL", "/model")
|
||||
|
||||
if not base_url:
|
||||
raise RuntimeError(f"Backend {name} has no URL configured.")
|
||||
prefix = f"LLM_{name.upper()}"
|
||||
|
||||
return provider, base_url, model
|
||||
provider = os.getenv(f"{prefix}_PROVIDER")
|
||||
url = os.getenv(f"{prefix}_URL")
|
||||
model = os.getenv(f"{prefix}_MODEL")
|
||||
|
||||
if not provider or not url or not model:
|
||||
raise RuntimeError(
|
||||
f"Backend '{name}' is missing configuration. "
|
||||
f"Expected {prefix}_PROVIDER / URL / MODEL in .env"
|
||||
)
|
||||
|
||||
return provider, url.rstrip("/"), model
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Build the final API URL
|
||||
# ============================================================
|
||||
# ---------------------------------------------
|
||||
# Core call_llm() — fail hard, no fallback
|
||||
# ---------------------------------------------
|
||||
|
||||
def build_url(provider: str, base_url: str):
|
||||
def call_llm(prompt: str, backend_env_var: str):
|
||||
"""
|
||||
Provider → correct endpoint.
|
||||
Example:
|
||||
call_llm(prompt, backend_env_var="CORTEX_LLM")
|
||||
|
||||
backend_env_var should contain one of:
|
||||
PRIMARY, SECONDARY, OPENAI, FALLBACK, etc
|
||||
"""
|
||||
if provider == "vllm":
|
||||
return f"{base_url}/v1/completions"
|
||||
|
||||
if provider == "openai_completions":
|
||||
return f"{base_url}/v1/completions"
|
||||
backend_name = os.getenv(backend_env_var)
|
||||
if not backend_name:
|
||||
raise RuntimeError(f"{backend_env_var} is not set in .env")
|
||||
|
||||
if provider == "openai_chat":
|
||||
return f"{base_url}/v1/chat/completions"
|
||||
provider, base_url, model = load_backend_config(backend_name)
|
||||
|
||||
if provider == "ollama":
|
||||
return f"{base_url}/api/generate"
|
||||
|
||||
raise RuntimeError(f"Unknown provider: {provider}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Build the payload depending on provider
|
||||
# ============================================================
|
||||
|
||||
def build_payload(provider: str, model: str, prompt: str, temperature: float):
|
||||
# ---------------------------------------------
|
||||
# Provider-specific behavior
|
||||
# ---------------------------------------------
|
||||
|
||||
if provider == "vllm":
|
||||
return {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": 512,
|
||||
"temperature": temperature
|
||||
}
|
||||
# vLLM OpenAI-compatible API
|
||||
response = requests.post(
|
||||
f"{base_url}/v1/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": 1024,
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["text"]
|
||||
|
||||
if provider == "openai_completions":
|
||||
return {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"max_tokens": 512,
|
||||
"temperature": temperature
|
||||
}
|
||||
elif provider == "ollama":
|
||||
response = requests.post(
|
||||
f"{base_url}/api/chat",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": False
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["message"]["content"]
|
||||
|
||||
if provider == "openai_chat":
|
||||
return {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if provider == "ollama":
|
||||
return {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
raise RuntimeError(f"Unknown provider: {provider}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Unified LLM call
|
||||
# ============================================================
|
||||
|
||||
async def call_llm(prompt: str,
|
||||
backend: str = "primary",
|
||||
temperature: float = 0.7):
|
||||
|
||||
provider, base_url, model = get_backend_config(backend)
|
||||
url = build_url(provider, base_url)
|
||||
payload = build_payload(provider, model, prompt, temperature)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Cloud auth (OpenAI)
|
||||
if provider.startswith("openai"):
|
||||
elif provider == "openai":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY missing")
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
raise RuntimeError("OPENAI_API_KEY missing but provider=openai was selected")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.post(url, json=payload, headers=headers, timeout=45)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
return f"[LLM-Error] {e}"
|
||||
response = requests.post(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
# =======================================================
|
||||
# Unified output extraction
|
||||
# =======================================================
|
||||
# vLLM + OpenAI completions
|
||||
if provider in ["vllm", "openai_completions"]:
|
||||
return (
|
||||
data["choices"][0].get("text") or
|
||||
data["choices"][0].get("message", {}).get("content", "")
|
||||
).strip()
|
||||
|
||||
# OpenAI chat
|
||||
if provider == "openai_chat":
|
||||
return data["choices"][0]["message"]["content"].strip()
|
||||
|
||||
# Ollama
|
||||
if provider == "ollama":
|
||||
# Ollama returns: {"model": "...", "created_at": ..., "response": "..."}
|
||||
return data.get("response", "").strip()
|
||||
|
||||
return str(data).strip()
|
||||
else:
|
||||
raise RuntimeError(f"Unknown LLM provider: {provider}")
|
||||
|
||||
Reference in New Issue
Block a user