Upload folder using huggingface_hub
Browse files- .gitattributes +9 -0
- README.md +5 -0
- requirements.txt +13 -0
- src/__init__.py +0 -0
- src/config.py +66 -0
- src/exp.py +276 -0
- whisper-alignment-results/base-to-large-probes.safetensors +3 -0
- whisper-alignment-results/base-to-medium-probes.safetensors +3 -0
- whisper-alignment-results/base-to-small-probes.safetensors +3 -0
- whisper-alignment-results/base-vs-large-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/base-vs-medium-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/base-vs-small-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/small-to-large-probes.safetensors +3 -0
- whisper-alignment-results/small-to-medium-probes.safetensors +3 -0
- whisper-alignment-results/small-vs-large-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/small-vs-medium-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-to-base-probes.safetensors +3 -0
- whisper-alignment-results/tiny-to-large-probes.safetensors +3 -0
- whisper-alignment-results/tiny-to-medium-probes.safetensors +3 -0
- whisper-alignment-results/tiny-to-small-probes.safetensors +3 -0
- whisper-alignment-results/tiny-vs-base-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-vs-large-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-vs-medium-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-vs-small-temporal-linear-mse-log.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
whisper-alignment-results/base-vs-large-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
whisper-alignment-results/base-vs-medium-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
whisper-alignment-results/base-vs-small-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
whisper-alignment-results/small-vs-large-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
whisper-alignment-results/small-vs-medium-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
whisper-alignment-results/tiny-vs-base-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
whisper-alignment-results/tiny-vs-large-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
whisper-alignment-results/tiny-vs-medium-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
whisper-alignment-results/tiny-vs-small-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experiment of Layer Alignment Analysis for Whisper Encoders
|
| 2 |
+
|
| 3 |
+
Analyzing and comparing the internal representations of OpenAI Whisper encoder models, designed for research on model interpretability and transferability.
|
| 4 |
+
|
| 5 |
+
All settings are adjustable in `src/config.py`.
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
joblib
|
| 2 |
+
matplotlib
|
| 3 |
+
numpy
|
| 4 |
+
torch
|
| 5 |
+
tqdm
|
| 6 |
+
transformers
|
| 7 |
+
datasets
|
| 8 |
+
librosa
|
| 9 |
+
soundfile
|
| 10 |
+
safetensors
|
| 11 |
+
|
| 12 |
+
isort
|
| 13 |
+
black
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Memory Configuration for Whisper Alignment Analysis
|
| 2 |
+
|
| 3 |
+
# Batch processing settings
|
| 4 |
+
BATCH_SIZE = 16 # Reduce if you get OOM errors, increase for faster processing
|
| 5 |
+
TRAINING_STEPS = 200 # Number of training steps for linear probes
|
| 6 |
+
LEARNING_RATE = 1e-3
|
| 7 |
+
|
| 8 |
+
# Model selection
|
| 9 |
+
ENABLED_MODELS = {
|
| 10 |
+
"tiny": "openai/whisper-tiny", # ~39M parameters
|
| 11 |
+
"base": "openai/whisper-base", # ~74M parameters
|
| 12 |
+
"small": "openai/whisper-small", # ~244M parameters
|
| 13 |
+
"medium": "openai/whisper-medium", # ~769M parameters
|
| 14 |
+
"large": "openai/whisper-large-v3-turbo", # ~1550M parameters
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
# Memory optimization settings
|
| 18 |
+
USE_HALF_PRECISION = (
|
| 19 |
+
True # Use half precision (bfloat16 preferred, float16 fallback) instead of float32
|
| 20 |
+
)
|
| 21 |
+
AGGRESSIVE_CLEANUP = False # Clear GPU cache after each operation
|
| 22 |
+
|
| 23 |
+
# Dataset settings
|
| 24 |
+
MAX_SAMPLES = None # Set to a number to limit dataset size for testing (e.g., 50)
|
| 25 |
+
|
| 26 |
+
# Output settings
|
| 27 |
+
OUTPUT_DIR = "whisper-alignment-results"
|
| 28 |
+
SAVE_PLOTS = True
|
| 29 |
+
PLOT_DPI = 300
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Half precision dtype selection (bfloat16 preferred if available, fallback to float16)
|
| 33 |
+
def get_half_precision_dtype():
|
| 34 |
+
"""
|
| 35 |
+
Determine the best half precision dtype based on hardware support.
|
| 36 |
+
bfloat16 is preferred when available as it has better numerical stability.
|
| 37 |
+
"""
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
if not USE_HALF_PRECISION:
|
| 41 |
+
return torch.float32
|
| 42 |
+
|
| 43 |
+
# Check if bfloat16 is supported
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
# Check GPU support for bfloat16
|
| 46 |
+
device_capability = torch.cuda.get_device_capability()
|
| 47 |
+
# bfloat16 is supported on Ampere (8.x) and newer GPUs
|
| 48 |
+
if device_capability[0] >= 8:
|
| 49 |
+
return torch.bfloat16
|
| 50 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 51 |
+
# Apple Silicon supports bfloat16
|
| 52 |
+
return torch.bfloat16
|
| 53 |
+
elif (
|
| 54 |
+
hasattr(torch, "backends")
|
| 55 |
+
and hasattr(torch.backends, "cpu")
|
| 56 |
+
and hasattr(torch.backends.cpu, "supports_bfloat16")
|
| 57 |
+
):
|
| 58 |
+
# Check CPU support for bfloat16 (newer PyTorch versions)
|
| 59 |
+
if torch.backends.cpu.supports_bfloat16:
|
| 60 |
+
return torch.bfloat16
|
| 61 |
+
|
| 62 |
+
# Fallback to float16
|
| 63 |
+
return torch.float16
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
HALF_PRECISION_DTYPE = get_half_precision_dtype()
|
src/exp.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from itertools import combinations
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from datasets import Audio, load_dataset
|
| 9 |
+
from safetensors.torch import save_file
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import AutoFeatureExtractor, WhisperModel
|
| 12 |
+
|
| 13 |
+
from .config import *
|
| 14 |
+
|
| 15 |
+
model_ids = ENABLED_MODELS
|
| 16 |
+
|
| 17 |
+
# Load dataset
|
| 18 |
+
dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train")
|
| 19 |
+
if MAX_SAMPLES is not None:
|
| 20 |
+
dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset))))
|
| 21 |
+
print(f"Limited dataset to {len(dataset)} samples for testing")
|
| 22 |
+
|
| 23 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
|
| 24 |
+
|
| 25 |
+
device = torch.device(
|
| 26 |
+
"cuda"
|
| 27 |
+
if torch.cuda.is_available()
|
| 28 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 29 |
+
)
|
| 30 |
+
print(f"Using device: {device}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def extract_layer_reps_generator(model_id, batch_size=4):
|
| 34 |
+
"""
|
| 35 |
+
Use a generator to process samples in batches, avoiding loading all hidden states into memory at once.
|
| 36 |
+
Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample.
|
| 37 |
+
"""
|
| 38 |
+
model = WhisperModel.from_pretrained(model_id).to(device)
|
| 39 |
+
feat_ext = AutoFeatureExtractor.from_pretrained(model_id)
|
| 40 |
+
model.eval()
|
| 41 |
+
|
| 42 |
+
for i in tqdm(
|
| 43 |
+
range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches"
|
| 44 |
+
):
|
| 45 |
+
batch_end = min(i + batch_size, len(dataset))
|
| 46 |
+
batch_samples = dataset.select(range(i, batch_end))
|
| 47 |
+
|
| 48 |
+
# Process each sample in the batch
|
| 49 |
+
for j, sample in enumerate(batch_samples):
|
| 50 |
+
audio = sample["audio"]
|
| 51 |
+
samples = audio["array"]
|
| 52 |
+
sr = audio["sampling_rate"]
|
| 53 |
+
|
| 54 |
+
inputs = feat_ext(
|
| 55 |
+
samples, sampling_rate=sr, return_tensors="pt"
|
| 56 |
+
).input_features.to(device)
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
outputs = model.encoder(
|
| 59 |
+
inputs, return_dict=True, output_hidden_states=True
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Save the full sequence for each layer and immediately move to CPU; optionally use half precision to save memory
|
| 63 |
+
layer_reps_for_sample = []
|
| 64 |
+
for hs in outputs.hidden_states:
|
| 65 |
+
# hs: [1, T, D] -> [T, D]
|
| 66 |
+
layer_rep = hs.squeeze(0)
|
| 67 |
+
if USE_HALF_PRECISION:
|
| 68 |
+
layer_rep = layer_rep.to(HALF_PRECISION_DTYPE)
|
| 69 |
+
layer_reps_for_sample.append(layer_rep)
|
| 70 |
+
|
| 71 |
+
yield i + j, layer_reps_for_sample
|
| 72 |
+
|
| 73 |
+
# Clean up GPU memory
|
| 74 |
+
del outputs, inputs
|
| 75 |
+
if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
|
| 76 |
+
torch.cuda.empty_cache()
|
| 77 |
+
|
| 78 |
+
# Clean up model memory
|
| 79 |
+
del model, feat_ext
|
| 80 |
+
if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
|
| 81 |
+
torch.cuda.empty_cache()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def compute_linear_mse_matrix_temporal_memory_efficient(
|
| 85 |
+
model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE.
|
| 89 |
+
Uses a generator to process in batches, avoiding loading all representations into memory at once.
|
| 90 |
+
Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes.
|
| 91 |
+
"""
|
| 92 |
+
print(f"Computing alignment between {model_a_id} and {model_b_id}...")
|
| 93 |
+
|
| 94 |
+
# First, get the number of layers
|
| 95 |
+
sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1)
|
| 96 |
+
_, sample_reps_a = next(sample_gen_a)
|
| 97 |
+
layers_a = len(sample_reps_a)
|
| 98 |
+
|
| 99 |
+
sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1)
|
| 100 |
+
_, sample_reps_b = next(sample_gen_b)
|
| 101 |
+
layers_b = len(sample_reps_b)
|
| 102 |
+
|
| 103 |
+
mse_mat = np.zeros((layers_a, layers_b))
|
| 104 |
+
trained_probes = {}
|
| 105 |
+
|
| 106 |
+
pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs")
|
| 107 |
+
|
| 108 |
+
# Re-initialize generators to process all samples
|
| 109 |
+
gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size)
|
| 110 |
+
gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size)
|
| 111 |
+
|
| 112 |
+
# Collect all sample representations for specified layers
|
| 113 |
+
reps_a_dict_all = {}
|
| 114 |
+
for sample_idx, layer_reps in gen_a:
|
| 115 |
+
reps_a_dict_all[sample_idx] = layer_reps
|
| 116 |
+
|
| 117 |
+
reps_b_dict_all = {}
|
| 118 |
+
for sample_idx, layer_reps in gen_b:
|
| 119 |
+
reps_b_dict_all[sample_idx] = layer_reps
|
| 120 |
+
|
| 121 |
+
for i in range(layers_a):
|
| 122 |
+
for j in range(layers_b):
|
| 123 |
+
# Collect all sample representations for the specified layer
|
| 124 |
+
reps_a_dict = {}
|
| 125 |
+
for sample_idx, layer_reps in reps_a_dict_all.items():
|
| 126 |
+
if i < len(layer_reps):
|
| 127 |
+
reps_a_dict[sample_idx] = layer_reps[i]
|
| 128 |
+
|
| 129 |
+
reps_b_dict = {}
|
| 130 |
+
for sample_idx, layer_reps in reps_b_dict_all.items():
|
| 131 |
+
if j < len(layer_reps):
|
| 132 |
+
reps_b_dict[sample_idx] = layer_reps[j]
|
| 133 |
+
|
| 134 |
+
# Concatenate representations in order
|
| 135 |
+
X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())]
|
| 136 |
+
Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())]
|
| 137 |
+
|
| 138 |
+
# Process in batches to avoid memory issues
|
| 139 |
+
X_cat = torch.cat(X_list, dim=0).to(device)
|
| 140 |
+
Y_cat = torch.cat(Y_list, dim=0).to(device)
|
| 141 |
+
|
| 142 |
+
dim_a = X_cat.shape[1]
|
| 143 |
+
dim_b = Y_cat.shape[1]
|
| 144 |
+
|
| 145 |
+
# For Conv1d, reshape to [Batch, Channels, Length]
|
| 146 |
+
X = X_cat.T.unsqueeze(0) # [1, Dim_A, Total_Tokens]
|
| 147 |
+
Y = Y_cat.T.unsqueeze(0) # [1, Dim_B, Total_Tokens]
|
| 148 |
+
|
| 149 |
+
# 2. Define and train linear probe (1x1 Conv)
|
| 150 |
+
probe = nn.Conv1d(
|
| 151 |
+
in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False
|
| 152 |
+
).to(device=device, dtype=HALF_PRECISION_DTYPE)
|
| 153 |
+
probe.train()
|
| 154 |
+
|
| 155 |
+
optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
|
| 156 |
+
loss_fn = nn.MSELoss()
|
| 157 |
+
|
| 158 |
+
for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"):
|
| 159 |
+
optimizer.zero_grad()
|
| 160 |
+
Y_pred = probe(X)
|
| 161 |
+
loss = loss_fn(Y_pred, Y)
|
| 162 |
+
loss.backward()
|
| 163 |
+
optimizer.step()
|
| 164 |
+
|
| 165 |
+
# 3. Record final MSE and trained probe
|
| 166 |
+
final_mse = loss.item()
|
| 167 |
+
mse_mat[i, j] = final_mse
|
| 168 |
+
trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"]
|
| 169 |
+
|
| 170 |
+
# Clean up memory
|
| 171 |
+
del (
|
| 172 |
+
X_cat,
|
| 173 |
+
Y_cat,
|
| 174 |
+
X,
|
| 175 |
+
Y,
|
| 176 |
+
probe,
|
| 177 |
+
optimizer,
|
| 178 |
+
reps_a_dict,
|
| 179 |
+
reps_b_dict,
|
| 180 |
+
X_list,
|
| 181 |
+
Y_list,
|
| 182 |
+
)
|
| 183 |
+
if torch.cuda.is_available():
|
| 184 |
+
torch.cuda.empty_cache()
|
| 185 |
+
|
| 186 |
+
pbar.update(1)
|
| 187 |
+
pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"})
|
| 188 |
+
|
| 189 |
+
pbar.close()
|
| 190 |
+
return mse_mat, trained_probes
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
print(f"Memory optimization settings:")
|
| 195 |
+
print(f" Batch size: {BATCH_SIZE}")
|
| 196 |
+
print(f" Training steps: {TRAINING_STEPS}")
|
| 197 |
+
if USE_HALF_PRECISION:
|
| 198 |
+
dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16"
|
| 199 |
+
print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})")
|
| 200 |
+
else:
|
| 201 |
+
print(f" Half precision: {USE_HALF_PRECISION}")
|
| 202 |
+
print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}")
|
| 203 |
+
print(f" Models: {list(model_ids.keys())}")
|
| 204 |
+
print(f" Dataset size: {len(dataset)} samples")
|
| 205 |
+
|
| 206 |
+
# Create results directory
|
| 207 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 208 |
+
|
| 209 |
+
# 2. Compare all model pairs - using memory-efficient method
|
| 210 |
+
model_names = list(model_ids.keys())
|
| 211 |
+
all_pairs = list(combinations(model_names, 2))
|
| 212 |
+
|
| 213 |
+
print(
|
| 214 |
+
f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
for pair_idx, (model_a, model_b) in enumerate(all_pairs):
|
| 218 |
+
print(
|
| 219 |
+
f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..."
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Compute linear MSE along the temporal dimension and get trained probes - memory-efficient version
|
| 223 |
+
mse_mat_temporal, trained_probes = (
|
| 224 |
+
compute_linear_mse_matrix_temporal_memory_efficient(
|
| 225 |
+
model_ids[model_a],
|
| 226 |
+
model_ids[model_b],
|
| 227 |
+
n_steps=TRAINING_STEPS,
|
| 228 |
+
lr=LEARNING_RATE,
|
| 229 |
+
batch_size=BATCH_SIZE,
|
| 230 |
+
)
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Save trained models
|
| 234 |
+
model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors"
|
| 235 |
+
save_file(
|
| 236 |
+
trained_probes,
|
| 237 |
+
model_save_path,
|
| 238 |
+
{
|
| 239 |
+
"from_model": model_a,
|
| 240 |
+
"to_model": model_b,
|
| 241 |
+
"from_layers": str(len(mse_mat_temporal)),
|
| 242 |
+
"to_layers": str(len(mse_mat_temporal[0])),
|
| 243 |
+
},
|
| 244 |
+
)
|
| 245 |
+
print(f"Saved trained probes to: {model_save_path}")
|
| 246 |
+
|
| 247 |
+
if SAVE_PLOTS:
|
| 248 |
+
# Visualize results
|
| 249 |
+
# Avoid log(0) by adding a small value
|
| 250 |
+
eps = 1e-10
|
| 251 |
+
log_mse_mat = -np.log10(mse_mat_temporal + eps)
|
| 252 |
+
|
| 253 |
+
plt.figure(figsize=(8, 6))
|
| 254 |
+
plt.imshow(
|
| 255 |
+
log_mse_mat, aspect="auto", origin="lower"
|
| 256 |
+
) # origin='lower' is more standard for matrices
|
| 257 |
+
plt.colorbar(label="-log10(MSE)")
|
| 258 |
+
plt.title(
|
| 259 |
+
f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}"
|
| 260 |
+
)
|
| 261 |
+
plt.xlabel(f"whisper-{model_b} layers")
|
| 262 |
+
plt.ylabel(f"whisper-{model_a} layers")
|
| 263 |
+
plt.tight_layout()
|
| 264 |
+
|
| 265 |
+
# Save visualization results
|
| 266 |
+
plot_save_path = (
|
| 267 |
+
f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png"
|
| 268 |
+
)
|
| 269 |
+
plt.savefig(plot_save_path, dpi=PLOT_DPI)
|
| 270 |
+
plt.close() # Close figure to save memory
|
| 271 |
+
print(f"Saved plot to: {plot_save_path}")
|
| 272 |
+
|
| 273 |
+
print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory")
|
| 274 |
+
print(
|
| 275 |
+
f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models"
|
| 276 |
+
)
|
whisper-alignment-results/base-to-large-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:010e2238c4fd6584ed9e8b3cbd9a3379633c585a3d8f4dee2617679002193809
|
| 3 |
+
size 302797200
|
whisper-alignment-results/base-to-medium-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f12bfaa489bcf81f3d0e51109e81e0ff2b03a31507c8742da5f4ea6e53b996b1
|
| 3 |
+
size 183516544
|
whisper-alignment-results/base-to-small-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab25b9c2c210ebc1e0c3a035025322f949dc449b9f9e4a119dcb21e527545425
|
| 3 |
+
size 71573320
|
whisper-alignment-results/base-vs-large-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/base-vs-medium-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/base-vs-small-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/small-to-large-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e025459eeb94c6209e04d348e18eb3cb3912f128ba68adc544df9b3ce8c3ae3
|
| 3 |
+
size 843487312
|
whisper-alignment-results/small-to-medium-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4225bd80384af8155badfbc9ace80114339e18c893c68ba787811aab776cf59a
|
| 3 |
+
size 511210280
|
whisper-alignment-results/small-vs-large-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/small-vs-medium-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/tiny-to-base-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b10f010c37f5b884e9b2e5a05cbb87226d3a660c51c894e4ce175bd1093da7e
|
| 3 |
+
size 13765648
|
whisper-alignment-results/tiny-to-large-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d363b8111d23d57e5c45a49e716a6c58c68c8f8c637175c4617ec0bed9f75cd9
|
| 3 |
+
size 162216440
|
whisper-alignment-results/tiny-to-medium-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8879301313af0f749dee9587fd917bf6923477ede8984b1a4a8dbedb169d828f
|
| 3 |
+
size 98315144
|
whisper-alignment-results/tiny-to-small-probes.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27d2ac798074676cf69de46554572c37279fc92ba55b459f892fee4eba686ca2
|
| 3 |
+
size 38344296
|
whisper-alignment-results/tiny-vs-base-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/tiny-vs-large-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/tiny-vs-medium-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|
whisper-alignment-results/tiny-vs-small-temporal-linear-mse-log.png
ADDED
|
Git LFS Details
|