Spaces:
Sleeping
Sleeping
File size: 4,570 Bytes
fccfdf4 7ec6449 fccfdf4 ef628bc fccfdf4 ef628bc fccfdf4 7077c22 fccfdf4 ef628bc fccfdf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os
def initialize_model():
"""Initialize the model and tokenizer"""
# Log in to Hugging Face
token = os.environ.get("hf")
login(token)
# Define the model ID and device
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Configure INT8 quantization
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto"
)
# Ensure padding token is defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, device
def format_conversation(conversation_history):
"""Format the conversation history into a single string."""
formatted = ""
for turn in conversation_history:
formatted += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
return formatted.strip()
def generate_response(model, tokenizer, device, prompt, conversation_history):
"""Generate model response"""
# Format the entire conversation context
context = format_conversation(conversation_history[:-1])
if context:
full_prompt = f"{context}\nUser: {prompt}"
else:
full_prompt = f"User: {prompt}"
# Tokenize input
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
# Calculate max new tokens
input_length = inputs["input_ids"].shape[1]
max_model_length = 2048
max_new_tokens = min(200, max_model_length - input_length)
# Generate response
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
min_length=20,
no_repeat_ngram_size=3
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response_parts = response.split("User: ")
model_response = response_parts[-1].split("Assistant: ")[-1].strip()
return model_response
def main():
st.set_page_config(page_title="LLM Chat Interface", page_icon="π€")
st.title("Chat with LLM π€")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize model (only once)
if "model" not in st.session_state:
with st.spinner("Loading the model... This might take a minute..."):
model, tokenizer, device = initialize_model()
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.device = device
# Display chat messages
for message in st.session_state.chat_history:
with st.chat_message("user"):
st.write(message["user"])
with st.chat_message("assistant"):
st.write(message["assistant"])
# Chat input
if prompt := st.chat_input("What would you like to know?"):
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate and display assistant response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
current_turn = {"user": prompt, "assistant": ""}
st.session_state.chat_history.append(current_turn)
response = generate_response(
st.session_state.model,
st.session_state.tokenizer,
st.session_state.device,
prompt,
st.session_state.chat_history
)
st.write(response)
st.session_state.chat_history[-1]["assistant"] = response
# Manage context window
if len(st.session_state.chat_history) > 5:
st.session_state.chat_history = st.session_state.chat_history[-5:]
# Add a clear chat button
if st.sidebar.button("Clear Chat"):
st.session_state.chat_history = []
st.rerun()
if __name__ == "__main__":
main() |