|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel |
|
from huggingface_hub import InferenceClient |
|
import re |
|
import torch |
|
|
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1") |
|
base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1") |
|
peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r") |
|
peft_model = peft_model.merge_and_unload() |
|
print("Model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
tokenizer = None |
|
base_model = None |
|
peft_model = None |
|
|
|
|
|
|
|
def respond(message, history, system_message, max_tokens, temperature, top_p): |
|
global tokenizer, peft_model |
|
|
|
if tokenizer is None or peft_model is None: |
|
return "Model loading failed. Please check the logs." |
|
prompt = '' |
|
|
|
prompt = system_message |
|
|
|
|
|
|
|
|
|
|
|
prompt += f"\n<user> input:{message} output:" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
|
|
|
try: |
|
outputs = peft_model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
|
|
|
|
) |
|
generated_text = None |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
except Exception as e: |
|
return f"Generation error: {e}" |
|
|
|
def extract_user_content(text): |
|
""" |
|
Extracts and returns the content that follows the word 'output' in the given text. |
|
If 'output' is not found, returns an empty string. |
|
""" |
|
|
|
pattern = re.compile('<user>(.*?)</user>', re.IGNORECASE | re.DOTALL) |
|
match = pattern.search(text) |
|
if match: |
|
|
|
return match.group(1).strip() |
|
else: |
|
|
|
return "Retry to get output, the model failed to generated required output(This occurs rarely🤷♂️)" |
|
print(generated_text) |
|
lines = extract_user_content(generated_text) |
|
print(lines) |
|
|
|
return lines |
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[ |
|
gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, Here's an example input and output of the task <Example> Exampleinput : 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>. Understand the task and Only provide the normalized output with atmost accuracy",label="System message"), |
|
gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.95, step=0.1, label="Temperature") |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |