# llm_router.py import os import requests import json # ------------------------------------------------------------ # Load backend registry from root .env # ------------------------------------------------------------ BACKENDS = { "PRIMARY": { "provider": os.getenv("LLM_PRIMARY_PROVIDER", "").lower(), "url": os.getenv("LLM_PRIMARY_URL", ""), "model": os.getenv("LLM_PRIMARY_MODEL", "") }, "SECONDARY": { "provider": os.getenv("LLM_SECONDARY_PROVIDER", "").lower(), "url": os.getenv("LLM_SECONDARY_URL", ""), "model": os.getenv("LLM_SECONDARY_MODEL", "") }, "OPENAI": { "provider": os.getenv("LLM_OPENAI_PROVIDER", "").lower(), "url": os.getenv("LLM_OPENAI_URL", ""), "model": os.getenv("LLM_OPENAI_MODEL", ""), "api_key": os.getenv("OPENAI_API_KEY", "") }, "FALLBACK": { "provider": os.getenv("LLM_FALLBACK_PROVIDER", "").lower(), "url": os.getenv("LLM_FALLBACK_URL", ""), "model": os.getenv("LLM_FALLBACK_MODEL", "") }, } DEFAULT_BACKEND = "PRIMARY" # ------------------------------------------------------------ # Public call # ------------------------------------------------------------ async def call_llm( prompt: str, backend: str | None = None, temperature: float = 0.7, max_tokens: int = 512, ): backend = (backend or DEFAULT_BACKEND).upper() if backend not in BACKENDS: raise RuntimeError(f"Unknown backend '{backend}'") cfg = BACKENDS[backend] provider = cfg["provider"] url = cfg["url"] model = cfg["model"] if not url or not model: raise RuntimeError(f"Backend '{backend}' missing url/model in env") # ------------------------------- # Provider: VLLM (your MI50) # ------------------------------- if provider == "vllm": payload = { "model": model, "prompt": prompt, "max_tokens": max_tokens, "temperature": temperature } r = requests.post(url, json=payload, timeout=120) data = r.json() return data["choices"][0]["text"] # ------------------------------- # Provider: OLLAMA (your 3090) # ------------------------------- if provider == "ollama": payload = { "model": model, "messages": [ {"role": "user", "content": prompt} ], "stream": False # <-- critical fix } r = requests.post(f"{url}/api/chat", json=payload, timeout=120) data = r.json() return data["message"]["content"] # ------------------------------- # Provider: OPENAI # ------------------------------- if provider == "openai": headers = { "Authorization": f"Bearer {cfg['api_key']}", "Content-Type": "application/json" } payload = { "model": model, "messages": [ {"role": "user", "content": prompt} ], "temperature": temperature, "max_tokens": max_tokens, } r = requests.post(f"{url}/chat/completions", json=payload, headers=headers, timeout=120) data = r.json() return data["choices"][0]["message"]["content"] # ------------------------------- # Unknown provider # ------------------------------- raise RuntimeError(f"Provider '{provider}' not implemented.")