import torch import numpy as np import pandas as pd import gradio as gr import tempfile import subprocess from matplotlib.animation import FFMpegWriter, PillowWriter import matplotlib.pyplot as plt from matplotlib import animation from config import MAX_TEXT_LEN from data import selective_smoothing, GLOBAL_MEAN_T, GLOBAL_STD_T from model import TextToPoseSeq2Seq from transformers import BertTokenizer #whisper from faster_whisper import WhisperModel # === Tokenizer and Model Init === tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2") # === Check FFmpeg availability === print("FFmpeg check:", subprocess.run(["ffmpeg", "-version"], capture_output=True).stdout.decode().splitlines()[0]) # === Load model === device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = TextToPoseSeq2Seq(tokenizer.vocab_size).to(device) model.load_state_dict(torch.load("best_seq2seq_model_mask.pth", map_location=device)) model.to(device) print("Loaded pretrained weights.") # === Load annotated.csv === annot_df = pd.read_csv("annotated.csv") annotated_words = set(annot_df["text"].str.strip().str.lower().unique()) # Load the final annotated CSV with video links video_df = pd.read_csv("annotated_vid_link.csv") video_df["text_clean"] = video_df["text"].str.strip().str.lower() video_lookup = dict(zip(video_df["text_clean"], video_df["Video URL"])) # === Load Whisper === whisper_model = WhisperModel("small", compute_type="int8") def transcribe_audio(audio_path): try: segments, _ = whisper_model.transcribe(audio_path, language="ms", beam_size=5) full_text = " ".join([segment.text.strip() for segment in segments]) return full_text.strip() except Exception as e: print("Whisper Error:", e) return "" def get_youtube_link(input_text): return video_lookup.get(input_text.strip().lower()) # === Keypoint setup === selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543]) NUM_KEYPOINTS = len(selected_keypoint_indices) POSE_DIM = NUM_KEYPOINTS * 3 # === Connection graph === mediapipe_connections = [ (0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8), (9, 10), (11, 12), (12, 14), (14, 16), (11, 13), (13, 15), (23, 24), (11, 23), (12, 24) ] def add_hand_connections(base_index): return [(base_index + i, base_index + j) for i, j in [ (0,1), (1,2), (2,3), (3,4), (0,5), (5,6), (6,7), (7,8), (0,9), (9,10), (10,11), (11,12), (0,13), (13,14), (14,15), (15,16), (0,17), (17,18), (18,19), (19,20) ]] hand1 = selected_keypoint_indices.index(501) hand2 = selected_keypoint_indices.index(522) mediapipe_connections += add_hand_connections(hand1) mediapipe_connections += add_hand_connections(hand2) # === Pose inference === def concatenate_and_smooth_sequences(sentence, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T): sentence = sentence.strip() if sentence.lower() in annotated_words: words = [sentence] else: words = sentence.split() pose_preds, conf_preds, frame_labels = [], [], [] current_frame = 0 model.eval() with torch.no_grad(): for word in words: inputs = tokenizer(word, padding="max_length", truncation=True, max_length=MAX_TEXT_LEN, return_tensors="pt") input_ids = inputs.input_ids.to(device) attn_mask = inputs.attention_mask.to(device) pred_pose_norm, pred_conf = model(input_ids, attention_mask=attn_mask) T = pred_pose_norm.shape[1] pose_preds.append(pred_pose_norm[0]) conf_preds.append(pred_conf[0]) frame_labels.append((word, current_frame, current_frame + T - 1)) current_frame += T if not pose_preds: return None, None, None full_pose = torch.cat(pose_preds, dim=0).unsqueeze(0) full_conf = torch.cat(conf_preds, dim=0) smoothed = selective_smoothing(full_pose).squeeze(0) unnormalized = smoothed * GLOBAL_STD_T + GLOBAL_MEAN_T return unnormalized.view(-1, NUM_KEYPOINTS, 3).cpu().numpy(), full_conf.cpu().numpy(), frame_labels # === Animation === def animate_pose(pred_pose, pred_conf=None, frame_labels=None, interval=150, conf_threshold=0.3): fig, ax = plt.subplots(figsize=(5, 5)) fig.subplots_adjust(top=0.85) def setup(): all_x, all_y = pred_pose[:,:,0].flatten(), pred_pose[:,:,1].flatten() x_buf = (all_x.max() - all_x.min()) * 0.1 + 0.1 y_buf = (all_y.max() - all_y.min()) * 0.1 + 0.1 ax.set_xlim(all_x.min() - x_buf, all_x.max() + x_buf) ax.set_ylim(-all_y.max() - y_buf, -all_y.min() + y_buf) ax.set_aspect('equal') ax.axis('off') setup() pred_lines = [ax.plot([], [], color='red', lw=2)[0] for _ in mediapipe_connections] pred_pts = ax.plot([], [], 'ko', markersize=3)[0] def init(): for line in pred_lines: line.set_data([], []) pred_pts.set_data([], []) return pred_lines + [pred_pts] def update(frame): px, py = pred_pose[frame,:,0], -pred_pose[frame,:,1] if pred_conf is not None: mask = pred_conf[frame] > conf_threshold px[~mask], py[~mask] = np.nan, np.nan pred_pts.set_data(px, py) for text in ax.texts: text.remove() for i in range(NUM_KEYPOINTS): ax.text(px[i], py[i], str(i), fontsize=5, color='black') for i, (start, end) in enumerate(mediapipe_connections): if pred_conf is None or (pred_conf[frame][start] > conf_threshold and pred_conf[frame][end] > conf_threshold): pred_lines[i].set_data([px[start], px[end]], [py[start], py[end]]) else: pred_lines[i].set_data([], []) if frame_labels: for word, start, end in frame_labels: if start <= frame <= end: ax.set_title(f'Prediction: “{word}” (Frames {start}–{end})', fontsize=12, pad=15) break return pred_lines + [pred_pts] ani = animation.FuncAnimation(fig, update, frames=len(pred_pose), init_func=init, blit=True, interval=interval) plt.tight_layout(pad=2.0) return ani # === Save video or fallback to GIF === def save_animation(anim, format="mp4"): try: ext = ".mp4" if format == "mp4" else ".gif" writer = FFMpegWriter(fps=10, bitrate=1800) if format == "mp4" else PillowWriter(fps=10) with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as f: path = f.name anim.save(path, writer=writer) return path except Exception as e: print(f"{format.upper()} save failed:", e) return None def find_all_matches(text, video_lookup): words = text.strip().lower().split() matched = [] i = 0 while i < len(words): found = False # Try trigram if i + 2 < len(words): phrase3 = " ".join(words[i:i+3]) if phrase3 in video_lookup: matched.append((phrase3, video_lookup[phrase3])) i += 3 found = True continue # Try bigram if i + 1 < len(words): phrase2 = " ".join(words[i:i+2]) if phrase2 in video_lookup: matched.append((phrase2, video_lookup[phrase2])) i += 2 found = True continue # Try unigram word = words[i] if word in video_lookup: matched.append((word, video_lookup[word])) else: matched.append((word, None)) # not in lookup i += 1 return matched def predict(text, threshold, show_videos=True): if not text.strip(): return None, "⚠️ Please enter valid text.", "" try: pose, conf, labels = concatenate_and_smooth_sequences(text, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T) if pose is None: return None, "⚠️ No pose predicted.", "" anim = animate_pose(pose, pred_conf=conf, frame_labels=labels, conf_threshold=threshold) path = save_animation(anim, format="mp4") or save_animation(anim, format="gif") if not path: return None, "❌ Failed to save animation.", "" # === Generate Text Output + Video Preview === cleaned_text = text.strip().lower() result_text, video_html = build_result_with_video_links(cleaned_text, video_lookup, show_videos) return path, result_text, video_html except Exception as e: print("Error during prediction:", e) return None, f"❌ Runtime error: {str(e)}", "" def build_result_with_video_links(cleaned_text, video_lookup, show_videos=True): checks = ["**Match Check (Phrase + Word Level):**"] html_blocks = [] if cleaned_text in video_lookup: checks.append(f'- “{cleaned_text}” ✅ in dataset') if show_videos: url = video_lookup[cleaned_text] video_id = url.split("v=")[-1] html_blocks.append( f'' ) else: checks.append(f'- “{cleaned_text}” ⚠️ not in dataset — broken into words') for word in cleaned_text.split(): if word in annotated_words: checks.append(f' - “{word}” ✅ in dataset') if show_videos and word in video_lookup: url = video_lookup[word] video_id = url.split("v=")[-1] html_blocks.append( f'' ) else: checks.append(f' - “{word}” ❌ not found — generated by approximation') return "\n".join(checks), "
".join(html_blocks) if show_videos else "" # === Gradio UI === with gr.Blocks() as demo: gr.Markdown("# Text-to-Malay Sign Pose Generator") gr.Markdown("Generate Malaysian Sign Language (BIM) pose animation from Malay text. Checks which words were seen in training and shows reference YouTube video (if available).") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Text Input"): #whisper text_input = gr.Textbox(label="Enter Malay Word or Sentence") with gr.Tab("Speech Input"): audio_input = gr.Audio(type="filepath", label="Upload or Record Malay Audio") audio_transcript = gr.Textbox(label="Transcribed Text", interactive=True) transcribe_btn = gr.Button("Transcribe") #whisper threshold_slider = gr.Slider(0.0, 1.0, value=0.05, step=0.05, label="Confidence Threshold (for displaying joints)") show_video_toggle = gr.Checkbox(label="Show Video Previews", value=True) submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") with gr.Column(scale=1): video_output = gr.Video(label="Generated Pose Animation", height=270) text_output = gr.Markdown(label="Match Check (Phrase + Word Level):") youtube_output = gr.HTML() submit_btn.click(fn=predict, inputs=[text_input, threshold_slider, show_video_toggle], outputs=[video_output, text_output, youtube_output]) transcribe_btn.click(fn=transcribe_audio, inputs=audio_input, outputs=audio_transcript) audio_transcript.change(fn=predict, inputs=[audio_transcript, threshold_slider, show_video_toggle], outputs=[video_output, text_output, youtube_output]) clear_btn.click(lambda: ("", "", ""), inputs=[], outputs=[video_output, text_output, youtube_output]) demo.launch(debug=True)