# llmEngine.py # IMPROVED: Multi-provider LLM engine with CACHING to prevent reloading # This version fixes the critical issue where LocalLLM was reloading on every call # Features: # - Provider caching (models stay in memory) # - Unified OpenAI-style chat() API # - Providers: OpenAI, Anthropic, HuggingFace, Nebius, SambaNova, Local (transformers) # - Automatic fallback to local model on errors # - JSON-based credit tracking from dotenv import load_dotenv import json import os import traceback from typing import List, Dict, Optional load_dotenv() hf_token = os.getenv('HUGGINGFACE_TOKEN') if hf_token: from huggingface_hub import login try: login(token=hf_token) # logger.info("[HF] Logged in") except Exception as e: # logger.warning(f"[HF] Login failed: {e}") pass ########################################################### # SIMPLE JSON CREDIT STORE ########################################################### CREDITS_DB_PATH = "credits.json" DEFAULT_CREDITS = { "openai": 25, "anthropic": 25000, "huggingface": 25, "nebius": 50, "modal": 250, "blaxel": 250, "elevenlabs": 44, "sambanova": 25, "local": 9999999 } def load_credits(): if not os.path.exists(CREDITS_DB_PATH): with open(CREDITS_DB_PATH, "w") as f: json.dump(DEFAULT_CREDITS, f) return DEFAULT_CREDITS.copy() with open(CREDITS_DB_PATH, "r") as f: return json.load(f) def save_credits(data): with open(CREDITS_DB_PATH, "w") as f: json.dump(data, f, indent=2) ########################################################### # BASE PROVIDER INTERFACE ########################################################### class BaseProvider: def chat(self, model: str, messages: List[Dict], **kwargs) -> str: raise NotImplementedError ########################################################### # PROVIDER: OPENAI ########################################################### try: from openai import OpenAI except Exception: OpenAI = None class OpenAIProvider(BaseProvider): def __init__(self): if OpenAI is None: raise RuntimeError("openai library not installed or not importable") self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "")) def chat(self, model, messages, **kwargs): try: from openai.types.chat import ( ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam, ) except Exception: ChatCompletionUserMessageParam = dict ChatCompletionAssistantMessageParam = dict ChatCompletionSystemMessageParam = dict if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages): raise TypeError("messages must be a list of dicts with 'role' and 'content'") safe_messages = [] for m in messages: role = str(m.get("role", "user")) content = str(m.get("content", "")) if role == "user": safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content)) elif role == "assistant": safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) elif role == "system": safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content)) else: safe_messages.append({"role": role, "content": content}) response = self.client.chat.completions.create(model=model, messages=safe_messages) try: return response.choices[0].message.content except Exception: return str(response) ########################################################### # PROVIDER: ANTHROPIC ########################################################### try: from anthropic import Anthropic except Exception: Anthropic = None class AnthropicProvider(BaseProvider): def __init__(self): if Anthropic is None: raise RuntimeError("anthropic library not installed or not importable") self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", "")) def chat(self, model, messages, **kwargs): if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages): raise TypeError("messages must be a list of dicts with 'role' and 'content'") user_text = "\n".join([m.get("content", "") for m in messages if m.get("role") == "user"]) reply = self.client.messages.create( model=model, max_tokens=300, messages=[{"role": "user", "content": user_text}] ) if hasattr(reply, "content"): content = reply.content if isinstance(content, list) and content and len(content) > 0: block = content[0] if hasattr(block, "text"): return getattr(block, "text", str(block)) elif isinstance(block, dict) and "text" in block: return block["text"] else: return str(block) elif isinstance(content, str): return content if isinstance(reply, dict) and "completion" in reply: return reply["completion"] return str(reply) ########################################################### # PROVIDER: HUGGINGFACE INFERENCE API ########################################################### import requests class HuggingFaceProvider(BaseProvider): def __init__(self): self.key = os.getenv("HF_API_KEY", "") def chat(self, model, messages, **kwargs): if not messages: raise ValueError("messages is empty") text = messages[-1].get("content", "") r = requests.post( f"https://api-inference.huggingface.co/models/{model}", headers={"Authorization": f"Bearer {self.key}"} if self.key else {}, json={"inputs": text}, timeout=60 ) r.raise_for_status() out = r.json() if isinstance(out, list) and out and isinstance(out[0], dict): return out[0].get("generated_text") or str(out[0]) return str(out) ########################################################### # PROVIDER: NEBIUS (OpenAI-compatible) ########################################################### class NebiusProvider(BaseProvider): def __init__(self): if OpenAI is None: raise RuntimeError("openai library not installed; Nebius wrapper expects OpenAI-compatible client") self.client = OpenAI( api_key=os.getenv("NEBIUS_API_KEY", ""), base_url=os.getenv("NEBIUS_BASE_URL", "https://api.studio.nebius.ai/v1") ) def chat(self, model, messages, **kwargs): try: from openai.types.chat import ( ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam, ) except Exception: ChatCompletionUserMessageParam = dict ChatCompletionAssistantMessageParam = dict ChatCompletionSystemMessageParam = dict safe_messages = [] for m in messages: role = str(m.get("role", "user")) content = str(m.get("content", "")) if role == "user": safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content)) elif role == "assistant": safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) elif role == "system": safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content)) else: safe_messages.append({"role": role, "content": content}) r = self.client.chat.completions.create(model=model, messages=safe_messages) try: return r.choices[0].message.content except Exception: return str(r) ########################################################### # PROVIDER: SAMBANOVA (OpenAI-compatible) ########################################################### class SambaNovaProvider(BaseProvider): def __init__(self): if OpenAI is None: raise RuntimeError("openai library not installed; SambaNova wrapper expects OpenAI-compatible client") self.client = OpenAI( api_key=os.getenv("SAMBANOVA_API_KEY", ""), base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1") ) def chat(self, model, messages, **kwargs): try: from openai.types.chat import ( ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam, ) except Exception: ChatCompletionUserMessageParam = dict ChatCompletionAssistantMessageParam = dict ChatCompletionSystemMessageParam = dict safe_messages = [] for m in messages: role = str(m.get("role", "user")) content = str(m.get("content", "")) if role == "user": safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content)) elif role == "assistant": safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) elif role == "system": safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content)) else: safe_messages.append({"role": role, "content": content}) r = self.client.chat.completions.create(model=model, messages=safe_messages) try: return r.choices[0].message.content except Exception: return str(r) ########################################################### # PROVIDER: LOCAL TRANSFORMERS (CACHED) ########################################################### try: from transformers import AutoTokenizer, AutoModelForCausalLM import torch TRANSFORMERS_AVAILABLE = True except Exception: TRANSFORMERS_AVAILABLE = False class LocalLLMProvider(BaseProvider): """ Local LLM provider with caching - MODEL LOADS ONCE """ def __init__(self, model_name: str = "meta-llama/Llama-3.2-3B-Instruct"): print(f"[LocalLLM] Initializing with model: {model_name}") self.model_name = os.getenv("LOCAL_MODEL", model_name) self.model = None self.tokenizer = None self.device = None self._initialize_model() def _initialize_model(self): """Initialize model ONCE - this is called only during __init__""" try: from transformers import AutoTokenizer, AutoModelForCausalLM import torch print(f"[LocalLLM] Loading model {self.model_name}...") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[LocalLLM] Using device: {self.device}") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto" if self.device == "cuda" else None, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, trust_remote_code=True ) print(f"[LocalLLM] ✅ Model loaded successfully!") except Exception as e: print(f"[LocalLLM] ❌ Failed to load model: {e}") self.model = None traceback.print_exc() def chat(self, model, messages, **kwargs): """ Generate response - MODEL ALREADY LOADED """ if self.model is None or self.tokenizer is None: return "Error: Model or tokenizer not loaded." # Extract text from messages text = messages[-1]["content"] if isinstance(messages[-1], dict) and "content" in messages[-1] else str(messages[-1]) max_tokens = kwargs.get("max_tokens", 128) temperature = kwargs.get("temperature", 0.7) import torch # Tokenize inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=2048 ).to(self.device) # Generate (model is already loaded, just inference) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=0.9, do_sample=temperature > 0, pad_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None, eos_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None ) # Decode response = self.tokenizer.decode( outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ).strip() if self.tokenizer else "Error: Tokenizer not loaded." return response ########################################################### # PROVIDER CACHE - CRITICAL FIX ########################################################### class ProviderCache: """ Cache provider instances to avoid reloading models This is the KEY fix - providers are created ONCE and reused """ _cache = {} @classmethod def get_provider(cls, provider_name: str) -> BaseProvider: """Get or create cached provider instance""" if provider_name not in cls._cache: print(f"[ProviderCache] Creating new instance of {provider_name}") provider_class = ProviderFactory.providers[provider_name] cls._cache[provider_name] = provider_class() else: print(f"[ProviderCache] Using cached instance of {provider_name}") return cls._cache[provider_name] @classmethod def clear_cache(cls): """Clear all cached providers (useful for debugging)""" cls._cache.clear() print("[ProviderCache] Cache cleared") ########################################################### # PROVIDER FACTORY (IMPROVED WITH CACHING) ########################################################### class ProviderFactory: providers = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, "huggingface": HuggingFaceProvider, "nebius": NebiusProvider, "sambanova": SambaNovaProvider, "local": LocalLLMProvider, } @staticmethod def get(provider_name: str) -> BaseProvider: """ Get provider instance - NOW USES CACHING This prevents reloading the model on every call """ provider_name = provider_name.lower() if provider_name not in ProviderFactory.providers: raise ValueError(f"Unknown provider: {provider_name}") # USE CACHE instead of creating new instance every time return ProviderCache.get_provider(provider_name) ########################################################### # MAIN ENGINE WITH FALLBACK + OPENAI-STYLE API ########################################################### class LLMEngine: def __init__(self): self.credits = load_credits() def deduct(self, provider, amount): if provider not in self.credits: self.credits[provider] = 0 self.credits[provider] = max(0, self.credits[provider] - amount) save_credits(self.credits) def chat(self, provider: str, model: str, messages: List[Dict], fallback: bool = True, **kwargs): """ Main chat method - providers are now cached """ try: p = ProviderFactory.get(provider) # This now returns cached instance result = p.chat(model=model, messages=messages, **kwargs) try: self.deduct(provider, 0.001) except Exception: pass return result except Exception as exc: print(f"⚠ Provider '{provider}' failed → fallback activated: {exc}") traceback.print_exc() if fallback: try: lp = ProviderFactory.get("local") # Gets cached local provider return lp.chat(model="local", messages=messages, **kwargs) except Exception as le: print("Fallback to local provider failed:", le) traceback.print_exc() raise raise ########################################################### # EXAMPLES + SIMPLE TESTS ########################################################### def main(): engine = LLMEngine() print("=== Testing Provider Caching ===") print("\nFirst call (should load model):") result1 = engine.chat( provider="local", model="meta-llama/Llama-3.2-3B-Instruct", messages=[{"role": "user", "content": "Say hello"}] ) print(f"Response: {result1[:100]}") print("\nSecond call (should use cached model - NO RELOAD):") result2 = engine.chat( provider="local", model="meta-llama/Llama-3.2-3B-Instruct", messages=[{"role": "user", "content": "Say goodbye"}] ) print(f"Response: {result2[:100]}") print("\n✅ If you didn't see 'Loading model' twice, caching works!") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--test", action="store_true", help="run examples and simple tests") args = parser.parse_args() if args.test: main() else: main()