oist commited on
Commit
6ffd2bc
·
1 Parent(s): 8328e47

Initial commit of MMNLI model with LFS

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +146 -0
  3. config.json +15 -0
  4. model.stateforce +3 -0
  5. modeling_mmnli.py +98 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.stateforce filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multilingual & Multimodal NLI (MMNLI)
2
+
3
+ This repository provides the **MMNLI model**, a multilingual and multimodal Natural Language Inference classifier.
4
+ It extends the BLASER architecture into **multiclass NLI**, supporting entailment, contradiction, and neutrality across text-text, text-speech, speech-text, and speech-speech input pairs.
5
+
6
+ The model is trained on the [oist/multimodal_nli_dataset](https://huggingface.co/datasets/oist/multimodal_nli_dataset).
7
+ Please refer to that dataset card for details.
8
+
9
+ ---
10
+
11
+ ## Usage
12
+
13
+ The model depends on **SONAR embeddings**. You can use the official SONAR encoders (for text and speech) or the **ported SONAR text encoder** [`cointegrated/SONAR_200_text_encoder`](https://huggingface.co/cointegrated/SONAR_200_text_encoder).
14
+
15
+ ---
16
+
17
+ ### Example 1: Speech–Text Inference
18
+
19
+ ```python
20
+ import torch
21
+ from sonar.inference_pipelines.speech import SpeechToEmbeddingModelPipeline
22
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
23
+ from transformers import AutoModel
24
+
25
+ # 1. Load SONAR encoders
26
+ speech_encoder = SpeechToEmbeddingModelPipeline(encoder="sonar_speech_encoder_eng")
27
+ text_encoder = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder")
28
+
29
+ # 2. Encode premise (speech) and hypothesis (text)
30
+ premise_embs = speech_encoder.predict(["audio.wav"])
31
+ hypothesis_embs = text_encoder.predict(["The cat sat on the mat."], source_lang="eng_Latn")
32
+
33
+ # 3. Load MMNLI model
34
+ mmnli_model_name = "oist/multimodal_nli_model"
35
+ mmnli_model = AutoModel.from_pretrained(mmnli_model_name, trust_remote_code=True)
36
+ mmnli_model.eval()
37
+
38
+ # 4. Run inference
39
+ with torch.inference_mode():
40
+ logits = mmnli_model(premise_embs, hypothesis_embs) # returns [batch_size, 3]
41
+ pred_class = torch.argmax(logits, dim=-1).item()
42
+
43
+ print("Prediction:", pred_class)
44
+ # 0 = Entailment, 1 = Neutral, 2 = Contradiction
45
+ ```
46
+
47
+ ### Example 2: Text–Text Inference (Official SONAR)
48
+
49
+ ```python
50
+ import torch
51
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
52
+ from transformers import AutoModel
53
+
54
+ # 1. Load official SONAR text encoder
55
+ text_encoder = TextToEmbeddingModelPipeline(
56
+ encoder="text_sonar_basic_encoder",
57
+ tokenizer="text_sonar_basic_encoder"
58
+ )
59
+
60
+ # 2. Encode premise and hypothesis
61
+ premise_texts = ["Le chat s'assit sur le tapis."]
62
+ hypothesis_texts = ["The cat sat on the mat."]
63
+
64
+ premise_embs = text_encoder.predict(premise_texts, source_lang="fra_Latn")
65
+ hypothesis_embs = text_encoder.predict(hypothesis_texts, source_lang="eng_Latn")
66
+
67
+ # 3. Load MMNLI model
68
+ mmnli_model = AutoModel.from_pretrained("oist/multimodal_nli_model", trust_remote_code=True)
69
+ mmnli_model.eval()
70
+
71
+ # 4. Run inference
72
+ with torch.inference_mode():
73
+ logits = mmnli_model(premise_embs, hypothesis_embs)
74
+ pred_class = torch.argmax(logits, dim=-1).item()
75
+
76
+ print("Prediction:", pred_class)
77
+ # 0 = Entailment, 1 = Neutral, 2 = Contradiction
78
+ ```
79
+
80
+ ### Example 3: Text–Text Inference (Ported SONAR)
81
+
82
+ ```python
83
+ import torch
84
+ from transformers import AutoTokenizer, AutoModel
85
+ from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
86
+
87
+ # 1. Load ported SONAR text encoder
88
+ sonar_model_name = "cointegrated/SONAR_200_text_encoder"
89
+ encoder = M2M100Encoder.from_pretrained(sonar_model_name)
90
+ tokenizer = AutoTokenizer.from_pretrained(sonar_model_name)
91
+
92
+ def encode_mean_pool(texts, tokenizer, encoder, lang='eng_Latn', norm=False):
93
+ tokenizer.src_lang = lang
94
+ with torch.inference_mode():
95
+ batch = tokenizer(texts, return_tensors='pt', padding=True)
96
+ seq_embs = encoder(**batch).last_hidden_state
97
+ mask = batch.attention_mask
98
+ mean_emb = (seq_embs * mask.unsqueeze(-1)).sum(1) / mask.unsqueeze(-1).sum(1)
99
+ if norm:
100
+ mean_emb = torch.nn.functional.normalize(mean_emb)
101
+ return mean_emb
102
+
103
+ # Example sentences
104
+ premise_sentences = ["Le chat s'assit sur le tapis."]
105
+ hypothesis_sentences = ["The cat sat on the mat."]
106
+
107
+ # 2. Encode premise and hypothesis
108
+ premise_embs = encode_mean_pool(premise_sentences, tokenizer, encoder, lang="fra_Latn")
109
+ hypothesis_embs = encode_mean_pool(hypothesis_sentences, tokenizer, encoder, lang="eng_Latn")
110
+
111
+
112
+ mmnli_model_name = "oist/multimodal_nli_model"
113
+ mmnli_model = AutoModel.from_pretrained(mmnli_model_name, trust_remote_code=True)
114
+ mmnli_model.eval()
115
+
116
+ # 4. Run inference
117
+ with torch.inference_mode():
118
+ logits = mmnli_model(premise_embs, hypothesis_embs) # returns [batch_size, 3]
119
+ pred_class = torch.argmax(logits, dim=-1).item()
120
+
121
+ print("Prediction:", pred_class)
122
+ # 0 = Entailment, 1 = Neutral, 2 = Contradiction
123
+ ```
124
+
125
+ ---
126
+
127
+ ## Labels
128
+
129
+ - 0 = Entailment
130
+ - 1 = Neutral
131
+ - 2 = Contradiction
132
+
133
+ ---
134
+
135
+ ## Citation
136
+
137
+ If you use this model, please cite:
138
+
139
+ ```bibtex
140
+ @inproceedings{istaiteh2025beyond,
141
+ title={Beyond Similarity Scoring: Detecting Entailment and Contradiction in Multilingual and Multimodal Contexts},
142
+ author={Istaiteh, Othman and Mdhaffar, Salima and Est{\`e}ve, Yannick},
143
+ booktitle={Proc. Interspeech 2025},
144
+ pages={286--290},
145
+ year={2025}
146
+ }
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "TANH",
3
+ "architectures": ["MMNLIModel"],
4
+ "dropout": 0.1,
5
+ "embedding_dim": 1024,
6
+ "hidden_dims": [3072, 1536],
7
+ "model_type": "mmnli",
8
+ "norm_emb": true,
9
+ "output_dim": 3,
10
+ "transformers_version": "4.56.1",
11
+ "auto_map": {
12
+ "AutoConfig": "modeling_mmnli.MMNLIConfig",
13
+ "AutoModel": "modeling_mmnli.MMNLIModel"
14
+ }
15
+ }
model.stateforce ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1b0f69053bbbb0e4b1a4577014eda15b94030c506dbc212726cf2919128751d
3
+ size 69245364
modeling_mmnli.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Optional
5
+ from torch import Tensor
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+
8
+ # ---------------- CONFIG ---------------- #
9
+ class MMNLIConfig(PretrainedConfig):
10
+ model_type = "mmnli"
11
+
12
+ def __init__(
13
+ self,
14
+ embedding_dim: int = 1024,
15
+ hidden_dims: Optional[List[int]] = None,
16
+ dropout: float = 0.1,
17
+ activation: str = "TANH",
18
+ norm_emb: bool = True,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.embedding_dim = embedding_dim
23
+ self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
24
+ self.dropout = dropout
25
+ self.activation = activation
26
+ self.norm_emb = norm_emb
27
+ self.output_dim = 3 # entailment, contradiction, neutral
28
+
29
+
30
+ # ---------------- CORE MODEL ---------------- #
31
+ ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
32
+
33
+
34
+ class MMNLICore(nn.Module):
35
+ def __init__(
36
+ self,
37
+ embedding_dim: int,
38
+ hidden_dims: List[int],
39
+ dropout: float,
40
+ activation: str,
41
+ norm_emb: bool,
42
+ ):
43
+ super().__init__()
44
+ self.norm_emb = norm_emb
45
+
46
+ if activation not in ACTIVATIONS:
47
+ raise ValueError(f"Unrecognized activation: {activation}")
48
+
49
+ # Input: concatenation of [p, h, p*h, |p-h|] => 4 * embedding_dim
50
+ input_dim = embedding_dim * 4
51
+
52
+ modules: List[nn.Module] = []
53
+ if dropout > 0:
54
+ modules.append(nn.Dropout(p=dropout))
55
+
56
+ nprev = input_dim
57
+ for h in hidden_dims:
58
+ modules.append(nn.Linear(nprev, h))
59
+ modules.append(ACTIVATIONS[activation]())
60
+ if dropout > 0:
61
+ modules.append(nn.Dropout(p=dropout))
62
+ nprev = h
63
+
64
+ # Final classifier layer: 3-way softmax
65
+ modules.append(nn.Linear(nprev, 3))
66
+ modules.append(nn.Softmax(dim=-1))
67
+
68
+ self.mlp = nn.Sequential(*modules)
69
+
70
+ def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
71
+ return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
72
+
73
+ def featurize(self, premise: Tensor, hypothesis: Tensor) -> Tensor:
74
+ return torch.cat(
75
+ [premise, hypothesis, premise * hypothesis, torch.abs(premise - hypothesis)],
76
+ dim=-1,
77
+ )
78
+
79
+
80
+ # ---------------- HF MODEL WRAPPER ---------------- #
81
+ class MMNLIModel(PreTrainedModel):
82
+ config_class = MMNLIConfig
83
+
84
+ def __init__(self, config: MMNLIConfig):
85
+ super().__init__(config)
86
+ self.core = MMNLICore(
87
+ embedding_dim=config.embedding_dim,
88
+ hidden_dims=config.hidden_dims,
89
+ dropout=config.dropout,
90
+ activation=config.activation,
91
+ norm_emb=config.norm_emb,
92
+ )
93
+
94
+ def forward(self, premise: Tensor, hypothesis: Tensor):
95
+ premise = self.core._norm(premise)
96
+ hypothesis = self.core._norm(hypothesis)
97
+ proc = self.core.featurize(premise, hypothesis)
98
+ return self.core.mlp(proc)