103 lines
3.2 KiB
Python
103 lines
3.2 KiB
Python
import os
|
|
import requests
|
|
|
|
# ---------------------------------------------
|
|
# Load backend definition from .env
|
|
# ---------------------------------------------
|
|
|
|
def load_backend_config(name: str):
|
|
"""
|
|
Given a backend name like 'PRIMARY' or 'OPENAI',
|
|
load the matching provider / url / model from env.
|
|
"""
|
|
|
|
prefix = f"LLM_{name.upper()}"
|
|
|
|
provider = os.getenv(f"{prefix}_PROVIDER")
|
|
url = os.getenv(f"{prefix}_URL")
|
|
model = os.getenv(f"{prefix}_MODEL")
|
|
|
|
if not provider or not url or not model:
|
|
raise RuntimeError(
|
|
f"Backend '{name}' is missing configuration. "
|
|
f"Expected {prefix}_PROVIDER / URL / MODEL in .env"
|
|
)
|
|
|
|
return provider, url.rstrip("/"), model
|
|
|
|
|
|
# ---------------------------------------------
|
|
# Core call_llm() — fail hard, no fallback
|
|
# ---------------------------------------------
|
|
|
|
def call_llm(prompt: str, backend_env_var: str):
|
|
"""
|
|
Example:
|
|
call_llm(prompt, backend_env_var="CORTEX_LLM")
|
|
|
|
backend_env_var should contain one of:
|
|
PRIMARY, SECONDARY, OPENAI, FALLBACK, etc
|
|
"""
|
|
|
|
backend_name = os.getenv(backend_env_var)
|
|
if not backend_name:
|
|
raise RuntimeError(f"{backend_env_var} is not set in .env")
|
|
|
|
provider, base_url, model = load_backend_config(backend_name)
|
|
|
|
# ---------------------------------------------
|
|
# Provider-specific behavior
|
|
# ---------------------------------------------
|
|
|
|
if provider == "vllm":
|
|
# vLLM OpenAI-compatible API
|
|
response = requests.post(
|
|
f"{base_url}/v1/completions",
|
|
json={
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"max_tokens": 1024,
|
|
"temperature": float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
|
},
|
|
timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["choices"][0]["text"]
|
|
|
|
elif provider == "ollama":
|
|
response = requests.post(
|
|
f"{base_url}/api/chat",
|
|
json={
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"stream": False
|
|
},
|
|
timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["message"]["content"]
|
|
|
|
elif provider == "openai":
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise RuntimeError("OPENAI_API_KEY missing but provider=openai was selected")
|
|
|
|
response = requests.post(
|
|
f"{base_url}/chat/completions",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
json={
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
|
},
|
|
timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["choices"][0]["message"]["content"]
|
|
|
|
else:
|
|
raise RuntimeError(f"Unknown LLM provider: {provider}")
|