Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering | |
from PIL import Image | |
import easyocr | |
import os | |
from groq import Groq, APIConnectionError, AuthenticationError | |
# OCR Function | |
def extract_text_from_image(image): | |
reader = easyocr.Reader(['en']) | |
result = reader.readtext(image) | |
extracted_text = " ".join([detection[1] for detection in result]) | |
return extracted_text | |
# Question Answering Function (DistilBERT) | |
def load_qa_model(): | |
model_name = "distilbert/distilbert-base-cased-distilled-squad" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
nlp = pipeline('question-answering', model=model, tokenizer=tokenizer) | |
return nlp | |
def answer_question(context, question, qa_model): | |
result = qa_model({'question': question, 'context': context}) | |
return result['answer'] | |
# Groq API Function | |
def groq_chat(prompt): | |
try: | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama-3.3-70b-versatile", | |
) | |
return chat_completion.choices[0].message.content | |
except APIConnectionError as e: | |
return f"Groq API Connection Error: {e}" | |
except AuthenticationError as e: | |
return f"Groq API Authentication Error: {e}" | |
except Exception as e: | |
return f"General Groq API Error: {e}" | |
# Streamlit App | |
def main(): | |
st.title("Image Text & Question Answering Chatbot") | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
if st.button("Extract Text and Enable Question Answering"): | |
with st.spinner("Extracting text..."): | |
extracted_text = extract_text_from_image(image) | |
st.session_state.extracted_text = extracted_text # Store in session state | |
st.write("Extracted Text:") | |
st.write(st.session_state.extracted_text) | |
if "extracted_text" in st.session_state: # Check if extracted_text is in session state | |
qa_model = load_qa_model() | |
question = st.text_input("Ask a question about the image text:") | |
if st.button("Answer"): | |
if question: | |
with st.spinner("Answering..."): | |
answer = answer_question(st.session_state.extracted_text, question, qa_model) | |
st.write("Answer:", answer) | |
else: | |
st.warning("Please enter a question.") | |
# Groq Chat Section | |
st.subheader("General Chat (Powered by Groq)") | |
groq_prompt = st.text_input("Enter your message:") | |
if st.button("Send"): | |
if groq_prompt: | |
with st.spinner("Generating response..."): | |
groq_response = groq_chat(groq_prompt) | |
st.write("Response:", groq_response) | |
else: | |
st.warning("Please enter a message.") | |
if __name__ == "__main__": | |
main() |