patrikpavlov's picture
Update handler.py
8d4f69c verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import json
class EndpointHandler:
def __init__(self, path=""):
"""
Initializes the model and tokenizer.
"""
model_name = "patrikpavlov/llama-finance-sentiment"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token # Important for generation
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
def __call__(self, data: dict) -> list:
"""
Handles an incoming request, runs inference, and returns the response.
"""
inputs = data.pop("inputs", "")
if not inputs:
return [{"error": "Input 'inputs' is required."}]
parameters = data.pop("parameters", {"max_new_tokens": 50})
# --- NEW PROMPT STRATEGY ---
# We give a very specific instruction and a schema for the model to follow.
# This is a more reliable way to get JSON output than the 'response_format' parameter.
prompt = f"""
Analyze the sentiment of the financial news text provided. You must respond with only a valid JSON object. Do not add any other text before or after the JSON.
The JSON object must follow this exact schema:
{{
"sentiment": "string"
}}
The value for "sentiment" must be one of the following three strings: "Positive", "Negative", or "Neutral".
Here is the financial news text to analyze:
---
{inputs}
---
"""
messages = [
{"role": "user", "content": prompt}
]
chat_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_ids = self.tokenizer(
chat_prompt,
return_tensors="pt"
).input_ids.to(self.model.device)
with torch.no_grad():
# Generate the text without the failing 'response_format' argument
output_tokens = self.model.generate(
input_ids,
**parameters
)
newly_generated_tokens = output_tokens[0][len(input_ids[0]):]
generated_text = self.tokenizer.decode(
newly_generated_tokens,
skip_special_tokens=True
)
# Clean up and parse the generated text to find the JSON
try:
# Find the start and end of the JSON object
json_start = generated_text.find('{')
json_end = generated_text.rfind('}') + 1
if json_start != -1 and json_end != -1:
json_string = generated_text[json_start:json_end]
json_output = json.loads(json_string)
return [json_output]
else:
raise ValueError("No JSON object found in the output.")
except (json.JSONDecodeError, ValueError) as e:
return [{"error": f"Failed to parse JSON from model output: {e}", "raw_output": generated_text}]