aikanava commited on
Commit
6412482
·
verified ·
1 Parent(s): 59b18a9

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +64 -0
  2. requirements.txt +5 -3
  3. train_model.py +104 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # === CONFIG ===
7
+ MODEL_PATH = 'trained_model/asl_model.h5'
8
+ IMG_SIZE = 64
9
+ CLASS_NAMES = [chr(i) for i in range(65, 91)] # A-Z
10
+
11
+ # Load model once
12
+ @st.cache_resource(show_spinner=False)
13
+ def load_model():
14
+ return tf.keras.models.load_model(MODEL_PATH)
15
+
16
+ model = load_model()
17
+
18
+ # === UI Header ===
19
+ st.set_page_config(page_title="ASL Recognition", page_icon="🧠", layout="centered")
20
+ st.markdown("<h1 style='text-align: center;'>🧠 ASL Alphabet Recognition</h1>", unsafe_allow_html=True)
21
+ st.markdown("<p style='text-align: center;'>Upload a hand gesture image and get instant letter prediction.</p>", unsafe_allow_html=True)
22
+ st.divider()
23
+
24
+ # === Helper Functions ===
25
+ def preprocess_image(image: Image.Image):
26
+ img = image.convert("RGB")
27
+ img = img.resize((IMG_SIZE, IMG_SIZE))
28
+ img = np.array(img) / 255.0
29
+ img = np.expand_dims(img, axis=0)
30
+ return img
31
+
32
+ def predict(img: Image.Image):
33
+ processed = preprocess_image(img)
34
+ preds = model.predict(processed)
35
+ class_idx = np.argmax(preds)
36
+ confidence = preds[0][class_idx]
37
+ return CLASS_NAMES[class_idx], confidence
38
+
39
+ # === Upload UI ===
40
+ uploaded_file = st.file_uploader("📁 Upload a hand gesture image", type=['png', 'jpg', 'jpeg'])
41
+
42
+ if uploaded_file:
43
+ col1, col2 = st.columns([1, 2])
44
+ with col1:
45
+ img = Image.open(uploaded_file)
46
+ st.image(img, caption="📷 Uploaded Image", use_column_width=True)
47
+ with col2:
48
+ st.write("### 🔍 Prediction")
49
+ label, confidence = predict(img)
50
+ st.success(f"Predicted Letter: **:blue[{label}]**")
51
+ st.metric(label="Confidence Score", value=f"{confidence * 100:.2f}%", delta=None)
52
+
53
+ # Optional: show full probabilities as a horizontal bar chart
54
+ preds = model.predict(preprocess_image(img))[0]
55
+ top_indices = np.argsort(preds)[::-1][:5]
56
+ st.write("#### 🔢 Top 5 Predictions")
57
+ for i in top_indices:
58
+ st.progress(float(preds[i]), text=f"{CLASS_NAMES[i]}: {preds[i]*100:.2f}%")
59
+ else:
60
+ st.info("📸 Upload a clear image showing a single hand gesture on a plain background.")
61
+
62
+ # === Footer ===
63
+ st.divider()
64
+ st.markdown("<small style='text-align:center; display:block;'>Developed with ❤️ using TensorFlow & Streamlit</small>", unsafe_allow_html=True)
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ tensorflow
2
+ streamlit
3
+ opencv-python-headless
4
+ numpy
5
+ matplotlib
train_model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ from tensorflow.keras.models import Sequential
6
+ from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
7
+ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
8
+ import matplotlib.pyplot as plt
9
+
10
+ # === CONFIGURATION ===
11
+ DATA_DIR = 'asl_alphabet_train' # Folder with A-Z subfolders containing images
12
+ MODEL_SAVE_PATH = 'trained_model/asl_model.h5'
13
+ IMG_SIZE = 64
14
+ BATCH_SIZE = 32
15
+ EPOCHS = 20
16
+ NUM_CLASSES = 26
17
+
18
+ # Create output directories if they don't exist
19
+ os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
20
+ os.makedirs('outputs', exist_ok=True)
21
+
22
+ # === DATA GENERATORS ===
23
+ train_datagen = ImageDataGenerator(
24
+ rescale=1./255,
25
+ validation_split=0.2,
26
+ rotation_range=15,
27
+ zoom_range=0.1,
28
+ width_shift_range=0.1,
29
+ height_shift_range=0.1,
30
+ horizontal_flip=True
31
+ )
32
+
33
+ train_generator = train_datagen.flow_from_directory(
34
+ DATA_DIR,
35
+ target_size=(IMG_SIZE, IMG_SIZE),
36
+ batch_size=BATCH_SIZE,
37
+ class_mode='categorical',
38
+ subset='training',
39
+ shuffle=True,
40
+ seed=42
41
+ )
42
+
43
+ validation_generator = train_datagen.flow_from_directory(
44
+ DATA_DIR,
45
+ target_size=(IMG_SIZE, IMG_SIZE),
46
+ batch_size=BATCH_SIZE,
47
+ class_mode='categorical',
48
+ subset='validation',
49
+ shuffle=False,
50
+ seed=42
51
+ )
52
+
53
+ # === MODEL ARCHITECTURE ===
54
+ model = Sequential([
55
+ Conv2D(32, (3,3), activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 3)),
56
+ MaxPooling2D(2,2),
57
+
58
+ Conv2D(64, (3,3), activation='relu'),
59
+ MaxPooling2D(2,2),
60
+
61
+ Conv2D(128, (3,3), activation='relu'),
62
+ MaxPooling2D(2,2),
63
+
64
+ Flatten(),
65
+ Dense(128, activation='relu'),
66
+ Dropout(0.5),
67
+ Dense(NUM_CLASSES, activation='softmax')
68
+ ])
69
+
70
+ model.compile(optimizer='adam',
71
+ loss='categorical_crossentropy',
72
+ metrics=['accuracy'])
73
+
74
+ model.summary()
75
+
76
+ # === CALLBACKS ===
77
+ checkpoint = ModelCheckpoint(MODEL_SAVE_PATH, save_best_only=True, monitor='val_accuracy', mode='max')
78
+ early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
79
+
80
+ # === TRAINING ===
81
+ history = model.fit(
82
+ train_generator,
83
+ validation_data=validation_generator,
84
+ epochs=EPOCHS,
85
+ callbacks=[checkpoint, early_stop]
86
+ )
87
+
88
+ # === PLOT TRAINING HISTORY ===
89
+ plt.figure(figsize=(12,5))
90
+
91
+ plt.subplot(1,2,1)
92
+ plt.plot(history.history['accuracy'], label='Train Accuracy')
93
+ plt.plot(history.history['val_accuracy'], label='Val Accuracy')
94
+ plt.legend()
95
+ plt.title('Accuracy')
96
+
97
+ plt.subplot(1,2,2)
98
+ plt.plot(history.history['loss'], label='Train Loss')
99
+ plt.plot(history.history['val_loss'], label='Val Loss')
100
+ plt.legend()
101
+ plt.title('Loss')
102
+
103
+ plt.savefig('outputs/training_plot.png')
104
+ plt.show()