Spaces:
Running
Running
| # llmEngine.py | |
| # IMPROVED: Multi-provider LLM engine with CACHING to prevent reloading | |
| # This version fixes the critical issue where LocalLLM was reloading on every call | |
| # Features: | |
| # - Provider caching (models stay in memory) | |
| # - Unified OpenAI-style chat() API | |
| # - Providers: OpenAI, Anthropic, HuggingFace, Nebius, SambaNova, Local (transformers) | |
| # - Automatic fallback to local model on errors | |
| # - JSON-based credit tracking | |
| from dotenv import load_dotenv | |
| import json | |
| import os | |
| import traceback | |
| from typing import List, Dict, Optional | |
| load_dotenv() | |
| hf_token = os.getenv('HUGGINGFACE_TOKEN') | |
| if hf_token: | |
| from huggingface_hub import login | |
| try: | |
| login(token=hf_token) | |
| # logger.info("[HF] Logged in") | |
| except Exception as e: | |
| # logger.warning(f"[HF] Login failed: {e}") | |
| pass | |
| ########################################################### | |
| # SIMPLE JSON CREDIT STORE | |
| ########################################################### | |
| CREDITS_DB_PATH = "credits.json" | |
| DEFAULT_CREDITS = { | |
| "openai": 25, | |
| "anthropic": 25000, | |
| "huggingface": 25, | |
| "nebius": 50, | |
| "modal": 250, | |
| "blaxel": 250, | |
| "elevenlabs": 44, | |
| "sambanova": 25, | |
| "local": 9999999 | |
| } | |
| def load_credits(): | |
| if not os.path.exists(CREDITS_DB_PATH): | |
| with open(CREDITS_DB_PATH, "w") as f: | |
| json.dump(DEFAULT_CREDITS, f) | |
| return DEFAULT_CREDITS.copy() | |
| with open(CREDITS_DB_PATH, "r") as f: | |
| return json.load(f) | |
| def save_credits(data): | |
| with open(CREDITS_DB_PATH, "w") as f: | |
| json.dump(data, f, indent=2) | |
| ########################################################### | |
| # BASE PROVIDER INTERFACE | |
| ########################################################### | |
| class BaseProvider: | |
| def chat(self, model: str, messages: List[Dict], **kwargs) -> str: | |
| raise NotImplementedError | |
| ########################################################### | |
| # PROVIDER: OPENAI | |
| ########################################################### | |
| try: | |
| from openai import OpenAI | |
| except Exception: | |
| OpenAI = None | |
| class OpenAIProvider(BaseProvider): | |
| def __init__(self): | |
| if OpenAI is None: | |
| raise RuntimeError("openai library not installed or not importable") | |
| self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "")) | |
| def chat(self, model, messages, **kwargs): | |
| try: | |
| from openai.types.chat import ( | |
| ChatCompletionUserMessageParam, | |
| ChatCompletionAssistantMessageParam, | |
| ChatCompletionSystemMessageParam, | |
| ) | |
| except Exception: | |
| ChatCompletionUserMessageParam = dict | |
| ChatCompletionAssistantMessageParam = dict | |
| ChatCompletionSystemMessageParam = dict | |
| if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages): | |
| raise TypeError("messages must be a list of dicts with 'role' and 'content'") | |
| safe_messages = [] | |
| for m in messages: | |
| role = str(m.get("role", "user")) | |
| content = str(m.get("content", "")) | |
| if role == "user": | |
| safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content)) | |
| elif role == "assistant": | |
| safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) | |
| elif role == "system": | |
| safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content)) | |
| else: | |
| safe_messages.append({"role": role, "content": content}) | |
| response = self.client.chat.completions.create(model=model, messages=safe_messages) | |
| try: | |
| return response.choices[0].message.content | |
| except Exception: | |
| return str(response) | |
| ########################################################### | |
| # PROVIDER: ANTHROPIC | |
| ########################################################### | |
| try: | |
| from anthropic import Anthropic | |
| except Exception: | |
| Anthropic = None | |
| class AnthropicProvider(BaseProvider): | |
| def __init__(self): | |
| if Anthropic is None: | |
| raise RuntimeError("anthropic library not installed or not importable") | |
| self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", "")) | |
| def chat(self, model, messages, **kwargs): | |
| if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages): | |
| raise TypeError("messages must be a list of dicts with 'role' and 'content'") | |
| user_text = "\n".join([m.get("content", "") for m in messages if m.get("role") == "user"]) | |
| reply = self.client.messages.create( | |
| model=model, | |
| max_tokens=300, | |
| messages=[{"role": "user", "content": user_text}] | |
| ) | |
| if hasattr(reply, "content"): | |
| content = reply.content | |
| if isinstance(content, list) and content and len(content) > 0: | |
| block = content[0] | |
| if hasattr(block, "text"): | |
| return getattr(block, "text", str(block)) | |
| elif isinstance(block, dict) and "text" in block: | |
| return block["text"] | |
| else: | |
| return str(block) | |
| elif isinstance(content, str): | |
| return content | |
| if isinstance(reply, dict) and "completion" in reply: | |
| return reply["completion"] | |
| return str(reply) | |
| ########################################################### | |
| # PROVIDER: HUGGINGFACE INFERENCE API | |
| ########################################################### | |
| import requests | |
| class HuggingFaceProvider(BaseProvider): | |
| def __init__(self): | |
| self.key = os.getenv("HF_API_KEY", "") | |
| def chat(self, model, messages, **kwargs): | |
| if not messages: | |
| raise ValueError("messages is empty") | |
| text = messages[-1].get("content", "") | |
| r = requests.post( | |
| f"https://api-inference.huggingface.co/models/{model}", | |
| headers={"Authorization": f"Bearer {self.key}"} if self.key else {}, | |
| json={"inputs": text}, | |
| timeout=60 | |
| ) | |
| r.raise_for_status() | |
| out = r.json() | |
| if isinstance(out, list) and out and isinstance(out[0], dict): | |
| return out[0].get("generated_text") or str(out[0]) | |
| return str(out) | |
| ########################################################### | |
| # PROVIDER: NEBIUS (OpenAI-compatible) | |
| ########################################################### | |
| class NebiusProvider(BaseProvider): | |
| def __init__(self): | |
| if OpenAI is None: | |
| raise RuntimeError("openai library not installed; Nebius wrapper expects OpenAI-compatible client") | |
| self.client = OpenAI( | |
| api_key=os.getenv("NEBIUS_API_KEY", ""), | |
| base_url=os.getenv("NEBIUS_BASE_URL", "https://api.studio.nebius.ai/v1") | |
| ) | |
| def chat(self, model, messages, **kwargs): | |
| try: | |
| from openai.types.chat import ( | |
| ChatCompletionUserMessageParam, | |
| ChatCompletionAssistantMessageParam, | |
| ChatCompletionSystemMessageParam, | |
| ) | |
| except Exception: | |
| ChatCompletionUserMessageParam = dict | |
| ChatCompletionAssistantMessageParam = dict | |
| ChatCompletionSystemMessageParam = dict | |
| safe_messages = [] | |
| for m in messages: | |
| role = str(m.get("role", "user")) | |
| content = str(m.get("content", "")) | |
| if role == "user": | |
| safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content)) | |
| elif role == "assistant": | |
| safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) | |
| elif role == "system": | |
| safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content)) | |
| else: | |
| safe_messages.append({"role": role, "content": content}) | |
| r = self.client.chat.completions.create(model=model, messages=safe_messages) | |
| try: | |
| return r.choices[0].message.content | |
| except Exception: | |
| return str(r) | |
| ########################################################### | |
| # PROVIDER: SAMBANOVA (OpenAI-compatible) | |
| ########################################################### | |
| class SambaNovaProvider(BaseProvider): | |
| def __init__(self): | |
| if OpenAI is None: | |
| raise RuntimeError("openai library not installed; SambaNova wrapper expects OpenAI-compatible client") | |
| self.client = OpenAI( | |
| api_key=os.getenv("SAMBANOVA_API_KEY", ""), | |
| base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1") | |
| ) | |
| def chat(self, model, messages, **kwargs): | |
| try: | |
| from openai.types.chat import ( | |
| ChatCompletionUserMessageParam, | |
| ChatCompletionAssistantMessageParam, | |
| ChatCompletionSystemMessageParam, | |
| ) | |
| except Exception: | |
| ChatCompletionUserMessageParam = dict | |
| ChatCompletionAssistantMessageParam = dict | |
| ChatCompletionSystemMessageParam = dict | |
| safe_messages = [] | |
| for m in messages: | |
| role = str(m.get("role", "user")) | |
| content = str(m.get("content", "")) | |
| if role == "user": | |
| safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content)) | |
| elif role == "assistant": | |
| safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) | |
| elif role == "system": | |
| safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content)) | |
| else: | |
| safe_messages.append({"role": role, "content": content}) | |
| r = self.client.chat.completions.create(model=model, messages=safe_messages) | |
| try: | |
| return r.choices[0].message.content | |
| except Exception: | |
| return str(r) | |
| ########################################################### | |
| # PROVIDER: LOCAL TRANSFORMERS (CACHED) | |
| ########################################################### | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| TRANSFORMERS_AVAILABLE = True | |
| except Exception: | |
| TRANSFORMERS_AVAILABLE = False | |
| class LocalLLMProvider(BaseProvider): | |
| """ | |
| Local LLM provider with caching - MODEL LOADS ONCE | |
| """ | |
| def __init__(self, model_name: str = "meta-llama/Llama-3.2-3B-Instruct"): | |
| print(f"[LocalLLM] Initializing with model: {model_name}") | |
| self.model_name = os.getenv("LOCAL_MODEL", model_name) | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| """Initialize model ONCE - this is called only during __init__""" | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| print(f"[LocalLLM] Loading model {self.model_name}...") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[LocalLLM] Using device: {self.device}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| device_map="auto" if self.device == "cuda" else None, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| print(f"[LocalLLM] ✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"[LocalLLM] ❌ Failed to load model: {e}") | |
| self.model = None | |
| traceback.print_exc() | |
| def chat(self, model, messages, **kwargs): | |
| """ | |
| Generate response - MODEL ALREADY LOADED | |
| """ | |
| if self.model is None or self.tokenizer is None: | |
| return "Error: Model or tokenizer not loaded." | |
| # Extract text from messages | |
| text = messages[-1]["content"] if isinstance(messages[-1], dict) and "content" in messages[-1] else str(messages[-1]) | |
| max_tokens = kwargs.get("max_tokens", 128) | |
| temperature = kwargs.get("temperature", 0.7) | |
| import torch | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048 | |
| ).to(self.device) | |
| # Generate (model is already loaded, just inference) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None, | |
| eos_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None | |
| ) | |
| # Decode | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() if self.tokenizer else "Error: Tokenizer not loaded." | |
| return response | |
| ########################################################### | |
| # PROVIDER CACHE - CRITICAL FIX | |
| ########################################################### | |
| class ProviderCache: | |
| """ | |
| Cache provider instances to avoid reloading models | |
| This is the KEY fix - providers are created ONCE and reused | |
| """ | |
| _cache = {} | |
| def get_provider(cls, provider_name: str) -> BaseProvider: | |
| """Get or create cached provider instance""" | |
| if provider_name not in cls._cache: | |
| print(f"[ProviderCache] Creating new instance of {provider_name}") | |
| provider_class = ProviderFactory.providers[provider_name] | |
| cls._cache[provider_name] = provider_class() | |
| else: | |
| print(f"[ProviderCache] Using cached instance of {provider_name}") | |
| return cls._cache[provider_name] | |
| def clear_cache(cls): | |
| """Clear all cached providers (useful for debugging)""" | |
| cls._cache.clear() | |
| print("[ProviderCache] Cache cleared") | |
| ########################################################### | |
| # PROVIDER FACTORY (IMPROVED WITH CACHING) | |
| ########################################################### | |
| class ProviderFactory: | |
| providers = { | |
| "openai": OpenAIProvider, | |
| "anthropic": AnthropicProvider, | |
| "huggingface": HuggingFaceProvider, | |
| "nebius": NebiusProvider, | |
| "sambanova": SambaNovaProvider, | |
| "local": LocalLLMProvider, | |
| } | |
| def get(provider_name: str) -> BaseProvider: | |
| """ | |
| Get provider instance - NOW USES CACHING | |
| This prevents reloading the model on every call | |
| """ | |
| provider_name = provider_name.lower() | |
| if provider_name not in ProviderFactory.providers: | |
| raise ValueError(f"Unknown provider: {provider_name}") | |
| # USE CACHE instead of creating new instance every time | |
| return ProviderCache.get_provider(provider_name) | |
| ########################################################### | |
| # MAIN ENGINE WITH FALLBACK + OPENAI-STYLE API | |
| ########################################################### | |
| class LLMEngine: | |
| def __init__(self): | |
| self.credits = load_credits() | |
| def deduct(self, provider, amount): | |
| if provider not in self.credits: | |
| self.credits[provider] = 0 | |
| self.credits[provider] = max(0, self.credits[provider] - amount) | |
| save_credits(self.credits) | |
| def chat(self, provider: str, model: str, messages: List[Dict], fallback: bool = True, **kwargs): | |
| """ | |
| Main chat method - providers are now cached | |
| """ | |
| try: | |
| p = ProviderFactory.get(provider) # This now returns cached instance | |
| result = p.chat(model=model, messages=messages, **kwargs) | |
| try: | |
| self.deduct(provider, 0.001) | |
| except Exception: | |
| pass | |
| return result | |
| except Exception as exc: | |
| print(f"⚠ Provider '{provider}' failed → fallback activated: {exc}") | |
| traceback.print_exc() | |
| if fallback: | |
| try: | |
| lp = ProviderFactory.get("local") # Gets cached local provider | |
| return lp.chat(model="local", messages=messages, **kwargs) | |
| except Exception as le: | |
| print("Fallback to local provider failed:", le) | |
| traceback.print_exc() | |
| raise | |
| raise | |
| ########################################################### | |
| # EXAMPLES + SIMPLE TESTS | |
| ########################################################### | |
| def main(): | |
| engine = LLMEngine() | |
| print("=== Testing Provider Caching ===") | |
| print("\nFirst call (should load model):") | |
| result1 = engine.chat( | |
| provider="local", | |
| model="meta-llama/Llama-3.2-3B-Instruct", | |
| messages=[{"role": "user", "content": "Say hello"}] | |
| ) | |
| print(f"Response: {result1[:100]}") | |
| print("\nSecond call (should use cached model - NO RELOAD):") | |
| result2 = engine.chat( | |
| provider="local", | |
| model="meta-llama/Llama-3.2-3B-Instruct", | |
| messages=[{"role": "user", "content": "Say goodbye"}] | |
| ) | |
| print(f"Response: {result2[:100]}") | |
| print("\n✅ If you didn't see 'Loading model' twice, caching works!") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--test", action="store_true", help="run examples and simple tests") | |
| args = parser.parse_args() | |
| if args.test: | |
| main() | |
| else: | |
| main() |