Round_2 / llm_engine.py
Chris4K's picture
Update llm_engine.py
c60cf73 verified
# llmEngine.py
# IMPROVED: Multi-provider LLM engine with CACHING to prevent reloading
# This version fixes the critical issue where LocalLLM was reloading on every call
# Features:
# - Provider caching (models stay in memory)
# - Unified OpenAI-style chat() API
# - Providers: OpenAI, Anthropic, HuggingFace, Nebius, SambaNova, Local (transformers)
# - Automatic fallback to local model on errors
# - JSON-based credit tracking
from dotenv import load_dotenv
import json
import os
import traceback
from typing import List, Dict, Optional
load_dotenv()
hf_token = os.getenv('HUGGINGFACE_TOKEN')
if hf_token:
from huggingface_hub import login
try:
login(token=hf_token)
# logger.info("[HF] Logged in")
except Exception as e:
# logger.warning(f"[HF] Login failed: {e}")
pass
###########################################################
# SIMPLE JSON CREDIT STORE
###########################################################
CREDITS_DB_PATH = "credits.json"
DEFAULT_CREDITS = {
"openai": 25,
"anthropic": 25000,
"huggingface": 25,
"nebius": 50,
"modal": 250,
"blaxel": 250,
"elevenlabs": 44,
"sambanova": 25,
"local": 9999999
}
def load_credits():
if not os.path.exists(CREDITS_DB_PATH):
with open(CREDITS_DB_PATH, "w") as f:
json.dump(DEFAULT_CREDITS, f)
return DEFAULT_CREDITS.copy()
with open(CREDITS_DB_PATH, "r") as f:
return json.load(f)
def save_credits(data):
with open(CREDITS_DB_PATH, "w") as f:
json.dump(data, f, indent=2)
###########################################################
# BASE PROVIDER INTERFACE
###########################################################
class BaseProvider:
def chat(self, model: str, messages: List[Dict], **kwargs) -> str:
raise NotImplementedError
###########################################################
# PROVIDER: OPENAI
###########################################################
try:
from openai import OpenAI
except Exception:
OpenAI = None
class OpenAIProvider(BaseProvider):
def __init__(self):
if OpenAI is None:
raise RuntimeError("openai library not installed or not importable")
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", ""))
def chat(self, model, messages, **kwargs):
try:
from openai.types.chat import (
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
)
except Exception:
ChatCompletionUserMessageParam = dict
ChatCompletionAssistantMessageParam = dict
ChatCompletionSystemMessageParam = dict
if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
raise TypeError("messages must be a list of dicts with 'role' and 'content'")
safe_messages = []
for m in messages:
role = str(m.get("role", "user"))
content = str(m.get("content", ""))
if role == "user":
safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
elif role == "assistant":
safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
elif role == "system":
safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
else:
safe_messages.append({"role": role, "content": content})
response = self.client.chat.completions.create(model=model, messages=safe_messages)
try:
return response.choices[0].message.content
except Exception:
return str(response)
###########################################################
# PROVIDER: ANTHROPIC
###########################################################
try:
from anthropic import Anthropic
except Exception:
Anthropic = None
class AnthropicProvider(BaseProvider):
def __init__(self):
if Anthropic is None:
raise RuntimeError("anthropic library not installed or not importable")
self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", ""))
def chat(self, model, messages, **kwargs):
if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
raise TypeError("messages must be a list of dicts with 'role' and 'content'")
user_text = "\n".join([m.get("content", "") for m in messages if m.get("role") == "user"])
reply = self.client.messages.create(
model=model,
max_tokens=300,
messages=[{"role": "user", "content": user_text}]
)
if hasattr(reply, "content"):
content = reply.content
if isinstance(content, list) and content and len(content) > 0:
block = content[0]
if hasattr(block, "text"):
return getattr(block, "text", str(block))
elif isinstance(block, dict) and "text" in block:
return block["text"]
else:
return str(block)
elif isinstance(content, str):
return content
if isinstance(reply, dict) and "completion" in reply:
return reply["completion"]
return str(reply)
###########################################################
# PROVIDER: HUGGINGFACE INFERENCE API
###########################################################
import requests
class HuggingFaceProvider(BaseProvider):
def __init__(self):
self.key = os.getenv("HF_API_KEY", "")
def chat(self, model, messages, **kwargs):
if not messages:
raise ValueError("messages is empty")
text = messages[-1].get("content", "")
r = requests.post(
f"https://api-inference.huggingface.co/models/{model}",
headers={"Authorization": f"Bearer {self.key}"} if self.key else {},
json={"inputs": text},
timeout=60
)
r.raise_for_status()
out = r.json()
if isinstance(out, list) and out and isinstance(out[0], dict):
return out[0].get("generated_text") or str(out[0])
return str(out)
###########################################################
# PROVIDER: NEBIUS (OpenAI-compatible)
###########################################################
class NebiusProvider(BaseProvider):
def __init__(self):
if OpenAI is None:
raise RuntimeError("openai library not installed; Nebius wrapper expects OpenAI-compatible client")
self.client = OpenAI(
api_key=os.getenv("NEBIUS_API_KEY", ""),
base_url=os.getenv("NEBIUS_BASE_URL", "https://api.studio.nebius.ai/v1")
)
def chat(self, model, messages, **kwargs):
try:
from openai.types.chat import (
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
)
except Exception:
ChatCompletionUserMessageParam = dict
ChatCompletionAssistantMessageParam = dict
ChatCompletionSystemMessageParam = dict
safe_messages = []
for m in messages:
role = str(m.get("role", "user"))
content = str(m.get("content", ""))
if role == "user":
safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
elif role == "assistant":
safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
elif role == "system":
safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
else:
safe_messages.append({"role": role, "content": content})
r = self.client.chat.completions.create(model=model, messages=safe_messages)
try:
return r.choices[0].message.content
except Exception:
return str(r)
###########################################################
# PROVIDER: SAMBANOVA (OpenAI-compatible)
###########################################################
class SambaNovaProvider(BaseProvider):
def __init__(self):
if OpenAI is None:
raise RuntimeError("openai library not installed; SambaNova wrapper expects OpenAI-compatible client")
self.client = OpenAI(
api_key=os.getenv("SAMBANOVA_API_KEY", ""),
base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
)
def chat(self, model, messages, **kwargs):
try:
from openai.types.chat import (
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
)
except Exception:
ChatCompletionUserMessageParam = dict
ChatCompletionAssistantMessageParam = dict
ChatCompletionSystemMessageParam = dict
safe_messages = []
for m in messages:
role = str(m.get("role", "user"))
content = str(m.get("content", ""))
if role == "user":
safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
elif role == "assistant":
safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
elif role == "system":
safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
else:
safe_messages.append({"role": role, "content": content})
r = self.client.chat.completions.create(model=model, messages=safe_messages)
try:
return r.choices[0].message.content
except Exception:
return str(r)
###########################################################
# PROVIDER: LOCAL TRANSFORMERS (CACHED)
###########################################################
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
TRANSFORMERS_AVAILABLE = True
except Exception:
TRANSFORMERS_AVAILABLE = False
class LocalLLMProvider(BaseProvider):
"""
Local LLM provider with caching - MODEL LOADS ONCE
"""
def __init__(self, model_name: str = "meta-llama/Llama-3.2-3B-Instruct"):
print(f"[LocalLLM] Initializing with model: {model_name}")
self.model_name = os.getenv("LOCAL_MODEL", model_name)
self.model = None
self.tokenizer = None
self.device = None
self._initialize_model()
def _initialize_model(self):
"""Initialize model ONCE - this is called only during __init__"""
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
print(f"[LocalLLM] Loading model {self.model_name}...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[LocalLLM] Using device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
trust_remote_code=True
)
print(f"[LocalLLM] ✅ Model loaded successfully!")
except Exception as e:
print(f"[LocalLLM] ❌ Failed to load model: {e}")
self.model = None
traceback.print_exc()
def chat(self, model, messages, **kwargs):
"""
Generate response - MODEL ALREADY LOADED
"""
if self.model is None or self.tokenizer is None:
return "Error: Model or tokenizer not loaded."
# Extract text from messages
text = messages[-1]["content"] if isinstance(messages[-1], dict) and "content" in messages[-1] else str(messages[-1])
max_tokens = kwargs.get("max_tokens", 128)
temperature = kwargs.get("temperature", 0.7)
import torch
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(self.device)
# Generate (model is already loaded, just inference)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None,
eos_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None
)
# Decode
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip() if self.tokenizer else "Error: Tokenizer not loaded."
return response
###########################################################
# PROVIDER CACHE - CRITICAL FIX
###########################################################
class ProviderCache:
"""
Cache provider instances to avoid reloading models
This is the KEY fix - providers are created ONCE and reused
"""
_cache = {}
@classmethod
def get_provider(cls, provider_name: str) -> BaseProvider:
"""Get or create cached provider instance"""
if provider_name not in cls._cache:
print(f"[ProviderCache] Creating new instance of {provider_name}")
provider_class = ProviderFactory.providers[provider_name]
cls._cache[provider_name] = provider_class()
else:
print(f"[ProviderCache] Using cached instance of {provider_name}")
return cls._cache[provider_name]
@classmethod
def clear_cache(cls):
"""Clear all cached providers (useful for debugging)"""
cls._cache.clear()
print("[ProviderCache] Cache cleared")
###########################################################
# PROVIDER FACTORY (IMPROVED WITH CACHING)
###########################################################
class ProviderFactory:
providers = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"huggingface": HuggingFaceProvider,
"nebius": NebiusProvider,
"sambanova": SambaNovaProvider,
"local": LocalLLMProvider,
}
@staticmethod
def get(provider_name: str) -> BaseProvider:
"""
Get provider instance - NOW USES CACHING
This prevents reloading the model on every call
"""
provider_name = provider_name.lower()
if provider_name not in ProviderFactory.providers:
raise ValueError(f"Unknown provider: {provider_name}")
# USE CACHE instead of creating new instance every time
return ProviderCache.get_provider(provider_name)
###########################################################
# MAIN ENGINE WITH FALLBACK + OPENAI-STYLE API
###########################################################
class LLMEngine:
def __init__(self):
self.credits = load_credits()
def deduct(self, provider, amount):
if provider not in self.credits:
self.credits[provider] = 0
self.credits[provider] = max(0, self.credits[provider] - amount)
save_credits(self.credits)
def chat(self, provider: str, model: str, messages: List[Dict], fallback: bool = True, **kwargs):
"""
Main chat method - providers are now cached
"""
try:
p = ProviderFactory.get(provider) # This now returns cached instance
result = p.chat(model=model, messages=messages, **kwargs)
try:
self.deduct(provider, 0.001)
except Exception:
pass
return result
except Exception as exc:
print(f"⚠ Provider '{provider}' failed → fallback activated: {exc}")
traceback.print_exc()
if fallback:
try:
lp = ProviderFactory.get("local") # Gets cached local provider
return lp.chat(model="local", messages=messages, **kwargs)
except Exception as le:
print("Fallback to local provider failed:", le)
traceback.print_exc()
raise
raise
###########################################################
# EXAMPLES + SIMPLE TESTS
###########################################################
def main():
engine = LLMEngine()
print("=== Testing Provider Caching ===")
print("\nFirst call (should load model):")
result1 = engine.chat(
provider="local",
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[{"role": "user", "content": "Say hello"}]
)
print(f"Response: {result1[:100]}")
print("\nSecond call (should use cached model - NO RELOAD):")
result2 = engine.chat(
provider="local",
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[{"role": "user", "content": "Say goodbye"}]
)
print(f"Response: {result2[:100]}")
print("\n✅ If you didn't see 'Loading model' twice, caching works!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--test", action="store_true", help="run examples and simple tests")
args = parser.parse_args()
if args.test:
main()
else:
main()