|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
import tempfile |
|
import subprocess |
|
from matplotlib.animation import FFMpegWriter, PillowWriter |
|
import matplotlib.pyplot as plt |
|
from matplotlib import animation |
|
from config import MAX_TEXT_LEN |
|
from data import selective_smoothing, GLOBAL_MEAN_T, GLOBAL_STD_T |
|
from model import TextToPoseSeq2Seq |
|
from transformers import BertTokenizer |
|
|
|
from faster_whisper import WhisperModel |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2") |
|
|
|
|
|
print("FFmpeg check:", subprocess.run(["ffmpeg", "-version"], capture_output=True).stdout.decode().splitlines()[0]) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = TextToPoseSeq2Seq(tokenizer.vocab_size).to(device) |
|
model.load_state_dict(torch.load("best_seq2seq_model_mask.pth", map_location=device)) |
|
model.to(device) |
|
print("Loaded pretrained weights.") |
|
|
|
|
|
annot_df = pd.read_csv("annotated.csv") |
|
annotated_words = set(annot_df["text"].str.strip().str.lower().unique()) |
|
|
|
|
|
video_df = pd.read_csv("annotated_vid_link.csv") |
|
video_df["text_clean"] = video_df["text"].str.strip().str.lower() |
|
video_lookup = dict(zip(video_df["text_clean"], video_df["Video URL"])) |
|
|
|
|
|
whisper_model = WhisperModel("small", compute_type="int8") |
|
|
|
def transcribe_audio(audio_path): |
|
try: |
|
segments, _ = whisper_model.transcribe(audio_path, language="ms", beam_size=5) |
|
full_text = " ".join([segment.text.strip() for segment in segments]) |
|
return full_text.strip() |
|
except Exception as e: |
|
print("Whisper Error:", e) |
|
return "" |
|
|
|
def get_youtube_link(input_text): |
|
return video_lookup.get(input_text.strip().lower()) |
|
|
|
|
|
|
|
selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543]) |
|
NUM_KEYPOINTS = len(selected_keypoint_indices) |
|
POSE_DIM = NUM_KEYPOINTS * 3 |
|
|
|
|
|
mediapipe_connections = [ |
|
(0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8), |
|
(9, 10), (11, 12), (12, 14), (14, 16), (11, 13), (13, 15), |
|
(23, 24), (11, 23), (12, 24) |
|
] |
|
|
|
def add_hand_connections(base_index): |
|
return [(base_index + i, base_index + j) for i, j in [ |
|
(0,1), (1,2), (2,3), (3,4), |
|
(0,5), (5,6), (6,7), (7,8), |
|
(0,9), (9,10), (10,11), (11,12), |
|
(0,13), (13,14), (14,15), (15,16), |
|
(0,17), (17,18), (18,19), (19,20) |
|
]] |
|
|
|
hand1 = selected_keypoint_indices.index(501) |
|
hand2 = selected_keypoint_indices.index(522) |
|
mediapipe_connections += add_hand_connections(hand1) |
|
mediapipe_connections += add_hand_connections(hand2) |
|
|
|
|
|
def concatenate_and_smooth_sequences(sentence, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T): |
|
sentence = sentence.strip() |
|
if sentence.lower() in annotated_words: |
|
words = [sentence] |
|
else: |
|
words = sentence.split() |
|
|
|
pose_preds, conf_preds, frame_labels = [], [], [] |
|
current_frame = 0 |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for word in words: |
|
inputs = tokenizer(word, padding="max_length", truncation=True, max_length=MAX_TEXT_LEN, return_tensors="pt") |
|
input_ids = inputs.input_ids.to(device) |
|
attn_mask = inputs.attention_mask.to(device) |
|
pred_pose_norm, pred_conf = model(input_ids, attention_mask=attn_mask) |
|
|
|
T = pred_pose_norm.shape[1] |
|
pose_preds.append(pred_pose_norm[0]) |
|
conf_preds.append(pred_conf[0]) |
|
frame_labels.append((word, current_frame, current_frame + T - 1)) |
|
current_frame += T |
|
|
|
if not pose_preds: |
|
return None, None, None |
|
|
|
full_pose = torch.cat(pose_preds, dim=0).unsqueeze(0) |
|
full_conf = torch.cat(conf_preds, dim=0) |
|
smoothed = selective_smoothing(full_pose).squeeze(0) |
|
unnormalized = smoothed * GLOBAL_STD_T + GLOBAL_MEAN_T |
|
return unnormalized.view(-1, NUM_KEYPOINTS, 3).cpu().numpy(), full_conf.cpu().numpy(), frame_labels |
|
|
|
|
|
def animate_pose(pred_pose, pred_conf=None, frame_labels=None, interval=150, conf_threshold=0.3): |
|
fig, ax = plt.subplots(figsize=(5, 5)) |
|
fig.subplots_adjust(top=0.85) |
|
|
|
def setup(): |
|
all_x, all_y = pred_pose[:,:,0].flatten(), pred_pose[:,:,1].flatten() |
|
x_buf = (all_x.max() - all_x.min()) * 0.1 + 0.1 |
|
y_buf = (all_y.max() - all_y.min()) * 0.1 + 0.1 |
|
ax.set_xlim(all_x.min() - x_buf, all_x.max() + x_buf) |
|
ax.set_ylim(-all_y.max() - y_buf, -all_y.min() + y_buf) |
|
ax.set_aspect('equal') |
|
ax.axis('off') |
|
|
|
setup() |
|
pred_lines = [ax.plot([], [], color='red', lw=2)[0] for _ in mediapipe_connections] |
|
pred_pts = ax.plot([], [], 'ko', markersize=3)[0] |
|
|
|
def init(): |
|
for line in pred_lines: line.set_data([], []) |
|
pred_pts.set_data([], []) |
|
return pred_lines + [pred_pts] |
|
|
|
def update(frame): |
|
px, py = pred_pose[frame,:,0], -pred_pose[frame,:,1] |
|
if pred_conf is not None: |
|
mask = pred_conf[frame] > conf_threshold |
|
px[~mask], py[~mask] = np.nan, np.nan |
|
|
|
pred_pts.set_data(px, py) |
|
for text in ax.texts: text.remove() |
|
for i in range(NUM_KEYPOINTS): |
|
ax.text(px[i], py[i], str(i), fontsize=5, color='black') |
|
|
|
for i, (start, end) in enumerate(mediapipe_connections): |
|
if pred_conf is None or (pred_conf[frame][start] > conf_threshold and pred_conf[frame][end] > conf_threshold): |
|
pred_lines[i].set_data([px[start], px[end]], [py[start], py[end]]) |
|
else: |
|
pred_lines[i].set_data([], []) |
|
|
|
if frame_labels: |
|
for word, start, end in frame_labels: |
|
if start <= frame <= end: |
|
ax.set_title(f'Prediction: β{word}β (Frames {start}β{end})', fontsize=12, pad=15) |
|
break |
|
|
|
return pred_lines + [pred_pts] |
|
|
|
ani = animation.FuncAnimation(fig, update, frames=len(pred_pose), init_func=init, blit=True, interval=interval) |
|
plt.tight_layout(pad=2.0) |
|
return ani |
|
|
|
|
|
def save_animation(anim, format="mp4"): |
|
try: |
|
ext = ".mp4" if format == "mp4" else ".gif" |
|
writer = FFMpegWriter(fps=10, bitrate=1800) if format == "mp4" else PillowWriter(fps=10) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as f: |
|
path = f.name |
|
anim.save(path, writer=writer) |
|
return path |
|
except Exception as e: |
|
print(f"{format.upper()} save failed:", e) |
|
return None |
|
|
|
def find_all_matches(text, video_lookup): |
|
words = text.strip().lower().split() |
|
matched = [] |
|
i = 0 |
|
while i < len(words): |
|
found = False |
|
|
|
if i + 2 < len(words): |
|
phrase3 = " ".join(words[i:i+3]) |
|
if phrase3 in video_lookup: |
|
matched.append((phrase3, video_lookup[phrase3])) |
|
i += 3 |
|
found = True |
|
continue |
|
|
|
if i + 1 < len(words): |
|
phrase2 = " ".join(words[i:i+2]) |
|
if phrase2 in video_lookup: |
|
matched.append((phrase2, video_lookup[phrase2])) |
|
i += 2 |
|
found = True |
|
continue |
|
|
|
word = words[i] |
|
if word in video_lookup: |
|
matched.append((word, video_lookup[word])) |
|
else: |
|
matched.append((word, None)) |
|
i += 1 |
|
return matched |
|
|
|
|
|
def predict(text, threshold, show_videos=True): |
|
if not text.strip(): |
|
return None, "β οΈ Please enter valid text.", "" |
|
|
|
try: |
|
pose, conf, labels = concatenate_and_smooth_sequences(text, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T) |
|
if pose is None: return None, "β οΈ No pose predicted.", "" |
|
anim = animate_pose(pose, pred_conf=conf, frame_labels=labels, conf_threshold=threshold) |
|
path = save_animation(anim, format="mp4") or save_animation(anim, format="gif") |
|
if not path: |
|
return None, "β Failed to save animation.", "" |
|
|
|
|
|
cleaned_text = text.strip().lower() |
|
result_text, video_html = build_result_with_video_links(cleaned_text, video_lookup, show_videos) |
|
return path, result_text, video_html |
|
|
|
except Exception as e: |
|
print("Error during prediction:", e) |
|
return None, f"β Runtime error: {str(e)}", "" |
|
|
|
def build_result_with_video_links(cleaned_text, video_lookup, show_videos=True): |
|
checks = ["**Match Check (Phrase + Word Level):**"] |
|
html_blocks = [] |
|
|
|
if cleaned_text in video_lookup: |
|
checks.append(f'- β{cleaned_text}β β
in dataset') |
|
if show_videos: |
|
url = video_lookup[cleaned_text] |
|
video_id = url.split("v=")[-1] |
|
html_blocks.append( |
|
f'<iframe width="480" height="270" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allowfullscreen></iframe>' |
|
) |
|
else: |
|
checks.append(f'- β{cleaned_text}β β οΈ not in dataset β broken into words') |
|
for word in cleaned_text.split(): |
|
if word in annotated_words: |
|
checks.append(f' - β{word}β β
in dataset') |
|
if show_videos and word in video_lookup: |
|
url = video_lookup[word] |
|
video_id = url.split("v=")[-1] |
|
html_blocks.append( |
|
f'<iframe width="240" height="135" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allowfullscreen></iframe>' |
|
) |
|
else: |
|
checks.append(f' - β{word}β β not found β generated by approximation') |
|
|
|
return "\n".join(checks), "<br>".join(html_blocks) if show_videos else "" |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Text-to-Malay Sign Pose Generator") |
|
gr.Markdown("Generate Malaysian Sign Language (BIM) pose animation from Malay text. Checks which words were seen in training and shows reference YouTube video (if available).") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Tab("Text Input"): |
|
text_input = gr.Textbox(label="Enter Malay Word or Sentence") |
|
with gr.Tab("Speech Input"): |
|
audio_input = gr.Audio(type="filepath", label="Upload or Record Malay Audio") |
|
audio_transcript = gr.Textbox(label="Transcribed Text", interactive=True) |
|
transcribe_btn = gr.Button("Transcribe") |
|
|
|
threshold_slider = gr.Slider(0.0, 1.0, value=0.05, step=0.05, label="Confidence Threshold (for displaying joints)") |
|
show_video_toggle = gr.Checkbox(label="Show Video Previews", value=True) |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.Button("Clear") |
|
|
|
with gr.Column(scale=1): |
|
video_output = gr.Video(label="Generated Pose Animation", height=270) |
|
text_output = gr.Markdown(label="Match Check (Phrase + Word Level):") |
|
youtube_output = gr.HTML() |
|
|
|
submit_btn.click(fn=predict, |
|
inputs=[text_input, threshold_slider, show_video_toggle], |
|
outputs=[video_output, text_output, youtube_output]) |
|
|
|
transcribe_btn.click(fn=transcribe_audio, |
|
inputs=audio_input, outputs=audio_transcript) |
|
audio_transcript.change(fn=predict, |
|
inputs=[audio_transcript, threshold_slider, show_video_toggle], |
|
outputs=[video_output, text_output, youtube_output]) |
|
|
|
clear_btn.click(lambda: ("", "", ""), |
|
inputs=[], |
|
outputs=[video_output, text_output, youtube_output]) |
|
|
|
demo.launch(debug=True) |