XLS-R L15 Phonetic Projection Head

A small MLP projection head trained on top of a frozen, truncated XLS-R 300M backbone to produce phone-discriminative embeddings. Intended for multilingual phone clustering / inventory-discovery research; no vocabulary commitment, no per-language classifier.

Architecture

  • Backbone: facebook/wav2vec2-xls-r-300m, frozen and truncated to the first 15 transformer layers. Layers 16โ€“24 are dropped because we read features from layer 15 only (phonetic peak per Pasad et al. 2023 for contrastive-objective 24-layer SSL models).
  • Pooling: full-span mean over the gold phone-span frames at layer 15.
  • Projection head: 2-layer MLP, 1024 โ†’ 1024 โ†’ 256, ReLU + dropout 0.1, L2-normalised output.
  • Loss: Supervised contrastive / NT-Xent (SupCon; Khosla et al. 2020), temperature ฯ„ = 0.07. For each anchor, all in-batch samples with the same IPA label are pulled together; all others are pushed apart.

The XLS-R weights are unchanged; only the 1.3M-parameter MLP is trained.

Training data

DoReCo [Paschen et al. 2020], a multilingual field-recording corpus of phone-aligned narrow-IPA transcriptions across 45 language sources.

  • Training pool: 40 DoReCo sources.
  • Held-out (never seen during training): 5 typologically diverse sources โ€” yong1270 (Yongning Na, tonal Sino-Tibetan), pnar1238 (Pnar, tonal Austroasiatic), arap1274 (Arapaho, Algonquian, noisy recordings), cash1254 (Cashinahua, Panoan with implosives/ejectives), savo1255 (Savosavo, Papuan isolate).
  • Total spans extracted: 1,048,186 (all 45 sources); held-out subset used for evaluation: 169,127 spans.
  • Batch: 32 classes ร— 8 instances per class = 256 samples per step; 5000 training steps; Adam lr = 1e-3.

Results โ€” held-out 5 languages, K-means clustering

All metrics computed on the 169,127-span held-out subset that was never seen during MLP training.

K Baseline PNMI (frozen L15) Projected PNMI (L15 โ†’ MLP) ฮ”PNMI Baseline purity Projected purity ฮ”purity
50 0.2952 0.4059 +0.1106 0.3219 0.4392 +0.1172
100 0.3548 0.4350 +0.0803 0.3745 0.4415 +0.0669
200 0.4152 0.4607 +0.0455 0.4262 0.4585 +0.0323
500 0.4997 0.4978 -0.0019 0.4990 0.4779 -0.0211

Interpretation. The projection head produces large PNMI/purity gains at low-to-moderate K (where baseline features do not separate phones well). At K=500 the gains vanish: with that many clusters, the frozen L15 representation is already granular enough that the MLP's additional compression does not help (and slightly hurts purity via over-splitting โ€” higher avg_spread). The sweet spot for downstream clustering is K โ‰ˆ 100โ€“200 where the projection nearly doubles PNMI compared to baseline.

Rare-phone-tail purity improves +0.10 at K=50, +0.06 at K=100, +0.03 at K=200 โ€” the projection does not collapse the long tail.

Inference code

import torch, torch.nn as nn, torch.nn.functional as F
import torchaudio
from transformers import Wav2Vec2Model

TARGET_LAYER = 15

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=1024, hidden=1024, out_dim=256, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x): return self.net(x)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load + truncate XLS-R to L15
xlsr = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-xls-r-300m').to(device).eval()
xlsr.encoder.layers = xlsr.encoder.layers[:TARGET_LAYER]
xlsr.config.num_hidden_layers = TARGET_LAYER
for p in xlsr.parameters(): p.requires_grad = False

# Load the trained MLP head from this repo
from huggingface_hub import hf_hub_download
ckpt = torch.load(hf_hub_download('breezywaves/xlsr-l15-phonetic-mlp', 'mlp_head.pt'),
                  map_location=device, weights_only=False)
cfg = ckpt['config']
head = ProjectionHead(**{k: cfg[k] for k in ['in_dim','hidden','out_dim']}).to(device)
head.load_state_dict(ckpt['state_dict']); head.eval()

# Inference over a phone span
def embed(audio, phone_start_sec, phone_end_sec, sr=16000):
    with torch.no_grad():
        hs = xlsr(audio.unsqueeze(0).to(device),
                  output_hidden_states=True).hidden_states
        l15 = hs[TARGET_LAYER][0]                  # (time, 1024)
        # Convert [start,end]s to frame indices at XLS-R's 50 Hz output rate
        sf = int(phone_start_sec * 50); ef = max(sf+1, int(phone_end_sec * 50))
        pooled = l15[sf:ef].mean(dim=0)            # (1024,)
        z = F.normalize(head(pooled.unsqueeze(0)), dim=-1)
    return z.squeeze(0).cpu()                      # (256,)

The embed return is L2-normalised; use cosine similarity / k-means on these vectors.

Caveats

  • Pseudo-labelled DoReCo boundaries. The phone-span timestamps come from a forced-aligner consensus on DoReCo, not gold human annotation. Systematic alignment biases are present in both train and held-out evaluation.
  • Single seed. Results reported here are from one training run with seed=42; variance across seeds has not been characterised.
  • Full-span mean pooling. We pool the whole phone span; coarticulation edge contamination is not actively mitigated.
  • No speaker-invariant adaptation. Downstream clustering still reflects speaker/recording-channel variation to some extent.
  • K=500 gains disappear. The projection is most valuable at Kโ‰ค200 for downstream clustering; at finer K the frozen L15 is already competitive.

Citations

If you use this model, please cite the underlying works:

Babu et al. 2022. XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale. Interspeech.
Pasad, Shi & Livescu 2023. Comparative layer-wise analysis of self-supervised speech models. ICASSP.
Khosla et al. 2020. Supervised Contrastive Learning. NeurIPS.
Paschen et al. 2020. DoReCo โ€” Building a Time-Aligned Cross-Linguistic Reference Corpus from Language Documentation Data. LREC.

License

Apache-2.0 (matches the base model).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for breezywaves/xlsr-l15-phonetic-mlp

Finetuned
(864)
this model