File size: 4,570 Bytes
fccfdf4
7ec6449
fccfdf4
 
ef628bc
 
fccfdf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef628bc
fccfdf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7077c22
fccfdf4
 
 
 
 
 
 
 
 
 
 
 
ef628bc
fccfdf4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os

def initialize_model():
    """Initialize the model and tokenizer"""
    # Log in to Hugging Face
    token = os.environ.get("hf")
    login(token)

    # Define the model ID and device
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Configure INT8 quantization
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_enable_fp32_cpu_offload=True
    )

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto"
    )

    # Ensure padding token is defined
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer, device

def format_conversation(conversation_history):
    """Format the conversation history into a single string."""
    formatted = ""
    for turn in conversation_history:
        formatted += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
    return formatted.strip()

def generate_response(model, tokenizer, device, prompt, conversation_history):
    """Generate model response"""
    # Format the entire conversation context
    context = format_conversation(conversation_history[:-1])
    if context:
        full_prompt = f"{context}\nUser: {prompt}"
    else:
        full_prompt = f"User: {prompt}"

    # Tokenize input
    inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)

    # Calculate max new tokens
    input_length = inputs["input_ids"].shape[1]
    max_model_length = 2048
    max_new_tokens = min(200, max_model_length - input_length)

    # Generate response
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=max_new_tokens,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
        min_length=20,
        no_repeat_ngram_size=3
    )

    # Decode response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response_parts = response.split("User: ")
    model_response = response_parts[-1].split("Assistant: ")[-1].strip()
    
    return model_response

def main():
    st.set_page_config(page_title="LLM Chat Interface", page_icon="πŸ€–")
    st.title("Chat with LLM πŸ€–")

    # Initialize session state for chat history
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

    # Initialize model (only once)
    if "model" not in st.session_state:
        with st.spinner("Loading the model... This might take a minute..."):
            model, tokenizer, device = initialize_model()
            st.session_state.model = model
            st.session_state.tokenizer = tokenizer
            st.session_state.device = device

    # Display chat messages
    for message in st.session_state.chat_history:
        with st.chat_message("user"):
            st.write(message["user"])
        with st.chat_message("assistant"):
            st.write(message["assistant"])

    # Chat input
    if prompt := st.chat_input("What would you like to know?"):
        # Display user message
        with st.chat_message("user"):
            st.write(prompt)

        # Generate and display assistant response
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                current_turn = {"user": prompt, "assistant": ""}
                st.session_state.chat_history.append(current_turn)
                
                response = generate_response(
                    st.session_state.model,
                    st.session_state.tokenizer,
                    st.session_state.device,
                    prompt,
                    st.session_state.chat_history
                )
                
                st.write(response)
                st.session_state.chat_history[-1]["assistant"] = response

        # Manage context window
        if len(st.session_state.chat_history) > 5:
            st.session_state.chat_history = st.session_state.chat_history[-5:]

    # Add a clear chat button
    if st.sidebar.button("Clear Chat"):
        st.session_state.chat_history = []
        st.rerun()

if __name__ == "__main__":
    main()