menamiai / mllm_tools /litellm.py
dfdfdsfgs's picture
ENCODING FIX: Clean app.py without Unicode issues - Fixed encoding problems that prevented HF Spaces from loading the interface properly
dd65f93
import json
import re
from typing import List, Dict, Any, Union, Optional
import io
import os
import base64
from PIL import Image
import mimetypes
import litellm
from litellm import completion, completion_cost
from dotenv import load_dotenv
import random
load_dotenv()
class LiteLLMWrapper:
"""Wrapper for LiteLLM to support multiple models and logging"""
def __init__(
self,
model_name: str = "gpt-4-vision-preview",
temperature: float = 0.7,
print_cost: bool = False,
verbose: bool = False,
use_langfuse: bool = True,
):
"""
Initialize the LiteLLM wrapper
Args:
model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro")
temperature: Temperature for completion
print_cost: Whether to print the cost of the completion
verbose: Whether to print verbose output
use_langfuse: Whether to enable Langfuse logging
"""
self.model_name = model_name
self.temperature = temperature
self.print_cost = print_cost
self.verbose = verbose
self.accumulated_cost = 0
# Handle Gemini API key fallback mechanism
if "gemini" in model_name.lower():
self._setup_gemini_api_key()
if self.verbose:
os.environ['LITELLM_LOG'] = 'DEBUG'
# Set langfuse callback only if enabled
if use_langfuse:
litellm.success_callback = ["langfuse"]
litellm.failure_callback = ["langfuse"]
def _setup_gemini_api_key(self):
"""Setup Gemini API key with fallback mechanism for multiple keys."""
from dotenv import load_dotenv
load_dotenv(override=True)
gemini_key_env = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
if not gemini_key_env:
raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.")
# Support comma-separated list of API keys with random selection
if ',' in gemini_key_env:
keys = [key.strip() for key in gemini_key_env.split(',') if key.strip()]
if not keys:
raise ValueError("No valid API keys found in GEMINI_API_KEY list.")
api_key = random.choice(keys)
print(f"Selected random Gemini API key from {len(keys)} available keys: {api_key[:20]}...")
else:
api_key = gemini_key_env
print(f"Using single Gemini API key: {api_key[:20]}...")
# Set the selected API key for LiteLLM
os.environ["GEMINI_API_KEY"] = api_key
os.environ["GOOGLE_API_KEY"] = api_key
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
"""
Encode local file or PIL Image to base64 string
Args:
file_path: Path to local file or PIL Image object
Returns:
Base64 encoded file string
"""
if isinstance(file_path, Image.Image):
buffered = io.BytesIO()
file_path.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
else:
with open(file_path, "rb") as file:
return base64.b64encode(file.read()).decode("utf-8")
def _get_mime_type(self, file_path: str) -> str:
"""
Get the MIME type of a file based on its extension
Args:
file_path: Path to the file
Returns:
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
"""
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
raise ValueError(f"Unsupported file type: {file_path}")
return mime_type
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Process messages and return completion
Args:
messages: List of message dictionaries with 'type' and 'content' keys
metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
Returns:
Generated text response
"""
if metadata is None:
print("No metadata provided, using empty metadata")
metadata = {}
metadata["trace_name"] = f"litellm-completion-{self.model_name}"
# Convert messages to LiteLLM format
formatted_messages = []
for msg in messages:
if msg["type"] == "text":
formatted_messages.append({
"role": "user",
"content": [{"type": "text", "text": msg["content"]}]
})
elif msg["type"] in ["image", "audio", "video"]:
# Check if content is a local file path or PIL Image
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
try:
if isinstance(msg["content"], Image.Image):
mime_type = "image/png"
else:
mime_type = self._get_mime_type(msg["content"])
base64_data = self._encode_file(msg["content"])
data_url = f"data:{mime_type};base64,{base64_data}"
except ValueError as e:
print(f"Error processing file {msg['content']}: {e}")
continue
else:
data_url = msg["content"]
# Append the formatted message based on the model
if "gemini" in self.model_name:
formatted_messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": data_url
}
]
})
elif "gpt" in self.model_name:
# GPT and other models expect a different format
if msg["type"] == "image":
# Default format for images and videos in GPT
formatted_messages.append({
"role": "user",
"content": [
{
"type": f"image_url",
f"{msg['type']}_url": {
"url": data_url,
"detail": "high"
}
}
]
})
else:
raise ValueError("For GPT, only text and image inferencing are supported")
else:
raise ValueError("Only support Gemini and Gpt for Multimodal capability now")
try:
# if it's openai o series model, set temperature to None and reasoning_effort to "medium"
if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)):
self.temperature = None
self.reasoning_effort = "medium"
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
reasoning_effort=self.reasoning_effort,
metadata=metadata,
max_retries=99
)
else:
response = completion(
model=self.model_name,
messages=formatted_messages,
temperature=self.temperature,
metadata=metadata,
max_retries=99
)
if self.print_cost:
# pass your response from completion to completion_cost
cost = completion_cost(completion_response=response)
formatted_string = f"Cost: ${float(cost):.10f}"
# print(formatted_string)
self.accumulated_cost += cost
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
content = response.choices[0].message.content
if content is None:
print(f"Got null response from model. Full response: {response}")
return content
except Exception as e:
print(f"Error in model completion: {e}")
return str(e)
if __name__ == "__main__":
pass