llama-2 / app.py
sharath6900's picture
streamlit.app
471ecfa verified
%%writefile app.py
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import LoraConfig
from trl import SFTTrainer
from datasets import load_dataset
# Define Streamlit interface
st.title("Llama-2-7b-Chat Fine-Tuned Model")
st.write("This app demonstrates a fine-tuned Llama-2-7b model using QLoRA.")
# Input text prompt
prompt = st.text_input("Enter your prompt:", value="What is opensource llm?")
# Model settings
st.write("Loading the model...")
# Load model and tokenizer
model_name = "NousResearch/Llama-2-7b-chat-hf"
dataset_name = "mlabonne/guanaco-llama2-1k"
# QLoRA parameters
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)
device_map = {"": 0}
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Run inference
if st.button("Generate"):
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
result = pipe(f"<s>[INST] {prompt} [/INST]")
st.write(result[0]['generated_text'])
prompt = "What is open-source LLM?"
print(generate_text(prompt))