safiaa02 commited on
Commit
1fa3d49
·
verified ·
1 Parent(s): f173d67

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import streamlit as st
4
+ import faiss
5
+ import numpy as np
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from sentence_transformers import SentenceTransformer
9
+ from reportlab.lib.pagesizes import A4
10
+ from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
11
+ from reportlab.lib.styles import getSampleStyleSheet
12
+
13
+ # Load milestones data
14
+ with open('milestones.json', 'r') as f:
15
+ milestones = json.load(f)
16
+
17
+ # Age categories for dropdown selection
18
+ age_categories = {
19
+ "Up to 2 months": 2, "Up to 4 months": 4, "Up to 6 months": 6,
20
+ "Up to 9 months": 9, "Up to 1 year": 12, "Up to 15 months": 15,
21
+ "Up to 18 months": 18, "Up to 2 years": 24, "Up to 30 months": 30,
22
+ "Up to 3 years": 36, "Up to 4 years": 48, "Up to 5 years": 60
23
+ }
24
+
25
+ # Initialize FAISS and Sentence Transformer
26
+ model = SentenceTransformer('all-MiniLM-L6-v2')
27
+
28
+ def create_faiss_index(data):
29
+ descriptions, age_keys = [], []
30
+ for age, categories in data.items():
31
+ for entry in categories:
32
+ descriptions.append(entry['description'])
33
+ age_keys.append(int(age)) # Convert age to int
34
+
35
+ embeddings = model.encode(descriptions, convert_to_numpy=True)
36
+ index = faiss.IndexFlatL2(embeddings.shape[1])
37
+ index.add(embeddings)
38
+ return index, descriptions, age_keys
39
+
40
+ index, descriptions, age_keys = create_faiss_index(milestones)
41
+
42
+ # Function to retrieve the closest milestone
43
+ def retrieve_milestone(user_input):
44
+ user_embedding = model.encode([user_input], convert_to_numpy=True)
45
+ _, indices = index.search(user_embedding, 1)
46
+ return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
47
+
48
+ # Load IBM Granite model and tokenizer
49
+ model_name = "ibm/granite-13b-chat"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ granite_model = AutoModelForCausalLM.from_pretrained(
52
+ model_name, torch_dtype=torch.float16, device_map="auto"
53
+ )
54
+
55
+ def generate_response(user_input, child_age):
56
+ relevant_milestone = retrieve_milestone(user_input)
57
+ prompt = (
58
+ f"The child is {child_age} months old. Based on the given traits: {user_input}, "
59
+ f"determine whether the child is meeting expected milestones. "
60
+ f"Relevant milestone: {relevant_milestone}. "
61
+ "If there are any concerns, suggest steps the parents can take."
62
+ )
63
+
64
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
65
+ output = granite_model.generate(**inputs, max_length=512)
66
+ return tokenizer.decode(output[0], skip_special_tokens=True)
67
+
68
+ # Streamlit UI Styling
69
+ st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
70
+
71
+ st.markdown("""
72
+ <style>
73
+ .stApp { background-color: #1e1e2e; color: #ffffff; }
74
+ .stTitle { text-align: center; color: #ffcc00; font-size: 36px; font-weight: bold; }
75
+ .stButton > button { background-color: #ffcc00; color: #000; border-radius: 5px; font-weight: bold; }
76
+ .stSelectbox, .stTextArea { background-color: #2e2e42; color: #ffffff; border-radius: 5px; }
77
+ </style>
78
+ """, unsafe_allow_html=True)
79
+
80
+ st.markdown("<h1 class='stTitle'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
81
+ st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.", unsafe_allow_html=True)
82
+
83
+ # User selects child's age
84
+ selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
85
+ child_age = age_categories[selected_age]
86
+
87
+ # User input for traits and skills
88
+ placeholder_text = "For example, your child might say simple words like 'mama' and 'dada' and smile when spoken to. They may grasp small objects with their fingers and show excitement during playtime."
89
+ user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
90
+
91
+ def generate_pdf_report(ai_response):
92
+ pdf_file = "progress_report.pdf"
93
+ doc = SimpleDocTemplate(pdf_file, pagesize=A4)
94
+ styles = getSampleStyleSheet()
95
+
96
+ elements = [
97
+ Paragraph("Child Development Progress Report", styles['Title']),
98
+ Spacer(1, 12),
99
+ Paragraph("Development Insights:", styles['Heading2']),
100
+ Spacer(1, 10)
101
+ ]
102
+
103
+ for part in ai_response.split('\n'):
104
+ part = part.strip().lstrip('0123456789.- ')
105
+ if part:
106
+ elements.append(Paragraph(f"• {part}", styles['Normal']))
107
+ elements.append(Spacer(1, 5))
108
+
109
+ disclaimer = ("This report is AI-generated and is for informational purposes only. "
110
+ "It should not be considered a substitute for professional medical advice. "
111
+ "Always consult a qualified pediatrician for expert guidance on your child's development.")
112
+ elements.append(Spacer(1, 12))
113
+ elements.append(Paragraph(disclaimer, styles['Italic']))
114
+
115
+ doc.build(elements)
116
+ return pdf_file
117
+
118
+ if st.button("🔍 Analyze", help="Click to analyze the child's development milestones"):
119
+ ai_response = generate_response(user_input, child_age)
120
+
121
+ st.subheader("📊 Development Insights:")
122
+ st.markdown(f"<div style='background-color:#44475a; color:#ffffff; padding: 15px; border-radius: 10px;'>{ai_response}</div>", unsafe_allow_html=True)
123
+
124
+ pdf_file = generate_pdf_report(ai_response)
125
+ with open(pdf_file, "rb") as f:
126
+ st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
127
+
128
+ st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")