import torch
import numpy as np
import pandas as pd
import gradio as gr
import tempfile
import subprocess
from matplotlib.animation import FFMpegWriter, PillowWriter
import matplotlib.pyplot as plt
from matplotlib import animation
from config import MAX_TEXT_LEN
from data import selective_smoothing, GLOBAL_MEAN_T, GLOBAL_STD_T
from model import TextToPoseSeq2Seq
from transformers import BertTokenizer
#whisper
from faster_whisper import WhisperModel
# === Tokenizer and Model Init ===
tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2")
# === Check FFmpeg availability ===
print("FFmpeg check:", subprocess.run(["ffmpeg", "-version"], capture_output=True).stdout.decode().splitlines()[0])
# === Load model ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TextToPoseSeq2Seq(tokenizer.vocab_size).to(device)
model.load_state_dict(torch.load("best_seq2seq_model_mask.pth", map_location=device))
model.to(device)
print("Loaded pretrained weights.")
# === Load annotated.csv ===
annot_df = pd.read_csv("annotated.csv")
annotated_words = set(annot_df["text"].str.strip().str.lower().unique())
# Load the final annotated CSV with video links
video_df = pd.read_csv("annotated_vid_link.csv")
video_df["text_clean"] = video_df["text"].str.strip().str.lower()
video_lookup = dict(zip(video_df["text_clean"], video_df["Video URL"]))
# === Load Whisper ===
whisper_model = WhisperModel("small", compute_type="int8")
def transcribe_audio(audio_path):
try:
segments, _ = whisper_model.transcribe(audio_path, language="ms", beam_size=5)
full_text = " ".join([segment.text.strip() for segment in segments])
return full_text.strip()
except Exception as e:
print("Whisper Error:", e)
return ""
def get_youtube_link(input_text):
return video_lookup.get(input_text.strip().lower())
# === Keypoint setup ===
selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
NUM_KEYPOINTS = len(selected_keypoint_indices)
POSE_DIM = NUM_KEYPOINTS * 3
# === Connection graph ===
mediapipe_connections = [
(0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8),
(9, 10), (11, 12), (12, 14), (14, 16), (11, 13), (13, 15),
(23, 24), (11, 23), (12, 24)
]
def add_hand_connections(base_index):
return [(base_index + i, base_index + j) for i, j in [
(0,1), (1,2), (2,3), (3,4),
(0,5), (5,6), (6,7), (7,8),
(0,9), (9,10), (10,11), (11,12),
(0,13), (13,14), (14,15), (15,16),
(0,17), (17,18), (18,19), (19,20)
]]
hand1 = selected_keypoint_indices.index(501)
hand2 = selected_keypoint_indices.index(522)
mediapipe_connections += add_hand_connections(hand1)
mediapipe_connections += add_hand_connections(hand2)
# === Pose inference ===
def concatenate_and_smooth_sequences(sentence, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T):
sentence = sentence.strip()
if sentence.lower() in annotated_words:
words = [sentence]
else:
words = sentence.split()
pose_preds, conf_preds, frame_labels = [], [], []
current_frame = 0
model.eval()
with torch.no_grad():
for word in words:
inputs = tokenizer(word, padding="max_length", truncation=True, max_length=MAX_TEXT_LEN, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attn_mask = inputs.attention_mask.to(device)
pred_pose_norm, pred_conf = model(input_ids, attention_mask=attn_mask)
T = pred_pose_norm.shape[1]
pose_preds.append(pred_pose_norm[0])
conf_preds.append(pred_conf[0])
frame_labels.append((word, current_frame, current_frame + T - 1))
current_frame += T
if not pose_preds:
return None, None, None
full_pose = torch.cat(pose_preds, dim=0).unsqueeze(0)
full_conf = torch.cat(conf_preds, dim=0)
smoothed = selective_smoothing(full_pose).squeeze(0)
unnormalized = smoothed * GLOBAL_STD_T + GLOBAL_MEAN_T
return unnormalized.view(-1, NUM_KEYPOINTS, 3).cpu().numpy(), full_conf.cpu().numpy(), frame_labels
# === Animation ===
def animate_pose(pred_pose, pred_conf=None, frame_labels=None, interval=150, conf_threshold=0.3):
fig, ax = plt.subplots(figsize=(5, 5))
fig.subplots_adjust(top=0.85)
def setup():
all_x, all_y = pred_pose[:,:,0].flatten(), pred_pose[:,:,1].flatten()
x_buf = (all_x.max() - all_x.min()) * 0.1 + 0.1
y_buf = (all_y.max() - all_y.min()) * 0.1 + 0.1
ax.set_xlim(all_x.min() - x_buf, all_x.max() + x_buf)
ax.set_ylim(-all_y.max() - y_buf, -all_y.min() + y_buf)
ax.set_aspect('equal')
ax.axis('off')
setup()
pred_lines = [ax.plot([], [], color='red', lw=2)[0] for _ in mediapipe_connections]
pred_pts = ax.plot([], [], 'ko', markersize=3)[0]
def init():
for line in pred_lines: line.set_data([], [])
pred_pts.set_data([], [])
return pred_lines + [pred_pts]
def update(frame):
px, py = pred_pose[frame,:,0], -pred_pose[frame,:,1]
if pred_conf is not None:
mask = pred_conf[frame] > conf_threshold
px[~mask], py[~mask] = np.nan, np.nan
pred_pts.set_data(px, py)
for text in ax.texts: text.remove()
for i in range(NUM_KEYPOINTS):
ax.text(px[i], py[i], str(i), fontsize=5, color='black')
for i, (start, end) in enumerate(mediapipe_connections):
if pred_conf is None or (pred_conf[frame][start] > conf_threshold and pred_conf[frame][end] > conf_threshold):
pred_lines[i].set_data([px[start], px[end]], [py[start], py[end]])
else:
pred_lines[i].set_data([], [])
if frame_labels:
for word, start, end in frame_labels:
if start <= frame <= end:
ax.set_title(f'Prediction: “{word}” (Frames {start}–{end})', fontsize=12, pad=15)
break
return pred_lines + [pred_pts]
ani = animation.FuncAnimation(fig, update, frames=len(pred_pose), init_func=init, blit=True, interval=interval)
plt.tight_layout(pad=2.0)
return ani
# === Save video or fallback to GIF ===
def save_animation(anim, format="mp4"):
try:
ext = ".mp4" if format == "mp4" else ".gif"
writer = FFMpegWriter(fps=10, bitrate=1800) if format == "mp4" else PillowWriter(fps=10)
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as f:
path = f.name
anim.save(path, writer=writer)
return path
except Exception as e:
print(f"{format.upper()} save failed:", e)
return None
def find_all_matches(text, video_lookup):
words = text.strip().lower().split()
matched = []
i = 0
while i < len(words):
found = False
# Try trigram
if i + 2 < len(words):
phrase3 = " ".join(words[i:i+3])
if phrase3 in video_lookup:
matched.append((phrase3, video_lookup[phrase3]))
i += 3
found = True
continue
# Try bigram
if i + 1 < len(words):
phrase2 = " ".join(words[i:i+2])
if phrase2 in video_lookup:
matched.append((phrase2, video_lookup[phrase2]))
i += 2
found = True
continue
# Try unigram
word = words[i]
if word in video_lookup:
matched.append((word, video_lookup[word]))
else:
matched.append((word, None)) # not in lookup
i += 1
return matched
def predict(text, threshold, show_videos=True):
if not text.strip():
return None, "⚠️ Please enter valid text.", ""
try:
pose, conf, labels = concatenate_and_smooth_sequences(text, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T)
if pose is None: return None, "⚠️ No pose predicted.", ""
anim = animate_pose(pose, pred_conf=conf, frame_labels=labels, conf_threshold=threshold)
path = save_animation(anim, format="mp4") or save_animation(anim, format="gif")
if not path:
return None, "❌ Failed to save animation.", ""
# === Generate Text Output + Video Preview ===
cleaned_text = text.strip().lower()
result_text, video_html = build_result_with_video_links(cleaned_text, video_lookup, show_videos)
return path, result_text, video_html
except Exception as e:
print("Error during prediction:", e)
return None, f"❌ Runtime error: {str(e)}", ""
def build_result_with_video_links(cleaned_text, video_lookup, show_videos=True):
checks = ["**Match Check (Phrase + Word Level):**"]
html_blocks = []
if cleaned_text in video_lookup:
checks.append(f'- “{cleaned_text}” ✅ in dataset')
if show_videos:
url = video_lookup[cleaned_text]
video_id = url.split("v=")[-1]
html_blocks.append(
f''
)
else:
checks.append(f'- “{cleaned_text}” ⚠️ not in dataset — broken into words')
for word in cleaned_text.split():
if word in annotated_words:
checks.append(f' - “{word}” ✅ in dataset')
if show_videos and word in video_lookup:
url = video_lookup[word]
video_id = url.split("v=")[-1]
html_blocks.append(
f''
)
else:
checks.append(f' - “{word}” ❌ not found — generated by approximation')
return "\n".join(checks), "
".join(html_blocks) if show_videos else ""
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Malay Sign Pose Generator")
gr.Markdown("Generate Malaysian Sign Language (BIM) pose animation from Malay text. Checks which words were seen in training and shows reference YouTube video (if available).")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Text Input"): #whisper
text_input = gr.Textbox(label="Enter Malay Word or Sentence")
with gr.Tab("Speech Input"):
audio_input = gr.Audio(type="filepath", label="Upload or Record Malay Audio")
audio_transcript = gr.Textbox(label="Transcribed Text", interactive=True)
transcribe_btn = gr.Button("Transcribe") #whisper
threshold_slider = gr.Slider(0.0, 1.0, value=0.05, step=0.05, label="Confidence Threshold (for displaying joints)")
show_video_toggle = gr.Checkbox(label="Show Video Previews", value=True)
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
with gr.Column(scale=1):
video_output = gr.Video(label="Generated Pose Animation", height=270)
text_output = gr.Markdown(label="Match Check (Phrase + Word Level):")
youtube_output = gr.HTML()
submit_btn.click(fn=predict,
inputs=[text_input, threshold_slider, show_video_toggle],
outputs=[video_output, text_output, youtube_output])
transcribe_btn.click(fn=transcribe_audio,
inputs=audio_input, outputs=audio_transcript)
audio_transcript.change(fn=predict,
inputs=[audio_transcript, threshold_slider, show_video_toggle],
outputs=[video_output, text_output, youtube_output])
clear_btn.click(lambda: ("", "", ""),
inputs=[],
outputs=[video_output, text_output, youtube_output])
demo.launch(debug=True)