XLS-R L15 Phonetic Projection Head
A small MLP projection head trained on top of a frozen, truncated XLS-R 300M backbone to produce phone-discriminative embeddings. Intended for multilingual phone clustering / inventory-discovery research; no vocabulary commitment, no per-language classifier.
Architecture
- Backbone:
facebook/wav2vec2-xls-r-300m, frozen and truncated to the first 15 transformer layers. Layers 16โ24 are dropped because we read features from layer 15 only (phonetic peak per Pasad et al. 2023 for contrastive-objective 24-layer SSL models). - Pooling: full-span mean over the gold phone-span frames at layer 15.
- Projection head: 2-layer MLP,
1024 โ 1024 โ 256, ReLU + dropout 0.1, L2-normalised output. - Loss: Supervised contrastive / NT-Xent (SupCon; Khosla et al. 2020), temperature ฯ = 0.07. For each anchor, all in-batch samples with the same IPA label are pulled together; all others are pushed apart.
The XLS-R weights are unchanged; only the 1.3M-parameter MLP is trained.
Training data
DoReCo [Paschen et al. 2020], a multilingual field-recording corpus of phone-aligned narrow-IPA transcriptions across 45 language sources.
- Training pool: 40 DoReCo sources.
- Held-out (never seen during training): 5 typologically diverse sources โ
yong1270(Yongning Na, tonal Sino-Tibetan),pnar1238(Pnar, tonal Austroasiatic),arap1274(Arapaho, Algonquian, noisy recordings),cash1254(Cashinahua, Panoan with implosives/ejectives),savo1255(Savosavo, Papuan isolate). - Total spans extracted: 1,048,186 (all 45 sources); held-out subset used for evaluation: 169,127 spans.
- Batch: 32 classes ร 8 instances per class = 256 samples per step; 5000 training steps; Adam lr = 1e-3.
Results โ held-out 5 languages, K-means clustering
All metrics computed on the 169,127-span held-out subset that was never seen during MLP training.
| K | Baseline PNMI (frozen L15) | Projected PNMI (L15 โ MLP) | ฮPNMI | Baseline purity | Projected purity | ฮpurity |
|---|---|---|---|---|---|---|
| 50 | 0.2952 | 0.4059 | +0.1106 | 0.3219 | 0.4392 | +0.1172 |
| 100 | 0.3548 | 0.4350 | +0.0803 | 0.3745 | 0.4415 | +0.0669 |
| 200 | 0.4152 | 0.4607 | +0.0455 | 0.4262 | 0.4585 | +0.0323 |
| 500 | 0.4997 | 0.4978 | -0.0019 | 0.4990 | 0.4779 | -0.0211 |
Interpretation. The projection head produces large PNMI/purity gains at low-to-moderate K (where baseline features do not separate phones well). At K=500 the gains vanish: with that many clusters, the frozen L15 representation is already granular enough that the MLP's additional compression does not help (and slightly hurts purity via over-splitting โ higher avg_spread). The sweet spot for downstream clustering is K โ 100โ200 where the projection nearly doubles PNMI compared to baseline.
Rare-phone-tail purity improves +0.10 at K=50, +0.06 at K=100, +0.03 at K=200 โ the projection does not collapse the long tail.
Inference code
import torch, torch.nn as nn, torch.nn.functional as F
import torchaudio
from transformers import Wav2Vec2Model
TARGET_LAYER = 15
class ProjectionHead(nn.Module):
def __init__(self, in_dim=1024, hidden=1024, out_dim=256, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
def forward(self, x): return self.net(x)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load + truncate XLS-R to L15
xlsr = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-xls-r-300m').to(device).eval()
xlsr.encoder.layers = xlsr.encoder.layers[:TARGET_LAYER]
xlsr.config.num_hidden_layers = TARGET_LAYER
for p in xlsr.parameters(): p.requires_grad = False
# Load the trained MLP head from this repo
from huggingface_hub import hf_hub_download
ckpt = torch.load(hf_hub_download('breezywaves/xlsr-l15-phonetic-mlp', 'mlp_head.pt'),
map_location=device, weights_only=False)
cfg = ckpt['config']
head = ProjectionHead(**{k: cfg[k] for k in ['in_dim','hidden','out_dim']}).to(device)
head.load_state_dict(ckpt['state_dict']); head.eval()
# Inference over a phone span
def embed(audio, phone_start_sec, phone_end_sec, sr=16000):
with torch.no_grad():
hs = xlsr(audio.unsqueeze(0).to(device),
output_hidden_states=True).hidden_states
l15 = hs[TARGET_LAYER][0] # (time, 1024)
# Convert [start,end]s to frame indices at XLS-R's 50 Hz output rate
sf = int(phone_start_sec * 50); ef = max(sf+1, int(phone_end_sec * 50))
pooled = l15[sf:ef].mean(dim=0) # (1024,)
z = F.normalize(head(pooled.unsqueeze(0)), dim=-1)
return z.squeeze(0).cpu() # (256,)
The embed return is L2-normalised; use cosine similarity / k-means on these vectors.
Caveats
- Pseudo-labelled DoReCo boundaries. The phone-span timestamps come from a forced-aligner consensus on DoReCo, not gold human annotation. Systematic alignment biases are present in both train and held-out evaluation.
- Single seed. Results reported here are from one training run with
seed=42; variance across seeds has not been characterised. - Full-span mean pooling. We pool the whole phone span; coarticulation edge contamination is not actively mitigated.
- No speaker-invariant adaptation. Downstream clustering still reflects speaker/recording-channel variation to some extent.
- K=500 gains disappear. The projection is most valuable at Kโค200 for downstream clustering; at finer K the frozen L15 is already competitive.
Citations
If you use this model, please cite the underlying works:
Babu et al. 2022. XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale. Interspeech.
Pasad, Shi & Livescu 2023. Comparative layer-wise analysis of self-supervised speech models. ICASSP.
Khosla et al. 2020. Supervised Contrastive Learning. NeurIPS.
Paschen et al. 2020. DoReCo โ Building a Time-Aligned Cross-Linguistic Reference Corpus from Language Documentation Data. LREC.
License
Apache-2.0 (matches the base model).
Model tree for breezywaves/xlsr-l15-phonetic-mlp
Base model
facebook/wav2vec2-xls-r-300m