Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import streamlit as st
|
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
-
from transformers import
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
from reportlab.lib.pagesizes import A4
|
| 10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
|
@@ -23,7 +23,7 @@ age_categories = {
|
|
| 23 |
}
|
| 24 |
|
| 25 |
# Initialize FAISS and Sentence Transformer
|
| 26 |
-
|
| 27 |
|
| 28 |
def create_faiss_index(data):
|
| 29 |
descriptions, age_keys = [], []
|
|
@@ -32,7 +32,7 @@ def create_faiss_index(data):
|
|
| 32 |
descriptions.append(entry['description'])
|
| 33 |
age_keys.append(int(age)) # Convert age to int
|
| 34 |
|
| 35 |
-
embeddings =
|
| 36 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 37 |
index.add(embeddings)
|
| 38 |
return index, descriptions, age_keys
|
|
@@ -41,14 +41,14 @@ index, descriptions, age_keys = create_faiss_index(milestones)
|
|
| 41 |
|
| 42 |
# Function to retrieve the closest milestone
|
| 43 |
def retrieve_milestone(user_input):
|
| 44 |
-
user_embedding =
|
| 45 |
_, indices = index.search(user_embedding, 1)
|
| 46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 47 |
|
| 48 |
-
# Load IBM Granite model and tokenizer
|
| 49 |
-
model_name = "ibm/granite-
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
-
granite_model =
|
| 52 |
model_name, torch_dtype=torch.float16, device_map="auto"
|
| 53 |
)
|
| 54 |
|
|
@@ -125,4 +125,4 @@ if st.button("🔍 Analyze", help="Click to analyze the child's development mile
|
|
| 125 |
with open(pdf_file, "rb") as f:
|
| 126 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
| 127 |
|
| 128 |
-
st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
|
|
|
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
from reportlab.lib.pagesizes import A4
|
| 10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
# Initialize FAISS and Sentence Transformer
|
| 26 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 27 |
|
| 28 |
def create_faiss_index(data):
|
| 29 |
descriptions, age_keys = [], []
|
|
|
|
| 32 |
descriptions.append(entry['description'])
|
| 33 |
age_keys.append(int(age)) # Convert age to int
|
| 34 |
|
| 35 |
+
embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
|
| 36 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 37 |
index.add(embeddings)
|
| 38 |
return index, descriptions, age_keys
|
|
|
|
| 41 |
|
| 42 |
# Function to retrieve the closest milestone
|
| 43 |
def retrieve_milestone(user_input):
|
| 44 |
+
user_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
|
| 45 |
_, indices = index.search(user_embedding, 1)
|
| 46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 47 |
|
| 48 |
+
# Load IBM Granite 3.1 model and tokenizer
|
| 49 |
+
model_name = "ibm-granite/granite-3.1-8b-instruct"
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
+
granite_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 52 |
model_name, torch_dtype=torch.float16, device_map="auto"
|
| 53 |
)
|
| 54 |
|
|
|
|
| 125 |
with open(pdf_file, "rb") as f:
|
| 126 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
| 127 |
|
| 128 |
+
st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
|