Spaces:
Paused
Paused
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
from peft import PeftModel # Import PeftModel | |
# Define your fine-tuned model ID on Hugging Face Hub | |
# Make sure this matches the repo_id you used when pushing | |
model_id = "whidbeysea/gemma-2b-it-fine-tuned-catechism" | |
# Define the base model ID (the original model you fine-tuned) | |
base_model_id = "google/gemma-2b-it" | |
# Set the device to use (GPU if available, otherwise CPU) | |
# You might need to adjust this based on your Space hardware and configuration | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load the tokenizer | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
print("Tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"Error loading tokenizer from {model_id}: {e}") | |
tokenizer = None | |
# Load the base model | |
# You might need to specify the dtype and quantization based on how you trained | |
# For a T4 High-RAM, using bfloat16 might be possible, or load_in_4bit=True | |
# Since we fine-tuned with LoRA, we load the base model first. | |
base_model = None | |
if device == "cuda": | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, # Try bfloat16 for T4 High-RAM | |
# load_in_4bit=True # You might need this if bfloat16 is not enough | |
) | |
print(f"Base model '{base_model_id}' loaded successfully on GPU.") | |
except Exception as e: | |
print(f"Error loading base model '{base_model_id}' on GPU: {e}") | |
print("Trying to load on CPU or with different settings...") | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_id) | |
print(f"Base model '{base_model_id}' loaded successfully on CPU.") | |
except Exception as e_cpu: | |
print(f"Error loading base model '{base_model_id}' on CPU: {e_cpu}") | |
base_model = None | |
else: # Load on CPU | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_id) | |
print(f"Base model '{base_model_id}' loaded successfully on CPU.") | |
except Exception as e: | |
print(f"Error loading base model '{base_model_id}' on CPU: {e}") | |
base_model = None | |
model = None | |
if base_model and tokenizer: | |
# Load the PEFT model (LoRA adapters) | |
try: | |
model = PeftModel.from_pretrained(base_model, model_id) | |
print(f"PEFT model loaded from {model_id}.") | |
# Move the model to the specified device | |
model.to(device) | |
print("Model moved to device.") | |
# Optional: Merge the LoRA adapters with the base model for potentially faster inference | |
# This might require more memory, so test if it works on your Space hardware | |
# print("Merging LoRA adapters...") | |
# model = model.merge_and_unload() | |
# print("LoRA adapters merged.") | |
except Exception as e: | |
print(f"Error loading PEFT model or moving to device: {e}") | |
model = None | |
generator = None | |
if model and tokenizer: | |
# Create a Hugging Face pipeline for text generation | |
try: | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if device == "cuda" else -1, # Use GPU device 0 if cuda, else CPU -1 | |
# Add other parameters as needed for generation (e.g., max_new_tokens, temperature, top_p, top_k) | |
max_new_tokens=200, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.eos_token_id, # Set pad token id for generation | |
) | |
print("Text generation pipeline created.") | |
except Exception as e: | |
print(f"Error creating text generation pipeline: {e}") | |
generator = None | |
def generate_answer(question): | |
""" | |
Generates an answer using the fine-tuned model based on the input question. | |
""" | |
if generator is None: | |
return "Error: Model or pipeline not loaded." | |
# Format the prompt to match the training data format | |
# We used "Question: [question]\nAnswer: [answer]" during training | |
prompt = f"Question: {question}\nAnswer:" | |
try: | |
# Generate text | |
# The pipeline will handle tokenization and generation | |
response = generator(prompt) | |
# Extract the generated text | |
# The pipeline output is typically a list of dictionaries | |
if response and len(response) > 0 and 'generated_text' in response[0]: | |
generated_text = response[0]['generated_text'] | |
# Post-process the generated text to extract only the answer part | |
# This depends on how your model was trained to respond after "Answer:" | |
# Find the start of the answer after the prompt | |
answer_start_marker = "Answer:" | |
if answer_start_marker in generated_text: | |
answer = generated_text.split(answer_start_marker, 1)[1].strip() | |
# You might want to further clean up the answer, e.g., remove extra newline characters or incomplete sentences. | |
# For example, split by newline and take the parts that look like answers. | |
lines = answer.split('\n') | |
cleaned_answer_lines = [] | |
for line in lines: | |
cleaned_line = line.strip() | |
if cleaned_line and not cleaned_line.startswith("Question:"): # Avoid including subsequent Q&A if generated | |
cleaned_answer_lines.append(cleaned_line) | |
answer = "\n".join(cleaned_answer_lines) | |
else: | |
# If the prefix is not found, return the whole generated text after the prompt | |
# This might happen if the model doesn't follow the expected format | |
answer = generated_text.split(prompt, 1)[-1].strip() # Attempt to remove the prompt | |
return answer if answer else "Could not generate a relevant answer." | |
else: | |
return "Error: Could not generate text from the model." | |
except Exception as e: | |
return f"An error occurred during text generation: {e}" | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=generate_answer, | |
inputs=gr.Textbox(label="Enter your question:"), | |
outputs=gr.Textbox(label="Generated Answer:"), | |
title="LutherAI Catechism Chatbot", | |
description="Ask questions about Luther's Large Catechism.", | |
allow_flagging="never" # Disable flagging for this example | |
) | |
# Launch the Gradio interface | |
iface.launch() |