Multi-Scale Multimodal Pose HAR

A custom PyTorch model for Human Activity Recognition (HAR) that integrates:

  • Short-term pose transformers (factorized temporal + spatial attention)
  • Long-term temporal aggregation
  • Optional multimodal fusion with RGB images
  • Multi-stage self-supervised + supervised training pipeline

🧠 Model Overview

Architecture Summary

Pose stream

  • Input: (B, L, T, J, C)
  • Short-term encoder: PoseFormerFactorized
    • Temporal attention (per joint)
    • Spatial attention (per frame)
  • Long-term encoder: Transformer over segment-level features

Image stream (optional)

  • Backbone: ResNet18 / ResNet50
  • Temporal pooling per segment

Fusion

  • concat (default): feature concatenation + MLP
  • xattn: shallow cross-attention (pose tokens ↔ image token)

Output

  • Activity classification logits
  • Optional intermediate embeddings / tokens

πŸ—οΈ Model Components

Module Description
PoseFormerFactorized Short-term pose transformer
LongTermTemporalBlock Long-range temporal modeling
ImageEncoder CNN-based RGB feature extractor
MMFusionConcatLN Concatenation-based multimodal fusion
MMFusionCrossAttnShallow Cross-attention multimodal fusion
SSLHeads Contrastive, reconstruction, temporal order heads

πŸ“₯ Input Format

Pose Input

(B, L, T, J, C)
  • B: batch size
  • L: number of temporal segments
  • T: frames per segment
  • J: number of joints (e.g. 17)
  • C: joint channels (2D or 3D)

Image Input (optional)

(B, L, T, 3, H, W)
  • 3: RGB channels
  • H: image height
  • W: image width

πŸš€ Usage

1️⃣ Load Model Code

from model_har_final import (
    PoseFormerFactorized,
    MultiScaleTemporalModel
)

2️⃣ Build Model

pose_backbone = PoseFormerFactorized(
    joints=17,
    in_ch=3,
    dim=128,
    layers=4,
    num_classes=6,
    return_tokens=True
)

model = MultiScaleTemporalModel(
    short_seq_model=pose_backbone,
    num_classes=6,
    enable_long_term=True,
    multimodal=True,
    fusion_mode="concat"  # or "xattn"
)

3️⃣ Load Weights

import torch

ckpt = torch.load("best_stage3_dual_sched.pth", map_location="cpu")
model.load_state_dict(ckpt)
model.eval()

ℹ️ This model is saved using state_dict, not pickle-serialized objects, for maximum compatibility.


4️⃣ Inference

with torch.no_grad():
    logits = model(pose_seq, img_seq)
    preds = logits.argmax(dim=-1)

πŸ§ͺ Training Strategy

The model is trained in three stages:

Stage 1 β€” Pose SSL + Weak Supervision

  • Masked joint modeling (MJM)
  • Contrastive learning (InfoNCE)
  • Temporal order prediction
  • Optional labeled supervision

Stage 2 β€” Pose-only SSL Refinement

  • Contrastive + reconstruction losses
  • Temporal attention disabled for stability

Stage 3 β€” Multimodal Supervised Fine-tuning

  • Label smoothing CE
  • Semantic prototype distillation
  • Metric learning (triplet)
  • Optional knowledge distillation

πŸ“Š Outputs

Output Shape
Logits (B, num_classes)
Pose embedding (B, D)
Pose tokens (optional) (B, T, J, D)

πŸ“Ž Limitations

  • Not compatible with AutoModel.from_pretrained
  • Requires custom code to instantiate architecture
  • Input pose format must match training configuration

πŸ“œ License

Apache License 2.0


πŸ“Œ Citation

If you use this work, please cite:

@misc{kim2025multiscalehar,
  title        = {Multi-Scale Multimodal Pose Transformer for Human Activity Recognition},
  author       = {Minjae Kim},
  year         = {2025},
  howpublished = {\url{https://huggingface.co/m97j/har-safety-model}}
}
Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Evaluation results