saracandu commited on
Commit
93c9fd9
·
verified ·
1 Parent(s): 20a3982

Upload STLForCausalLM

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_stldec.py +68 -0
  3. modeling_stldec.py +2166 -0
config.json CHANGED
@@ -5,6 +5,10 @@
5
  "STLForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
 
 
 
 
8
  "bos_token_id": 2,
9
  "d_model": 1024,
10
  "decoder_attention_heads": 16,
 
5
  "STLForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_stldec.STLConfig",
10
+ "AutoModelForCausalLM": "modeling_stldec.STLForCausalLM"
11
+ },
12
  "bos_token_id": 2,
13
  "d_model": 1024,
14
  "decoder_attention_heads": 16,
configuration_stldec.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ class STLConfig(PretrainedConfig):
4
+
5
+ model_type = "stldec"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=35,
12
+ decoder_vocab_size=None, # unused
13
+ max_position_embeddings=1024,
14
+ encoder_layers=12,
15
+ encoder_ffn_dim=4096,
16
+ encoder_attention_heads=16,
17
+ decoder_layers=12,
18
+ decoder_ffn_dim=4096,
19
+ decoder_attention_heads=16,
20
+ encoder_layerdrop=0.0,
21
+ decoder_layerdrop=0.0,
22
+ use_cache=True,
23
+ is_encoder_decoder=True,
24
+ activation_function="gelu",
25
+ d_model=1024,
26
+ dropout=0.1,
27
+ attention_dropout=0.0,
28
+ activation_dropout=0.0,
29
+ init_std=0.02,
30
+ decoder_start_token_id=3,
31
+ scale_embedding=False,
32
+ pad_token_id=1,
33
+ eos_token_id=3,
34
+ bos_token_id=2,
35
+ forced_eos_token_id=3,
36
+ share_encoder_decoder_embeddings=True,
37
+ **kwargs,
38
+ ):
39
+ self.vocab_size = vocab_size
40
+ self.decoder_vocab_size = decoder_vocab_size or vocab_size
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.d_model = d_model
43
+ self.encoder_ffn_dim = encoder_ffn_dim
44
+ self.encoder_layers = encoder_layers
45
+ self.encoder_attention_heads = encoder_attention_heads
46
+ self.decoder_ffn_dim = decoder_ffn_dim
47
+ self.decoder_layers = decoder_layers
48
+ self.decoder_attention_heads = decoder_attention_heads
49
+ self.dropout = dropout
50
+ self.attention_dropout = attention_dropout
51
+ self.activation_dropout = activation_dropout
52
+ self.activation_function = activation_function
53
+ self.init_std = init_std
54
+ self.encoder_layerdrop = encoder_layerdrop
55
+ self.decoder_layerdrop = decoder_layerdrop
56
+ self.use_cache = use_cache
57
+ self.num_hidden_layers = encoder_layers
58
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
59
+ self.share_encoder_decoder_embeddings = share_encoder_decoder_embeddings
60
+ super().__init__(
61
+ bos_token_id=bos_token_id,
62
+ pad_token_id=pad_token_id,
63
+ eos_token_id=eos_token_id,
64
+ is_encoder_decoder=is_encoder_decoder,
65
+ decoder_start_token_id=decoder_start_token_id,
66
+ forced_eos_token_id=forced_eos_token_id,
67
+ **kwargs,
68
+ )
modeling_stldec.py ADDED
@@ -0,0 +1,2166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import copy
3
+ import math
4
+ import pickle
5
+ import os
6
+ from collections import deque
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import Dataset
16
+
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
19
+ from transformers.generation import GenerationMixin
20
+ from transformers.utils import (
21
+ add_end_docstrings,
22
+ add_start_docstrings,
23
+ add_start_docstrings_to_model_forward,
24
+ logging,
25
+ replace_return_docstrings,
26
+ )
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ Seq2SeqLMOutput,
32
+ Seq2SeqModelOutput,
33
+ )
34
+
35
+ from .configuration_stldec import STLConfig
36
+ from nltk.translate.bleu_score import sentence_bleu
37
+ # from stl import *
38
+ import networkx as nx
39
+ from datasets import load_dataset
40
+
41
+ ### from custom_typing.py
42
+
43
+ realnum = Union[float, int]
44
+
45
+
46
+ ### from stl.py
47
+
48
+ # For tensor functions
49
+ import torch
50
+ from torch import Tensor
51
+ import torch.nn.functional as F
52
+
53
+
54
+ def eventually(x: Tensor, time_span: int) -> Tensor:
55
+ """
56
+ STL operator 'eventually' in 1D.
57
+
58
+ Parameters
59
+ ----------
60
+ x: torch.Tensor
61
+ Signal
62
+ time_span: any numeric type
63
+ Timespan duration
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ A tensor containing the result of the operation.
69
+ """
70
+ return F.max_pool1d(x, kernel_size=time_span, stride=1)
71
+
72
+ class Node:
73
+ """Abstract node class for STL semantics tree."""
74
+
75
+ def __init__(self) -> None:
76
+ # Must be overloaded.
77
+ pass
78
+
79
+ def __str__(self) -> str:
80
+ # Must be overloaded.
81
+ pass
82
+
83
+ def boolean(self, x: Tensor, evaluate_at_all_times: bool = False) -> Tensor:
84
+ """
85
+ Evaluates the boolean semantics at the node.
86
+
87
+ Parameters
88
+ ----------
89
+ x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
90
+ The input signals, stored as a batch tensor with trhee dimensions.
91
+ evaluate_at_all_times: bool
92
+ Whether to evaluate the semantics at all times (True) or
93
+ just at t=0 (False).
94
+
95
+ Returns
96
+ -------
97
+ torch.Tensor
98
+ A tensor with the boolean semantics for the node.
99
+ """
100
+ z: Tensor = self._boolean(x)
101
+ if evaluate_at_all_times:
102
+ return z
103
+ else:
104
+ return self._extract_semantics_at_time_zero(z)
105
+
106
+ def quantitative(
107
+ self,
108
+ x: Tensor,
109
+ normalize: bool = False,
110
+ evaluate_at_all_times: bool = False,
111
+ ) -> Tensor:
112
+ """
113
+ Evaluates the quantitative semantics at the node.
114
+
115
+ Parameters
116
+ ----------
117
+ x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
118
+ The input signals, stored as a batch tensor with three dimensions.
119
+ normalize: bool
120
+ Whether the measure of robustness if normalized (True) or
121
+ not (False). Currently not in use.
122
+ evaluate_at_all_times: bool
123
+ Whether to evaluate the semantics at all times (True) or
124
+ just at t=0 (False).
125
+
126
+ Returns
127
+ -------
128
+ torch.Tensor
129
+ A tensor with the quantitative semantics for the node.
130
+ """
131
+ z: Tensor = self._quantitative(x, normalize)
132
+ if evaluate_at_all_times:
133
+ return z
134
+ else:
135
+ return self._extract_semantics_at_time_zero(z)
136
+
137
+ def set_normalizing_flag(self, value: bool = True) -> None:
138
+ """
139
+ Setter for the 'normalization of robustness of the formula' flag.
140
+ Currently not in use.
141
+ """
142
+
143
+ def time_depth(self) -> int:
144
+ """Returns time depth of bounded temporal operators only."""
145
+ # Must be overloaded.
146
+
147
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
148
+ """Private method equivalent to public one for inner call."""
149
+ # Must be overloaded.
150
+
151
+ def _boolean(self, x: Tensor) -> Tensor:
152
+ """Private method equivalent to public one for inner call."""
153
+ # Must be overloaded.
154
+
155
+ @staticmethod
156
+ def _extract_semantics_at_time_zero(x: Tensor) -> Tensor:
157
+ """Extrapolates the vector of truth values at time zero"""
158
+ return torch.reshape(x[:, 0, 0], (-1,))
159
+
160
+
161
+ class Atom(Node):
162
+ """Atomic formula node; for now of the form X<=t or X>=t"""
163
+
164
+ def __init__(self, var_index: int, threshold: realnum, lte: bool = False) -> None:
165
+ super().__init__()
166
+ self.var_index: int = var_index
167
+ self.threshold: realnum = threshold
168
+ self.lte: bool = lte
169
+
170
+ def __str__(self) -> str:
171
+ s: str = (
172
+ "x_"
173
+ + str(self.var_index)
174
+ + (" <= " if self.lte else " >= ")
175
+ + str(round(self.threshold, 4))
176
+ )
177
+ return s
178
+
179
+ def time_depth(self) -> int:
180
+ return 0
181
+
182
+ def _boolean(self, x: Tensor) -> Tensor:
183
+ # extract tensor of the same dimension as data, but with only one variable
184
+ xj: Tensor = x[:, self.var_index, :]
185
+ xj: Tensor = xj.view(xj.size()[0], 1, -1)
186
+ if self.lte:
187
+ z: Tensor = torch.le(xj, self.threshold)
188
+ else:
189
+ z: Tensor = torch.ge(xj, self.threshold)
190
+ return z
191
+
192
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
193
+ # extract tensor of the same dimension as data, but with only one variable
194
+ xj: Tensor = x[:, self.var_index, :]
195
+ xj: Tensor = xj.view(xj.size()[0], 1, -1)
196
+ if self.lte:
197
+ z: Tensor = -xj + self.threshold
198
+ else:
199
+ z: Tensor = xj - self.threshold
200
+ if normalize:
201
+ z: Tensor = torch.tanh(z)
202
+ return z
203
+
204
+ class Not(Node):
205
+ """Negation node."""
206
+
207
+ def __init__(self, child: Node) -> None:
208
+ super().__init__()
209
+ self.child: Node = child
210
+
211
+ def __str__(self) -> str:
212
+ s: str = "not ( " + self.child.__str__() + " )"
213
+ return s
214
+
215
+ def time_depth(self) -> int:
216
+ return self.child.time_depth()
217
+
218
+ def _boolean(self, x: Tensor) -> Tensor:
219
+ z: Tensor = ~self.child._boolean(x)
220
+ return z
221
+
222
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
223
+ z: Tensor = -self.child._quantitative(x, normalize)
224
+ return z
225
+
226
+
227
+ class And(Node):
228
+ """Conjunction node."""
229
+
230
+ def __init__(self, left_child: Node, right_child: Node) -> None:
231
+ super().__init__()
232
+ self.left_child: Node = left_child
233
+ self.right_child: Node = right_child
234
+
235
+ def __str__(self) -> str:
236
+ s: str = (
237
+ "( "
238
+ + self.left_child.__str__()
239
+ + " and "
240
+ + self.right_child.__str__()
241
+ + " )"
242
+ )
243
+ return s
244
+
245
+ def time_depth(self) -> int:
246
+ return max(self.left_child.time_depth(), self.right_child.time_depth())
247
+
248
+ def _boolean(self, x: Tensor) -> Tensor:
249
+ z1: Tensor = self.left_child._boolean(x)
250
+ z2: Tensor = self.right_child._boolean(x)
251
+ size: int = min(z1.size()[2], z2.size()[2])
252
+ z1: Tensor = z1[:, :, :size]
253
+ z2: Tensor = z2[:, :, :size]
254
+ z: Tensor = torch.logical_and(z1, z2)
255
+ return z
256
+
257
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
258
+ z1: Tensor = self.left_child._quantitative(x, normalize)
259
+ z2: Tensor = self.right_child._quantitative(x, normalize)
260
+ size: int = min(z1.size()[2], z2.size()[2])
261
+ z1: Tensor = z1[:, :, :size]
262
+ z2: Tensor = z2[:, :, :size]
263
+ z: Tensor = torch.min(z1, z2)
264
+ return z
265
+
266
+ class Not(Node):
267
+ """Negation node."""
268
+
269
+ def __init__(self, child: Node) -> None:
270
+ super().__init__()
271
+ self.child: Node = child
272
+
273
+ def __str__(self) -> str:
274
+ s: str = "not ( " + self.child.__str__() + " )"
275
+ return s
276
+
277
+ def time_depth(self) -> int:
278
+ return self.child.time_depth()
279
+
280
+ def _boolean(self, x: Tensor) -> Tensor:
281
+ z: Tensor = ~self.child._boolean(x)
282
+ return z
283
+
284
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
285
+ z: Tensor = -self.child._quantitative(x, normalize)
286
+ return z
287
+
288
+
289
+ class And(Node):
290
+ """Conjunction node."""
291
+
292
+ def __init__(self, left_child: Node, right_child: Node) -> None:
293
+ super().__init__()
294
+ self.left_child: Node = left_child
295
+ self.right_child: Node = right_child
296
+
297
+ def __str__(self) -> str:
298
+ s: str = (
299
+ "( "
300
+ + self.left_child.__str__()
301
+ + " and "
302
+ + self.right_child.__str__()
303
+ + " )"
304
+ )
305
+ return s
306
+
307
+ def time_depth(self) -> int:
308
+ return max(self.left_child.time_depth(), self.right_child.time_depth())
309
+
310
+ def _boolean(self, x: Tensor) -> Tensor:
311
+ z1: Tensor = self.left_child._boolean(x)
312
+ z2: Tensor = self.right_child._boolean(x)
313
+ size: int = min(z1.size()[2], z2.size()[2])
314
+ z1: Tensor = z1[:, :, :size]
315
+ z2: Tensor = z2[:, :, :size]
316
+ z: Tensor = torch.logical_and(z1, z2)
317
+ return z
318
+
319
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
320
+ z1: Tensor = self.left_child._quantitative(x, normalize)
321
+ z2: Tensor = self.right_child._quantitative(x, normalize)
322
+ size: int = min(z1.size()[2], z2.size()[2])
323
+ z1: Tensor = z1[:, :, :size]
324
+ z2: Tensor = z2[:, :, :size]
325
+ z: Tensor = torch.min(z1, z2)
326
+ return z
327
+
328
+ class Or(Node):
329
+ """Disjunction node."""
330
+
331
+ def __init__(self, left_child: Node, right_child: Node) -> None:
332
+ super().__init__()
333
+ self.left_child: Node = left_child
334
+ self.right_child: Node = right_child
335
+
336
+ def __str__(self) -> str:
337
+ s: str = (
338
+ "( "
339
+ + self.left_child.__str__()
340
+ + " or "
341
+ + self.right_child.__str__()
342
+ + " )"
343
+ )
344
+ return s
345
+
346
+ def time_depth(self) -> int:
347
+ return max(self.left_child.time_depth(), self.right_child.time_depth())
348
+
349
+ def _boolean(self, x: Tensor) -> Tensor:
350
+ z1: Tensor = self.left_child._boolean(x)
351
+ z2: Tensor = self.right_child._boolean(x)
352
+ size: int = min(z1.size()[2], z2.size()[2])
353
+ z1: Tensor = z1[:, :, :size]
354
+ z2: Tensor = z2[:, :, :size]
355
+ z: Tensor = torch.logical_or(z1, z2)
356
+ return z
357
+
358
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
359
+ z1: Tensor = self.left_child._quantitative(x, normalize)
360
+ z2: Tensor = self.right_child._quantitative(x, normalize)
361
+ size: int = min(z1.size()[2], z2.size()[2])
362
+ z1: Tensor = z1[:, :, :size]
363
+ z2: Tensor = z2[:, :, :size]
364
+ z: Tensor = torch.max(z1, z2)
365
+ return z
366
+
367
+
368
+ class Globally(Node):
369
+ """Globally node."""
370
+ def __init__(
371
+ self,
372
+ child: Node,
373
+ unbound: bool = False,
374
+ right_unbound: bool = False,
375
+ left_time_bound: int = 0,
376
+ right_time_bound: int = 1,
377
+ adapt_unbound: bool = True,
378
+ ) -> None:
379
+ super().__init__()
380
+ self.child: Node = child
381
+ self.unbound: bool = unbound
382
+ self.right_unbound: bool = right_unbound
383
+ self.left_time_bound: int = left_time_bound
384
+ self.right_time_bound: int = right_time_bound + 1
385
+ self.adapt_unbound: bool = adapt_unbound
386
+
387
+ def __str__(self) -> str:
388
+ s_left = "[" + str(self.left_time_bound) + ","
389
+ s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
390
+ s0: str = s_left + s_right + "]" if not self.unbound else ""
391
+ s: str = "always" + s0 + " ( " + self.child.__str__() + " )"
392
+ return s
393
+
394
+ def time_depth(self) -> int:
395
+ if self.unbound:
396
+ return self.child.time_depth()
397
+ elif self.right_unbound:
398
+ return self.child.time_depth() + self.left_time_bound
399
+ else:
400
+ # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
401
+ return self.child.time_depth() + self.right_time_bound - 1
402
+ # (self.right_time_bound - self.left_time_bound + 1) - diff
403
+
404
+ def _boolean(self, x: Tensor) -> Tensor:
405
+ z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) # nested temporal parameters
406
+ # z1 = z1[:, :, self.left_time_bound:]
407
+ if self.unbound or self.right_unbound:
408
+ if self.adapt_unbound:
409
+ z: Tensor
410
+ _: Tensor
411
+ z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
412
+ z: Tensor = torch.flip(z, [2])
413
+ else:
414
+ z: Tensor
415
+ _: Tensor
416
+ z, _ = torch.min(z1, 2, keepdim=True)
417
+ else:
418
+ z: Tensor = torch.ge(1.0 - eventually((~z1).double(), self.right_time_bound - self.left_time_bound), 0.5)
419
+ return z
420
+
421
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
422
+ z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
423
+ # z1 = z1[:, :, self.left_time_bound:]
424
+ if self.unbound or self.right_unbound:
425
+ if self.adapt_unbound:
426
+ z: Tensor
427
+ _: Tensor
428
+ z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
429
+ z: Tensor = torch.flip(z, [2])
430
+ else:
431
+ z: Tensor
432
+ _: Tensor
433
+ z, _ = torch.min(z1, 2, keepdim=True)
434
+ else:
435
+ z: Tensor = -eventually(-z1, self.right_time_bound - self.left_time_bound)
436
+ return z
437
+
438
+
439
+
440
+ class Eventually(Node):
441
+ """Eventually node."""
442
+
443
+ def __init__(
444
+ self,
445
+ child: Node,
446
+ unbound: bool = False,
447
+ right_unbound: bool = False,
448
+ left_time_bound: int = 0,
449
+ right_time_bound: int = 1,
450
+ adapt_unbound: bool = True,
451
+ ) -> None:
452
+ super().__init__()
453
+ self.child: Node = child
454
+ self.unbound: bool = unbound
455
+ self.right_unbound: bool = right_unbound
456
+ self.left_time_bound: int = left_time_bound
457
+ self.right_time_bound: int = right_time_bound + 1
458
+ self.adapt_unbound: bool = adapt_unbound
459
+
460
+ if (self.unbound is False) and (self.right_unbound is False) and \
461
+ (self.right_time_bound <= self.left_time_bound):
462
+ raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
463
+
464
+ def __str__(self) -> str:
465
+ s_left = "[" + str(self.left_time_bound) + ","
466
+ s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
467
+ s0: str = s_left + s_right + "]" if not self.unbound else ""
468
+ s: str = "eventually" + s0 + " ( " + self.child.__str__() + " )"
469
+ return s
470
+
471
+ def time_depth(self) -> int:
472
+ if self.unbound:
473
+ return self.child.time_depth()
474
+ elif self.right_unbound:
475
+ return self.child.time_depth() + self.left_time_bound
476
+ else:
477
+ # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
478
+ return self.child.time_depth() + self.right_time_bound - 1
479
+ # (self.right_time_bound - self.left_time_bound + 1) - diff
480
+
481
+ def _boolean(self, x: Tensor) -> Tensor:
482
+ z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:])
483
+ if self.unbound or self.right_unbound:
484
+ if self.adapt_unbound:
485
+ z: Tensor
486
+ _: Tensor
487
+ z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
488
+ z: Tensor = torch.flip(z, [2])
489
+ else:
490
+ z: Tensor
491
+ _: Tensor
492
+ z, _ = torch.max(z1, 2, keepdim=True)
493
+ else:
494
+ z: Tensor = torch.ge(eventually(z1.double(), self.right_time_bound - self.left_time_bound), 0.5)
495
+ return z
496
+
497
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
498
+ z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
499
+ if self.unbound or self.right_unbound:
500
+ if self.adapt_unbound:
501
+ z: Tensor
502
+ _: Tensor
503
+ z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
504
+ z: Tensor = torch.flip(z, [2])
505
+ else:
506
+ z: Tensor
507
+ _: Tensor
508
+ z, _ = torch.max(z1, 2, keepdim=True)
509
+ else:
510
+ z: Tensor = eventually(z1, self.right_time_bound - self.left_time_bound)
511
+ return z
512
+
513
+ class Until(Node):
514
+ """Until node."""
515
+
516
+ def __init__(
517
+ self,
518
+ left_child: Node,
519
+ right_child: Node,
520
+ unbound: bool = False,
521
+ right_unbound: bool = False,
522
+ left_time_bound: int = 0,
523
+ right_time_bound: int = 1,
524
+ ) -> None:
525
+ super().__init__()
526
+ self.left_child: Node = left_child
527
+ self.right_child: Node = right_child
528
+ self.unbound: bool = unbound
529
+ self.right_unbound: bool = right_unbound
530
+ self.left_time_bound: int = left_time_bound
531
+ self.right_time_bound: int = right_time_bound + 1
532
+
533
+ if (self.unbound is False) and (self.right_unbound is False) and \
534
+ (self.right_time_bound <= self.left_time_bound):
535
+ raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
536
+
537
+ def __str__(self) -> str:
538
+ s_left = "[" + str(self.left_time_bound) + ","
539
+ s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
540
+ s0: str = s_left + s_right + "]" if not self.unbound else ""
541
+ s: str = "( " + self.left_child.__str__() + " until" + s0 + " " + self.right_child.__str__() + " )"
542
+ return s
543
+
544
+ def time_depth(self) -> int:
545
+ sum_children_depth: int = self.left_child.time_depth() + self.right_child.time_depth()
546
+ if self.unbound:
547
+ return sum_children_depth
548
+ elif self.right_unbound:
549
+ return sum_children_depth + self.left_time_bound
550
+ else:
551
+ # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
552
+ return sum_children_depth + self.right_time_bound - 1
553
+ # (self.right_time_bound - self.left_time_bound + 1) - diff
554
+
555
+ def _boolean(self, x: Tensor) -> Tensor:
556
+ if self.unbound:
557
+ z1: Tensor = self.left_child._boolean(x)
558
+ z2: Tensor = self.right_child._boolean(x)
559
+ size: int = min(z1.size()[2], z2.size()[2])
560
+ z1: Tensor = z1[:, :, :size]
561
+ z2: Tensor = z2[:, :, :size]
562
+ z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
563
+ z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
564
+ z1_triu = torch.triu(z1_rep)
565
+ z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
566
+
567
+ z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
568
+ z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
569
+ z2_triu = torch.triu(z2_rep)
570
+ z2_def = z2_tril + z2_triu
571
+ z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
572
+ dim=-1)[0]
573
+ elif self.right_unbound:
574
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
575
+ And(Eventually(self.right_child, right_unbound=True,
576
+ left_time_bound=self.left_time_bound),
577
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
578
+ left_time_bound=self.left_time_bound, right_unbound=True)))
579
+ z: Tensor = timed_until._boolean(x)
580
+ else:
581
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
582
+ And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
583
+ right_time_bound=self.right_time_bound - 1),
584
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
585
+ left_time_bound=self.left_time_bound, right_unbound=True)))
586
+ z: Tensor = timed_until._boolean(x)
587
+ return z
588
+
589
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
590
+ if self.unbound:
591
+ z1: Tensor = self.left_child._quantitative(x, normalize)
592
+ z2: Tensor = self.right_child._quantitative(x, normalize)
593
+ size: int = min(z1.size()[2], z2.size()[2])
594
+ z1: Tensor = z1[:, :, :size]
595
+ z2: Tensor = z2[:, :, :size]
596
+
597
+ # z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
598
+ # z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
599
+ # z1_triu = torch.triu(z1_rep)
600
+ # z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
601
+
602
+ # z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
603
+ # z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
604
+ # z2_triu = torch.triu(z2_rep)
605
+ # z2_def = z2_tril + z2_triu
606
+ # z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
607
+ # dim=-1)[0]
608
+ z: Tensor = torch.cat([torch.max(torch.min(
609
+ torch.cat([torch.cummin(z1[:, :, t:].unsqueeze(-1), dim=2)[0], z2[:, :, t:].unsqueeze(-1)], dim=-1),
610
+ dim=-1)[0], dim=2, keepdim=True)[0] for t in range(size)], dim=2)
611
+ elif self.right_unbound:
612
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
613
+ And(Eventually(self.right_child, right_unbound=True,
614
+ left_time_bound=self.left_time_bound),
615
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
616
+ left_time_bound=self.left_time_bound, right_unbound=True)))
617
+ z: Tensor = timed_until._quantitative(x, normalize=normalize)
618
+ else:
619
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
620
+ And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
621
+ right_time_bound=self.right_time_bound-1),
622
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
623
+ left_time_bound=self.left_time_bound, right_unbound=True)))
624
+ z: Tensor = timed_until._quantitative(x, normalize=normalize)
625
+ return z
626
+
627
+ # from anchor_set_generation import anchorGeneration
628
+
629
+ import re
630
+ import json
631
+ from typing import Any, Dict, List, Optional, Tuple, Union
632
+ from transformers import PreTrainedTokenizer
633
+ from transformers.utils import logging
634
+
635
+ logger = logging.get_logger(__name__)
636
+
637
+ #### utils ####
638
+
639
+ def load_pickle(path):
640
+ with open(path, 'rb') as f:
641
+ x = pickle.load(f)
642
+ return x
643
+
644
+ def dump_pickle(name, thing):
645
+ with open(name + '.pickle', 'wb') as f:
646
+ pickle.dump(thing, f)
647
+
648
+ def from_string_to_formula(st):
649
+ root_arity = 2 if st.startswith('(') else 1
650
+ st_split = st.split()
651
+ if root_arity <= 1:
652
+ root_op_str = copy.deepcopy(st_split[0])
653
+ if root_op_str.startswith('x'):
654
+ atom_sign = True if st_split[1] == '<=' else False
655
+ root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
656
+ return root_phi
657
+ else:
658
+ assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
659
+ or root_op_str.startswith('always'))
660
+ current_st = copy.deepcopy(st_split[2:-1])
661
+ if root_op_str == 'not':
662
+ root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
663
+ elif root_op_str.startswith('eventually'):
664
+ unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
665
+ root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
666
+ right_unbound=right_unbound, left_time_bound=left_time_bound,
667
+ right_time_bound=right_time_bound)
668
+ else:
669
+ unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
670
+ root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
671
+ right_unbound=right_unbound, left_time_bound=left_time_bound,
672
+ right_time_bound=right_time_bound)
673
+ else:
674
+ # 1 - delete everything which is contained in other sets of parenthesis (if any)
675
+ current_st = copy.deepcopy(st_split[1:-1])
676
+ if '(' in current_st:
677
+ par_queue = deque()
678
+ par_idx_list = []
679
+ for i, sub in enumerate(current_st):
680
+ if sub == '(':
681
+ par_queue.append(i)
682
+ elif sub == ')':
683
+ par_idx_list.append(tuple([par_queue.pop(), i]))
684
+ # open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
685
+ # union of parentheses range --> from these we may extract the substrings to be the children!!!
686
+ children_range = []
687
+ for begin, end in sorted(par_idx_list):
688
+ if children_range and children_range[-1][1] >= begin - 1:
689
+ children_range[-1][1] = max(children_range[-1][1], end)
690
+ else:
691
+ children_range.append([begin, end])
692
+ n_children = len(children_range)
693
+ assert (n_children in [1, 2])
694
+ if n_children == 1:
695
+ # one of the children is a variable --> need to individuate it
696
+ var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
697
+ if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
698
+ children_range[0][0] -= 1
699
+ left_child_str = current_st[:3] if var_child_idx == 0 else \
700
+ current_st[children_range[0][0]:children_range[0][1] + 1]
701
+ right_child_str = current_st[-3:] if var_child_idx == 1 else \
702
+ current_st[children_range[0][0]:children_range[0][1] + 1]
703
+ root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
704
+ current_st[children_range[0][0] - 1]
705
+ assert (root_op_str[:2] in ['an', 'or', 'un'])
706
+ else:
707
+ if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
708
+ children_range[0][0] -= 1
709
+ if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
710
+ children_range[1][0] -= 1
711
+ # if there are two children, with parentheses, the element in the middle is the root
712
+ root_op_str = current_st[children_range[0][1] + 1]
713
+ assert (root_op_str[:2] in ['an', 'or', 'un'])
714
+ left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
715
+ right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
716
+ else:
717
+ # no parentheses means that both children are variables
718
+ left_child_str = current_st[:3]
719
+ right_child_str = current_st[-3:]
720
+ root_op_str = current_st[3]
721
+ left_child_str = ' '.join(left_child_str)
722
+ right_child_str = ' '.join(right_child_str)
723
+ if root_op_str == 'and':
724
+ root_phi = And(left_child=from_string_to_formula(left_child_str),
725
+ right_child=from_string_to_formula(right_child_str))
726
+ elif root_op_str == 'or':
727
+ root_phi = Or(left_child=from_string_to_formula(left_child_str),
728
+ right_child=from_string_to_formula(right_child_str))
729
+ else:
730
+ unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
731
+ root_phi = Until(left_child=from_string_to_formula(left_child_str),
732
+ right_child=from_string_to_formula(right_child_str),
733
+ unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
734
+ right_time_bound=right_time_bound)
735
+ return root_phi
736
+
737
+ def load_json(path: str) -> Union[Dict, List]:
738
+ """
739
+ Load a JSON file from the given path.
740
+ Args:
741
+ path (str): The path to the JSON file to be loaded.
742
+
743
+ Returns:
744
+ Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
745
+ """
746
+ with open(path, "r") as f:
747
+ return json.load(f)
748
+
749
+ #### phis_generator ####
750
+
751
+ class StlGenerator:
752
+ def __init__(
753
+ self,
754
+ leaf_prob: float = 0.3,
755
+ inner_node_prob: list = None,
756
+ threshold_mean: float = 0.0,
757
+ threshold_sd: float = 1.0,
758
+ unbound_prob: float = 0.1,
759
+ right_unbound_prob: float = 0.2,
760
+ time_bound_max_range: float = 20,
761
+ adaptive_unbound_temporal_ops: bool = True,
762
+ max_timespan: int = 100,
763
+ ):
764
+ """
765
+ leaf_prob
766
+ probability of generating a leaf (always zero for root)
767
+ node_types = ["not", "and", "or", "always", "eventually", "until"]
768
+ Inner node types
769
+ inner_node_prob
770
+ probability vector for the different types of internal nodes
771
+ threshold_mean
772
+ threshold_sd
773
+ mean and std for the normal distribution of the thresholds of atoms
774
+ unbound_prob
775
+ probability of a temporal operator to have a time bound o the type [0,infty]
776
+ time_bound_max_range
777
+ maximum value of time span of a temporal operator (i.e. max value of t in [0,t])
778
+ adaptive_unbound_temporal_ops
779
+ if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise
780
+ they are evaluated only at time zero.
781
+ max_timespan
782
+ maximum time depth of a formula.
783
+ """
784
+
785
+ # Address the mutability of default arguments
786
+ if inner_node_prob is None:
787
+ inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166]
788
+
789
+ self.leaf_prob = leaf_prob
790
+ self.inner_node_prob = inner_node_prob
791
+ self.threshold_mean = threshold_mean
792
+ self.threshold_sd = threshold_sd
793
+ self.unbound_prob = unbound_prob
794
+ self.right_unbound_prob = right_unbound_prob
795
+ self.time_bound_max_range = time_bound_max_range
796
+ self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops
797
+ self.node_types = ["not", "and", "or", "always", "eventually", "until"]
798
+ self.max_timespan = max_timespan
799
+
800
+ def sample(self, nvars):
801
+ """
802
+ Samples a random formula with distribution defined in class instance parameters
803
+ Parameters
804
+ ----------
805
+ nvars : number of variables of input signals
806
+ how many variables the formula is expected to consider.
807
+ Returns
808
+ -------
809
+ TYPE
810
+ A random formula.
811
+ """
812
+ return self._sample_internal_node(nvars)
813
+ def bag_sample(self, bag_size, nvars):
814
+ """
815
+ Samples a bag of bag_size formulae
816
+ Parameters
817
+ ----------
818
+ bag_size : INT
819
+ number of formulae.
820
+ nvars : INT
821
+ number of vars in formulae.
822
+ Returns
823
+ -------
824
+ a list of formulae.
825
+ """
826
+ formulae = []
827
+ for _ in range(bag_size):
828
+ phi = self.sample(nvars)
829
+ formulae.append(phi)
830
+ return formulae
831
+
832
+ def _sample_internal_node(self, nvars):
833
+ # Declare & dummy-assign "idiom"
834
+ node: Union[None, Node]
835
+ node = None
836
+ # choose node type
837
+ nodetype = rnd.choice(self.node_types, p=self.inner_node_prob)
838
+ while True:
839
+ if nodetype == "not":
840
+ n = self._sample_node(nvars)
841
+ node = Not(n)
842
+ elif nodetype == "and":
843
+ n1 = self._sample_node(nvars)
844
+ n2 = self._sample_node(nvars)
845
+ node = And(n1, n2)
846
+ elif nodetype == "or":
847
+ n1 = self._sample_node(nvars)
848
+ n2 = self._sample_node(nvars)
849
+ node = Or(n1, n2)
850
+ elif nodetype == "always":
851
+ n = self._sample_node(nvars)
852
+ unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
853
+ node = Globally(
854
+ n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
855
+ )
856
+ elif nodetype == "eventually":
857
+ n = self._sample_node(nvars)
858
+ unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
859
+ node = Eventually(
860
+ n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
861
+ )
862
+ elif nodetype == "until":
863
+ n1 = self._sample_node(nvars)
864
+ n2 = self._sample_node(nvars)
865
+ unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
866
+ node = Until(
867
+ n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
868
+ )
869
+
870
+ if (node is not None) and (node.time_depth() < self.max_timespan):
871
+ return node
872
+
873
+ def _sample_node(self, nvars):
874
+ if rnd.rand() < self.leaf_prob:
875
+ # sample a leaf
876
+ var, thr, lte = self._get_atom(nvars)
877
+ return Atom(var, thr, lte)
878
+ else:
879
+ return self._sample_internal_node(nvars)
880
+
881
+ def _get_temporal_parameters(self):
882
+ if rnd.rand() < self.unbound_prob:
883
+ return True, False, 0, 0
884
+ elif rnd.rand() < self.right_unbound_prob:
885
+ return False, True, rnd.randint(self.time_bound_max_range), 1
886
+ else:
887
+ left_bound = rnd.randint(self.time_bound_max_range)
888
+ return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1
889
+
890
+ def _get_atom(self, nvars):
891
+ variable = rnd.randint(nvars)
892
+ lte = rnd.rand() > 0.5
893
+ threshold = rnd.normal(self.threshold_mean, self.threshold_sd)
894
+ return variable, threshold, lte
895
+
896
+ #### traj_measure ####
897
+
898
+ class Measure:
899
+ def sample(self, samples=100000, varn=2, points=100):
900
+ # Must be overridden
901
+ pass
902
+
903
+ class BaseMeasure(Measure):
904
+ def __init__(
905
+ self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
906
+ ):
907
+ """
908
+ Parameters
909
+ ----------
910
+ mu0 : mean of normal distribution of initial state, optional
911
+ The default is 0.0.
912
+ sigma0 : standard deviation of normal distribution of initial state, optional
913
+ The default is 1.0.
914
+ mu1 : DOUBLE, optional
915
+ mean of normal distribution of total variation. The default is 0.0.
916
+ sigma1 : standard deviation of normal distribution of total variation, optional
917
+ The default is 1.0.
918
+ q : DOUBLE, optional
919
+ probability of change of sign in derivative. The default is 0.1.
920
+ q0 : DOUBLE, optional
921
+ probability of initial sign of derivative. The default is 0.5.
922
+ device : 'cpu' or 'cuda', optional
923
+ device on which to run the algorithm. The default is 'cpu'.
924
+ Returns
925
+ -------
926
+ None.
927
+ """
928
+ self.mu0 = mu0
929
+ self.sigma0 = sigma0
930
+ self.mu1 = mu1
931
+ self.sigma1 = sigma1
932
+ self.q = q
933
+ self.q0 = q0
934
+ self.device = device
935
+
936
+ def sample(self, samples=100000, varn=2, points=100):
937
+ """
938
+ Samples a set of trajectories from the basic measure space, with parameters
939
+ passed to the sampler
940
+ Parameters
941
+ ----------
942
+ points : INT, optional
943
+ number of points per trajectory, including initial one. The default is 1000.
944
+ samples : INT, optional
945
+ number of trajectories. The default is 100000.
946
+ varn : INT, optional
947
+ number of variables per trajectory. The default is 2.
948
+ Returns
949
+ -------
950
+ signal : samples x varn x points double pytorch tensor
951
+ The sampled signals.
952
+ """
953
+ if self.device == "cuda" and not torch.cuda.is_available():
954
+ raise RuntimeError("GPU card or CUDA library not available!")
955
+
956
+ # generate unif RN
957
+ signal = torch.rand(samples, varn, points, device=self.device)
958
+ # first point is special - set to zero for the moment, and set one point to 1
959
+ signal[:, :, 0] = 0.0
960
+ signal[:, :, -1] = 1.0
961
+ # sorting each trajectory
962
+ signal, _ = torch.sort(signal, 2)
963
+ # computing increments and storing them in points 1 to end
964
+ signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1]
965
+ # generate initial state, according to a normal distribution
966
+ signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size())
967
+
968
+ # sampling change signs from bernoulli in -1, 1
969
+ derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device)
970
+ derivs = 2 * torch.bernoulli(derivs) - 1
971
+ # sampling initial derivative
972
+ derivs[:, :, 0] = self.q0
973
+ derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1
974
+ # taking the cumulative product along axis 2
975
+ derivs = torch.cumprod(derivs, 2)
976
+
977
+ # sampling total variation
978
+ totvar = torch.pow(
979
+ self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device),
980
+ 2,
981
+ )
982
+ # multiplying total variation and derivatives and making initial point non-invasive
983
+ derivs = derivs * totvar
984
+ derivs[:, :, 0] = 1.0
985
+
986
+ # computing trajectories by multiplying and then doing a cumulative sum
987
+ signal = signal * derivs
988
+ signal = torch.cumsum(signal, 2)
989
+ return signal
990
+
991
+ #### kernel ####
992
+
993
+ realnum = Union[float, int]
994
+
995
+ class StlKernel:
996
+ def __init__(
997
+ self,
998
+ measure,
999
+ normalize=True,
1000
+ exp_kernel=True,
1001
+ sigma2=0.2, # 0.5 meglio, inizialmente era a 0.2
1002
+ integrate_time=False,
1003
+ samples=100000,
1004
+ varn=2,
1005
+ points=100,
1006
+ boolean=False,
1007
+ signals=None,
1008
+ ):
1009
+ self.traj_measure = measure
1010
+ self.exp_kernel = exp_kernel
1011
+ self.normalize = normalize
1012
+ self.sigma2 = sigma2
1013
+ self.samples = samples
1014
+ self.varn = varn
1015
+ self.points = points
1016
+ self.integrate_time = integrate_time
1017
+ if signals is not None:
1018
+ self.signals = signals
1019
+ else:
1020
+ self.signals = measure.sample(points=points, samples=samples, varn=varn)
1021
+ self.boolean = boolean
1022
+
1023
+ def compute(self, phi1, phi2):
1024
+ return self.compute_one_one(phi1, phi2)
1025
+
1026
+ def compute_one_one(self, phi1, phi2):
1027
+ phis1: list = [phi1]
1028
+ phis2: list = [phi2]
1029
+ ker = self.compute_bag_bag(phis1, phis2)
1030
+ return ker[0, 0]
1031
+
1032
+ def compute_bag(self, phis, return_robustness=True):
1033
+ if self.integrate_time:
1034
+ rhos, selfk, len0 = self._compute_robustness_time(phis)
1035
+ kernel_matrix = self._compute_kernel_time(
1036
+ rhos, rhos, selfk, selfk, len0, len0
1037
+ )
1038
+ else:
1039
+ rhos, selfk = self._compute_robustness_no_time(phis)
1040
+ kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk)
1041
+ len0 = None
1042
+ if return_robustness:
1043
+ return kernel_matrix.cpu(), rhos, selfk, len0
1044
+ else:
1045
+ return kernel_matrix.cpu()
1046
+
1047
+ def compute_one_bag(self, phi1, phis2, return_robustness=False):
1048
+ phis1: list = [phi1]
1049
+ return self.compute_bag_bag(phis1, phis2, return_robustness)
1050
+
1051
+ def compute_bag_bag(self, phis1, phis2, return_robustness=False):
1052
+ if self.integrate_time:
1053
+ rhos1, selfk1, len1 = self._compute_robustness_time(phis1)
1054
+ rhos2, selfk2, len2 = self._compute_robustness_time(phis2)
1055
+ kernel_matrix = self._compute_kernel_time(
1056
+ rhos1, rhos2, selfk1, selfk2, len1, len2
1057
+ )
1058
+ else:
1059
+ rhos1, selfk1 = self._compute_robustness_no_time(phis1)
1060
+ rhos2, selfk2 = self._compute_robustness_no_time(phis2)
1061
+ len1, len2 = [None, None]
1062
+ kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2)
1063
+ if return_robustness:
1064
+ return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2
1065
+ else:
1066
+ return kernel_matrix.cpu()
1067
+
1068
+ def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False):
1069
+ phis: list = [phi]
1070
+ return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness)
1071
+
1072
+ def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False):
1073
+ if self.integrate_time:
1074
+ rhos1, selfk1, len1 = self._compute_robustness_time(phis)
1075
+ kernel_matrix = self._compute_kernel_time(
1076
+ rhos1, rhos, selfk1, rho_self, len1, lengths
1077
+ )
1078
+ else:
1079
+ rhos1, selfk1 = self._compute_robustness_no_time(phis)
1080
+ len1 = None
1081
+ kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self)
1082
+ if return_robustness:
1083
+ return kernel_matrix.cpu(), rhos1, selfk1, len1
1084
+ else:
1085
+ return kernel_matrix.cpu()
1086
+ n = self.samples
1087
+ p = self.points
1088
+ k = len(phis)
1089
+ rhos = torch.zeros((k, n, p), device="cpu")
1090
+ lengths = torch.zeros(k)
1091
+ self_kernels = torch.zeros((k, 1))
1092
+ for i, phi in enumerate(phis):
1093
+ if self.boolean:
1094
+ rho = phi.boolean(self.signals, evaluate_at_all_times=True).float()
1095
+ rho[rho == 0.0] = -1.0
1096
+ else:
1097
+ rho = phi.quantitative(self.signals, evaluate_at_all_times=True)
1098
+ actual_p = rho.size()[2]
1099
+ rho = rho.reshape(n, actual_p).cpu()
1100
+ rhos[i, :, :actual_p] = rho
1101
+ lengths[i] = actual_p
1102
+ self_kernels[i] = torch.tensordot(
1103
+ rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]]
1104
+ ) / (actual_p * n)
1105
+ return rhos, self_kernels, lengths
1106
+
1107
+ def _compute_robustness_no_time(self, phis):
1108
+ n = self.samples
1109
+ k = len(phis)
1110
+ rhos = torch.zeros((k, n), device=self.traj_measure.device)
1111
+ self_kernels = torch.zeros((k, 1), device=self.traj_measure.device)
1112
+ for i, phi in enumerate(phis):
1113
+ if self.boolean:
1114
+ rho = phi.boolean(self.signals, evaluate_at_all_times=False).float()
1115
+ rho[rho == 0.0] = -1.0
1116
+ else:
1117
+ rho = phi.quantitative(self.signals, evaluate_at_all_times=False)
1118
+ self_kernels[i] = rho.dot(rho) / n
1119
+ rhos[i, :] = rho
1120
+ return rhos, self_kernels
1121
+
1122
+ def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2):
1123
+ kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]])
1124
+ length_normalizer = self._compute_trajectory_length_normalizer(len1, len2)
1125
+ kernel_matrix = kernel_matrix * length_normalizer / self.samples
1126
+ if self.normalize:
1127
+ kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
1128
+ if self.exp_kernel:
1129
+ kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
1130
+ return kernel_matrix
1131
+
1132
+ def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2):
1133
+ kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]])
1134
+ kernel_matrix = kernel_matrix / self.samples
1135
+ if self.normalize:
1136
+ kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
1137
+ if self.exp_kernel:
1138
+ kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
1139
+ return kernel_matrix
1140
+
1141
+ @staticmethod
1142
+ def _normalize(kernel_matrix, selfk1, selfk2):
1143
+ normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
1144
+ kernel_matrix = kernel_matrix / normalize
1145
+ return kernel_matrix
1146
+
1147
+ @staticmethod
1148
+ def _normalize(kernel_matrix, selfk1, selfk2):
1149
+ normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
1150
+ kernel_matrix = kernel_matrix / normalize
1151
+ return kernel_matrix
1152
+
1153
+ def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
1154
+ if sigma2 is None:
1155
+ sigma2 = self.sigma2
1156
+ if self.normalize:
1157
+ # selfk is (1.0^2 + 1.0^2)
1158
+ selfk = 2.0
1159
+ else:
1160
+ k1 = selfk1.size()[0]
1161
+ k2 = selfk2.size()[0]
1162
+ selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose(
1163
+ selfk2 * selfk2, 0, 1
1164
+ ).repeat(k1, 1)
1165
+ return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2))
1166
+
1167
+ @staticmethod
1168
+ def _compute_trajectory_length_normalizer(len1, len2):
1169
+ k1 = len1.size()[0]
1170
+ k2 = len2.size()[0]
1171
+ y1 = len1.reshape(-1, 1)
1172
+ y1 = y1.repeat(1, k2)
1173
+ y2 = len2.repeat(k1, 1)
1174
+ return 1.0 / torch.min(y1, y2)
1175
+
1176
+ class GramMatrix:
1177
+ def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None):
1178
+ self.kernel = kernel
1179
+ self.formulae_list = formulae
1180
+ # if kernel is computed from robustness at time zero only,
1181
+ # we store the robustness for each formula and each sample
1182
+ # to speed up computation later
1183
+ self.store_robustness = store_robustness
1184
+ self.dim = len(self.formulae_list) if not bag_size else int(bag_size)
1185
+ self.sample = sample # whether to generate formulae in a controlled manner
1186
+ if self.sample:
1187
+ self.t = 0.99 if self.kernel.boolean else 0.85
1188
+ self.sampler = sampler # stl formulae generator
1189
+ self._compute_gram_matrix()
1190
+
1191
+ def _compute_gram_matrix(self):
1192
+ if self.sample:
1193
+ gram = torch.zeros(self.dim, self.dim)
1194
+ rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \
1195
+ not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points),
1196
+ device=self.kernel.traj_measure.device)
1197
+ lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim)
1198
+ kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device)
1199
+ phis = [self.sampler.sample(nvars=self.kernel.varn)]
1200
+ gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True)
1201
+ while len(phis) < self.dim:
1202
+ i = len(phis)
1203
+ phi = self.sampler.sample(nvars=self.kernel.varn)
1204
+ gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness(
1205
+ phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True)
1206
+ if torch.sum(gram[i, :i + 1] >= self.t) < 3:
1207
+ phis.append(phi)
1208
+ gram[:i, i] = gram[i, :i]
1209
+ gram[i, i] = kernels[i, :]
1210
+
1211
+ self.formulae_list = phis
1212
+ self.gram = gram.cpu()
1213
+ self.robustness = rhos if self.store_robustness else None
1214
+ self.self_kernels = kernels if self.store_robustness else None
1215
+ self.robustness_lengths = lengths if self.store_robustness else None
1216
+ else:
1217
+ if self.store_robustness:
1218
+ k_matrix, rhos, selfk, len0 = self.kernel.compute_bag(
1219
+ self.formulae_list, return_robustness=True
1220
+ )
1221
+ self.gram = k_matrix
1222
+ self.robustness = rhos
1223
+ self.self_kernels = selfk
1224
+ self.robustness_lengths = len0
1225
+ else:
1226
+ self.gram = self.kernel.compute_bag(
1227
+ self.formulae_list, return_robustness=False
1228
+ )
1229
+ self.robustness = None
1230
+ self.self_kernels = None
1231
+ self.robustness_lengths = None
1232
+
1233
+ def compute_kernel_vector(self, phi):
1234
+ if self.store_robustness:
1235
+ return self.kernel.compute_one_from_robustness(
1236
+ phi, self.robustness, self.self_kernels, self.robustness_lengths
1237
+ )
1238
+ else:
1239
+ return self.kernel.compute_one_bag(phi, self.formulae_list)
1240
+
1241
+ def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None):
1242
+ if generate_phis:
1243
+ gram_test = torch.zeros(bag_size, self.dim) # self.dim, bag_size
1244
+ rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \
1245
+ not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points),
1246
+ device=self.kernel.traj_measure.device)
1247
+ lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size)
1248
+ kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device)
1249
+ phi_test = []
1250
+ while len(phi_test) < bag_size:
1251
+ i = len(phi_test)
1252
+ phi = self.sampler.sample(nvars=self.kernel.varn)
1253
+ if self.store_robustness:
1254
+ gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \
1255
+ self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels,
1256
+ self.robustness_lengths, return_robustness=True)
1257
+ else:
1258
+ gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \
1259
+ self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True)
1260
+ if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()):
1261
+ phi_test.append(phi)
1262
+ return phi_test, gram_test.cpu()
1263
+ else:
1264
+ if self.store_robustness:
1265
+ return self.kernel.compute_bag_from_robustness(
1266
+ phis, self.robustness, self.self_kernels, self.robustness_lengths
1267
+ )
1268
+ else:
1269
+ return self.kernel.compute_bag_bag(phis, self.formulae_list)
1270
+
1271
+ def invert_regularized(self, alpha):
1272
+ regularizer = abs(pow(10, alpha)) * torch.eye(self.dim)
1273
+ return torch.inverse(self.gram + regularizer)
1274
+
1275
+ #### anchor_generation ####
1276
+
1277
+ def anchorGeneration(diff_init = False, # to control whether we want formulae to be semantically different by construction
1278
+ embed_dim: int = 30, # embedding dimension, aka number of generated formulae in the anchor set
1279
+ n_vars: int = 3, # dimension of the input signal (3D in this case)
1280
+ leaf_prob: float = 0.4, # complexity of the generated formula
1281
+ cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
1282
+ ) -> str:
1283
+
1284
+ # initialize STL formula generator
1285
+ sampler = StlGenerator(leaf_prob)
1286
+
1287
+ # effective anchor set generation
1288
+ if diff_init:
1289
+
1290
+ # initialize the anchor set with a randomly sampled formula
1291
+ diff_anchor_set = [sampler.sample(nvars=n_vars)]
1292
+
1293
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1294
+ mu = BaseMeasure(device=device)
1295
+
1296
+ # generates a set of random signals working as a tester for the formulae testing
1297
+ signals = mu.sample(samples=10000, varn=n_vars)
1298
+
1299
+ # computes robustness value for the initial set of formulae in the anchor set
1300
+ anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0)
1301
+
1302
+ while len(diff_anchor_set) < embed_dim:
1303
+ # sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
1304
+ candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
1305
+
1306
+ # compute robustness of candidate anchor formulae on the same signals as previous anchor set
1307
+ candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
1308
+
1309
+ # compute cosine similarity between current anchor set and candidate new formulae
1310
+ cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
1311
+
1312
+ # check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
1313
+ # NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
1314
+ similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
1315
+
1316
+ # keep only those who are semantically distant
1317
+ keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
1318
+
1319
+ diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
1320
+
1321
+ # Convert keep_idx to a tensor on the same device as candidate_robs
1322
+ keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
1323
+
1324
+ # Use index_select to pick the relevant rows
1325
+ selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
1326
+
1327
+ # Concatenate on the same device
1328
+ anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
1329
+
1330
+ anchor_set = diff_anchor_set[:embed_dim]
1331
+
1332
+ else:
1333
+ anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
1334
+
1335
+ filename = f'anchor_set_no_diff_{embed_dim}_dim'
1336
+ dump_pickle(filename, anchor_set)
1337
+ return filename
1338
+
1339
+ ####
1340
+
1341
+ """
1342
+ A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
1343
+ This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
1344
+ and handle padding and special tokens.
1345
+ """
1346
+
1347
+ def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
1348
+ bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs):
1349
+ """
1350
+ Initializes the STLTokenizer with a given vocabulary and special tokens.
1351
+ Args:
1352
+ vocab_path (str): The path to the JSON file containing the vocabulary.
1353
+ unk_token (str, optional): The token used for unknown words. Defaults to "unk".
1354
+ pad_token (str, optional): The token used for padding. Defaults to "pad".
1355
+ bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s".
1356
+ eos_token (str, optional): The token used for the end of a sequence. Defaults to "s".
1357
+ """
1358
+ self.vocab = load_json(vocab_path)
1359
+ self.unk_token = unk_token
1360
+ self.pad_token = pad_token
1361
+ self.bos_token = bos_token
1362
+ self.eos_token = eos_token
1363
+ self.model_max_length = model_max_length
1364
+ self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
1365
+ super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token,
1366
+ model_max_length=model_max_length, *args, **kwargs)
1367
+
1368
+ @property
1369
+ def vocab_size(self) -> int:
1370
+ """
1371
+ Returns the size of the vocabulary.
1372
+ Returns:
1373
+ int: The number of tokens in the vocabulary.
1374
+ """
1375
+ return len(self.vocab)
1376
+
1377
+ def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
1378
+ """
1379
+ Replaces spaces in the input sequence with a specified token.
1380
+ Args:
1381
+ sequence (str): The input sequence.
1382
+ undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
1383
+ Returns:
1384
+ str: The preprocessed sequence with spaces or padding tokens replaced.
1385
+ """
1386
+ if undo:
1387
+ return sequence.replace(new_space_token, space_token)
1388
+ else:
1389
+ return sequence.replace(space_token, new_space_token)
1390
+
1391
+ def add_bos_eos(self, sequence: str) -> str:
1392
+ """
1393
+ Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
1394
+ Args:
1395
+ sequence (str): La sequenza di input.
1396
+ Returns:
1397
+ str: La sequenza con i token BOS ed EOS.
1398
+ """
1399
+ return f'{self.bos_token} {sequence} {self.eos_token}'
1400
+
1401
+ def tokenize(self, text: str) -> List[str]:
1402
+ """
1403
+ Tokenizes the input text into a list of tokens.
1404
+ The method preprocesses the input text by replacing spaces with padding tokens and then tries to
1405
+ find the longest possible match for each substring in the vocabulary.
1406
+ Args:
1407
+ text (str): The input text to be tokenized.
1408
+ Returns:
1409
+ List[str]: A list of tokens representing the tokenized text.
1410
+ """
1411
+ text = self.add_bos_eos(text)
1412
+ text = self.prepad_sequence(text)
1413
+ tokens = []
1414
+ i = 0
1415
+ while i < len(text):
1416
+ best_match = None
1417
+ for j in range(len(text), i, -1): # Try matching substrings of decreasing length
1418
+ subtoken = text[i:j]
1419
+ if subtoken in self.vocab:
1420
+ best_match = subtoken
1421
+ break
1422
+ if best_match:
1423
+ tokens.append(best_match)
1424
+ i += len(best_match)
1425
+ else:
1426
+ tokens.append(self.unk_token)
1427
+ i += 1
1428
+ return tokens
1429
+
1430
+ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
1431
+ """
1432
+ Converts a list of tokens into a list of token IDs.
1433
+ Args:
1434
+ tokens (List[str]): A list of tokens to be converted into IDs.
1435
+ Returns:
1436
+ List[int]: A list of corresponding token IDs.
1437
+ """
1438
+ return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
1439
+
1440
+ def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
1441
+ """
1442
+ Converts a list of token IDs into a list of tokens.
1443
+ Args:
1444
+ ids (List[int]): A list of token IDs to be converted into tokens.
1445
+ Returns:
1446
+ List[str]: A list of corresponding tokens.
1447
+ """
1448
+ return [self.id_to_token.get(i, self.unk_token) for i in ids]
1449
+
1450
+ def encode(self, sequence: str) -> List[int]:
1451
+ """
1452
+ Encodes a string sequence into a list of token IDs.
1453
+
1454
+ This method tokenizes the input sequence using the `tokenize` method,
1455
+ and then converts the resulting tokens into their corresponding token IDs
1456
+ using the `convert_tokens_to_ids` method.
1457
+
1458
+ Args:
1459
+ sequence (str): The input sequence (text) to be encoded.
1460
+
1461
+ Returns:
1462
+ List[int]: A list of token IDs corresponding to the input sequence.
1463
+ """
1464
+ splitted_sequence = self.tokenize(sequence)
1465
+ return self.convert_tokens_to_ids(splitted_sequence)
1466
+
1467
+ def postpad_sequence(self, sequence, pad_token_id):
1468
+ """
1469
+ Fills the sequence up to max_length padding elements
1470
+ """
1471
+ num_extra_elements = self.model_max_length - len(sequence) -1
1472
+ if num_extra_elements > 0:
1473
+ sequence.extend([pad_token_id] * num_extra_elements)
1474
+ return sequence
1475
+
1476
+ def decode(self, token_ids: List[int]) -> str:
1477
+ """
1478
+ Decodes a list of token IDs into a string of text.
1479
+ The method converts the IDs to tokens and joins them to form a string.
1480
+ It also restores the original spaces or padding tokens if `undo` is True.
1481
+ Args:
1482
+ token_ids (List[int]): A list of token IDs to be decoded.
1483
+ skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
1484
+ Returns:
1485
+ str: The decoded string.
1486
+ """
1487
+ tokens = self.convert_ids_to_tokens(token_ids)
1488
+ decoded = "".join(tokens)
1489
+ return self.prepad_sequence(decoded, undo=True)
1490
+
1491
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
1492
+ """
1493
+ Saves the tokenizer's vocabulary to a file.
1494
+ Useful only when the vocabulary has to be retrieved and is not given
1495
+ (thus this is not the case: here to further improvements with sentencepiece).
1496
+ This method saves the vocabulary to a JSON file in the specified directory.
1497
+ Args:
1498
+ save_directory (str): The directory where the vocabulary file will be saved.
1499
+ filename_prefix (Optional[str]): An optional prefix for the filename.
1500
+ Returns:
1501
+ Tuple[str]: A tuple containing the path to the saved vocabulary file.
1502
+ """
1503
+ vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json"
1504
+ with open(vocab_file, "w", encoding="utf-8") as f:
1505
+ json.dump(self.vocab, f, indent=2, ensure_ascii=False)
1506
+ return (vocab_file,)
1507
+
1508
+ def get_vocab(self) -> dict:
1509
+ """
1510
+ Retrieves the vocabulary used by the tokenizer.
1511
+ Returns:
1512
+ dict: The vocabulary as a dictionary.
1513
+ """
1514
+ return self.vocab
1515
+
1516
+ class STLSinusoidalPositionalEmbedding(nn.Embedding):
1517
+ """This module produces sinusoidal positional embeddings of any length."""
1518
+
1519
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
1520
+ super().__init__(num_positions, embedding_dim)
1521
+ self.weight = self._init_weight(self.weight)
1522
+
1523
+ @staticmethod
1524
+ def _init_weight(out: nn.Parameter) -> nn.Parameter:
1525
+ """
1526
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
1527
+ the 2nd half of the vector. [dim // 2:]
1528
+ """
1529
+ n_pos, dim = out.shape
1530
+ position_enc = np.array(
1531
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
1532
+ )
1533
+ out.requires_grad = False # set early to avoid an error in pytorch-1.8+
1534
+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
1535
+ out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
1536
+ out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1537
+ out.detach_()
1538
+ return out
1539
+ @torch.no_grad()
1540
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
1541
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
1542
+ bsz, seq_len = input_ids_shape[:2]
1543
+ positions = torch.arange(
1544
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
1545
+ )
1546
+ return super().forward(positions)
1547
+
1548
+ class STLAttention(nn.Module):
1549
+ """ Multi-Head Attention as depicted from 'Attention is all you need' """
1550
+
1551
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
1552
+ is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
1553
+
1554
+ super().__init__()
1555
+ self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
1556
+ self.num_heads = num_heads
1557
+ self.dropout = dropout
1558
+ self.head_dim = embed_dim // num_heads
1559
+ assert (self.head_dim * num_heads) == self.embed_dim
1560
+ self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
1561
+ self.is_decoder = is_decoder
1562
+ self.is_causal = is_causal
1563
+
1564
+ # 'roleplaying' matrices
1565
+ self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
1566
+ self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
1567
+ self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
1568
+
1569
+ # to project the heads' outputs into a single vector
1570
+ self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
1571
+
1572
+
1573
+ def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
1574
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1575
+
1576
+
1577
+ def forward(self,
1578
+ hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
1579
+ key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
1580
+ past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
1581
+ attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
1582
+ layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
1583
+ output_attentions: bool = False # flag to control the output of the attn values,
1584
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1585
+
1586
+ is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
1587
+
1588
+ batch_size, tgt_len, embed_dim = hidden_states.size()
1589
+
1590
+ # Project the current input in the `query` role:
1591
+ query = self.W_q(hidden_states) * self.scaling
1592
+
1593
+ if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]):
1594
+ key = past_key_value[0]
1595
+ value = past_key_value[1]
1596
+ elif is_cross_attention:
1597
+ key = self._shape(self.W_k(key_value_states), -1, batch_size)
1598
+ value = self._shape(self.W_v(key_value_states), -1, batch_size)
1599
+ elif past_key_value is not None:
1600
+ key = self._shape(self.W_k(hidden_states), -1, batch_size)
1601
+ value = self._shape(self.W_v(hidden_states), -1, batch_size)
1602
+ key = torch.cat([past_key_value[0], key], dim=2)
1603
+ value = torch.cat([past_key_value[1], value], dim=2)
1604
+ else:
1605
+ key = self._shape(self.W_k(hidden_states), -1, batch_size)
1606
+ value = self._shape(self.W_v(hidden_states), -1, batch_size)
1607
+ if self.is_decoder:
1608
+ past_key_value = (key, value)
1609
+
1610
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
1611
+
1612
+ query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
1613
+ key = key.reshape(*proj_shape)
1614
+ value = value.reshape(*proj_shape)
1615
+
1616
+ src_len = key.size(1)
1617
+
1618
+
1619
+ ######################################################################################################
1620
+
1621
+ # 'traditional' attention computation
1622
+ # i.e. softmax(Q*K^T / sqrt(d_model) + self_attn_mask) * V
1623
+
1624
+ # Batch-wise matrix multiplication between `query` and (TRANSPOSED) `key`
1625
+ attn_weights = torch.bmm(query, key.transpose(1, 2))
1626
+
1627
+ if attention_mask is not None:
1628
+ attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
1629
+ attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
1630
+
1631
+ # Normalize values on the `key` axis (dim=-1)
1632
+ attn_weights = F.softmax(attn_weights, dim=-1)
1633
+
1634
+ # if layer_head_mask is not None:
1635
+ # attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(batch_size, self.num_heads, tgt_len, src_len)
1636
+ # attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
1637
+
1638
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
1639
+
1640
+ # Batch-wise matrix multiplication between the resulting probs and the value
1641
+ attn_output = torch.bmm(attn_probs, value)
1642
+
1643
+ ######################################################################################################
1644
+
1645
+ attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
1646
+ attn_output = attn_output.transpose(1, 2)
1647
+
1648
+ attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
1649
+ attn_output = self.W_o(attn_output)
1650
+
1651
+ return attn_output, None, past_key_value
1652
+
1653
+ ####
1654
+
1655
+ class STLEncoder():
1656
+ def __init__(self,
1657
+ embed_dim: int,
1658
+ anchor_filename: Optional[str] = None,
1659
+ n_vars: int = 3):
1660
+
1661
+ self.n_vars = n_vars # passaglielo in input
1662
+ self.embed_dim = embed_dim
1663
+ self.anchorset_filename = anchor_filename
1664
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1665
+ self.mu = BaseMeasure(device=self.device)
1666
+ self.kernel = StlKernel(self.mu, varn=self.n_vars)
1667
+
1668
+ if anchor_filename is None:
1669
+ anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars)
1670
+ anchor_filename+='.pickle'
1671
+
1672
+ # TO DO: check on the dimensions of the anchor set and the `embed_dim` and `n_vars` values
1673
+ anchor_set = load_pickle(anchor_filename)
1674
+ if len(anchor_set) != self.embed_dim:
1675
+ raise ValueError("The anchor set and the embedding dimension do not match!")
1676
+
1677
+ self.anchor_set = anchor_set
1678
+
1679
+ def compute_embeddings(self, formula: List[str]):
1680
+ return self.kernel.compute_bag_bag(formula, self.anchor_set)
1681
+
1682
+ class STLModel(PreTrainedModel):
1683
+ config_class = STLConfig
1684
+ base_model_prefix = "model"
1685
+ supports_gradient_checkpointing = True
1686
+
1687
+ # initializes the weights of `nn.Linear`, `nn.Embedding` and `STLSinusoidalPositionalEmbedding`
1688
+ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, STLSinusoidalPositionalEmbedding]):
1689
+ std = self.config.init_std
1690
+ if isinstance(module, nn.Linear):
1691
+ module.weight.data.normal_(mean=0.0, std=std)
1692
+ if module.bias is not None:
1693
+ module.bias.data.zero_()
1694
+ elif isinstance(module, STLSinusoidalPositionalEmbedding):
1695
+ pass
1696
+ elif isinstance(module, nn.Embedding):
1697
+ module.weight.data.normal_(mean=0.0, std=std)
1698
+ if module.padding_idx is not None:
1699
+ module.weight.data[module.padding_idx].zero_()
1700
+
1701
+ @property
1702
+ def dummy_inputs(self):
1703
+ pad_token = self.config.pad_token_id
1704
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
1705
+ dummy_inputs = {
1706
+ "attention_mask": input_ids.ne(pad_token),
1707
+ "input_ids": input_ids,
1708
+ "decoder_input_ids": input_ids,
1709
+ }
1710
+ return dummy_inputs
1711
+
1712
+ class STLDecoderBlock(nn.Module):
1713
+
1714
+ def __init__(self, embed_dim: int,
1715
+ num_decoder_attention_heads: int,
1716
+ num_decoder_ffn_dim: int,
1717
+ dropout: float = 0.0,
1718
+ attention_dropout: float = 0.0,
1719
+ activation_dropout: float = 0.0,
1720
+ ):
1721
+
1722
+ super().__init__()
1723
+
1724
+ self.embed_dim = embed_dim
1725
+
1726
+ # first block
1727
+ self.self_attn = STLAttention(
1728
+ embed_dim=self.embed_dim,
1729
+ num_heads=num_decoder_attention_heads,
1730
+ dropout=dropout,
1731
+ is_decoder=True, # not used, debugging purposes
1732
+ is_causal=True, # not used, debugging purposes
1733
+ )
1734
+ self.dropout = dropout
1735
+ self.activation_fn = nn.functional.gelu
1736
+ self.activation_dropout = activation_dropout
1737
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1738
+
1739
+ # second block
1740
+ self.encoder_attn = STLAttention(
1741
+ self.embed_dim,
1742
+ num_decoder_attention_heads,
1743
+ dropout=attention_dropout,
1744
+ is_decoder=True, # not used, debugging purposes
1745
+ )
1746
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1747
+
1748
+ # third block
1749
+ self.fc1 = nn.Linear(self.embed_dim, num_decoder_ffn_dim)
1750
+ self.fc2 = nn.Linear(num_decoder_ffn_dim, self.embed_dim)
1751
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
1752
+
1753
+
1754
+ def forward(
1755
+ self,
1756
+ hidden_states: torch.Tensor,
1757
+ attention_mask: Optional[torch.Tensor] = None,
1758
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1759
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1760
+ layer_head_mask: Optional[torch.Tensor] = None,
1761
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
1762
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1763
+ output_attentions: Optional[bool] = False,
1764
+ use_cache: Optional[bool] = True,
1765
+ **kwargs,
1766
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1767
+ """
1768
+ Args:
1769
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1770
+ attention_mask (`torch.FloatTensor`): attention mask of size
1771
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1772
+ encoder_hidden_states (`torch.FloatTensor`):
1773
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
1774
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
1775
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1776
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
1777
+ `(encoder_attention_heads,)`.
1778
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
1779
+ size `(decoder_attention_heads,)`.
1780
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
1781
+ output_attentions (`bool`, *optional*):
1782
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1783
+ returned tensors for more detail.
1784
+ """
1785
+
1786
+ ###################################################################
1787
+
1788
+ # BLOCK 1: processing what has been previously generated
1789
+
1790
+ # previous state is stored into an auxiliary variable `residual`
1791
+ residual = hidden_states
1792
+
1793
+ # tries to exploit previous K, V values if there are any
1794
+ # (practically picks up to the first 2 values stored in `past_key_value` vector)
1795
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
1796
+
1797
+ # masked MHSA on the already generated sequence
1798
+ # invokes `forward` method to transform the original vector accordingly
1799
+ hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
1800
+ hidden_states=hidden_states, # Q
1801
+ past_key_value=self_attn_past_key_value, # K, V
1802
+ attention_mask=attention_mask, # passed as input of the decoder layer
1803
+ layer_head_mask=layer_head_mask, # to deactivate certain attn layers
1804
+ output_attentions=output_attentions,
1805
+ )
1806
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1807
+
1808
+ # residual connection
1809
+ hidden_states = residual + hidden_states
1810
+
1811
+ # normalization
1812
+ hidden_states = self.self_attn_layer_norm(hidden_states)
1813
+
1814
+ ###################################################################
1815
+
1816
+ # BLOCK 2: cross-attn between already generated input and previous information (from the encoder)
1817
+
1818
+ # initialize K, Q, attn_weights for this new attn operation
1819
+ cross_attn_present_key_value = None
1820
+ cross_attn_weights = None
1821
+
1822
+ # the important condition is that the encoder carries some information
1823
+ if encoder_hidden_states is not None:
1824
+
1825
+ # previous state is stored into an auxiliary variable `residual`
1826
+ residual = hidden_states
1827
+
1828
+ # cross_attn cached key/values tuple is at positions 3, 4 of PAST_key_value tuple
1829
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
1830
+
1831
+ # MHSA in cross-attn
1832
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn.forward(
1833
+ hidden_states=hidden_states, # Q = generated output
1834
+ key_value_states=encoder_hidden_states, # K, V = encoder memory (used only in the 1st step when `use_cache = True`)
1835
+ attention_mask=encoder_attention_mask, # just pads some elements (not causal this time!)
1836
+ layer_head_mask=cross_attn_layer_head_mask, # again to mask certain heads
1837
+ past_key_value=cross_attn_past_key_value, # K, V = encoder CACHED memory (used from the 2nd step on when `use_cache = True`)
1838
+ output_attentions=output_attentions,
1839
+ )
1840
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1841
+
1842
+ # residual connection
1843
+ hidden_states = residual + hidden_states
1844
+
1845
+ # normalization
1846
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
1847
+
1848
+ # add cross-attn to positions 3, 4 of PRESENT_key_value tuple
1849
+ present_key_value = present_key_value + cross_attn_present_key_value
1850
+
1851
+ ###################################################################
1852
+
1853
+ # BLOCK 3: FFNN (transforming some merged generated output - encoder information)
1854
+
1855
+ # previous state is stored into an auxiliary variable `residual`
1856
+ residual = hidden_states
1857
+
1858
+ # FFNN - core
1859
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
1860
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
1861
+ hidden_states = self.fc2(hidden_states)
1862
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1863
+
1864
+ # residual connection
1865
+ hidden_states = residual + hidden_states
1866
+
1867
+ # normalization
1868
+ hidden_states = self.final_layer_norm(hidden_states)
1869
+
1870
+ outputs = (hidden_states,)
1871
+
1872
+ if output_attentions:
1873
+ outputs += (self_attn_weights, cross_attn_weights)
1874
+
1875
+ if use_cache: # otherwise it picks K and V each time
1876
+ outputs += (present_key_value,)
1877
+
1878
+ return outputs
1879
+
1880
+ class STLDecoder(STLModel):
1881
+ def __init__(self, config):
1882
+ super().__init__(config)
1883
+
1884
+ # Extract from `config` file
1885
+ embed_dim = config.d_model
1886
+ num_decoder_attention_heads = config.decoder_attention_heads
1887
+ num_decoder_ffn_dim = config.decoder_ffn_dim
1888
+ max_position_embeddings = config.max_position_embeddings
1889
+ decoder_vocab_size = config.vocab_size
1890
+ pad_token_id = config.pad_token_id
1891
+ num_decoder_layers = config.decoder_layers
1892
+ scale_embedding = config.scale_embedding
1893
+ dropout = config.dropout
1894
+ attention_dropout = config.attention_dropout
1895
+ activation_dropout = config.activation_dropout
1896
+ decoder_layerdrop = config.decoder_layerdrop
1897
+
1898
+ self.dropout = dropout
1899
+ self.layerdrop = decoder_layerdrop
1900
+ self.padding_idx = pad_token_id
1901
+ self.max_target_positions = max_position_embeddings
1902
+ self.embed_scale = math.sqrt(embed_dim) if scale_embedding else 1.0
1903
+
1904
+ # Initialize the input embedding (if not passed already)
1905
+ self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx)
1906
+
1907
+ # Initialize positional embedding also
1908
+ self.embed_positions = STLSinusoidalPositionalEmbedding(
1909
+ max_position_embeddings, embed_dim, self.padding_idx
1910
+ )
1911
+
1912
+ # Initialize decoder layers (of a prespecified number)
1913
+ self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads,
1914
+ num_decoder_ffn_dim, dropout,
1915
+ attention_dropout, activation_dropout)
1916
+ for _ in range(num_decoder_layers)])
1917
+
1918
+ self.gradient_checkpointing = False
1919
+ self.post_init()
1920
+
1921
+ def forward(
1922
+ self,
1923
+ input_ids: torch.LongTensor = None,
1924
+ attention_mask: Optional[torch.Tensor] = None,
1925
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1926
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1927
+ head_mask: Optional[torch.Tensor] = None,
1928
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1929
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1930
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1931
+ use_cache: Optional[bool] = None,
1932
+ output_attentions: Optional[bool] = None,
1933
+ output_hidden_states: Optional[bool] = None,
1934
+ return_dict: Optional[bool] = None,
1935
+ **kwargs,
1936
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
1937
+
1938
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1939
+ output_hidden_states = (
1940
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1941
+ )
1942
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1943
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1944
+
1945
+ # retrieve input_ids and inputs_embeds
1946
+ if input_ids is not None and inputs_embeds is not None:
1947
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1948
+ elif input_ids is not None:
1949
+ input_shape = input_ids.size()
1950
+ input_ids = input_ids.view(-1, input_shape[-1])
1951
+ elif inputs_embeds is not None:
1952
+ input_shape = inputs_embeds.size()[:-1]
1953
+ else:
1954
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1955
+
1956
+ # past_key_values_length
1957
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1958
+
1959
+ if inputs_embeds is None:
1960
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1961
+
1962
+ attention_mask = _prepare_4d_causal_attention_mask(
1963
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1964
+ )
1965
+
1966
+ # expand encoder attention mask
1967
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1968
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1969
+ encoder_attention_mask = _prepare_4d_attention_mask(
1970
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1971
+ )
1972
+
1973
+ # embed positions
1974
+ positions = self.embed_positions(input_shape, past_key_values_length)
1975
+
1976
+ hidden_states = inputs_embeds + positions
1977
+
1978
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1979
+
1980
+ if self.gradient_checkpointing and self.training:
1981
+ if use_cache:
1982
+ logger.warning_once(
1983
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1984
+ )
1985
+ use_cache = False
1986
+
1987
+ # decoder layers
1988
+ all_hidden_states = () if output_hidden_states else None
1989
+ all_self_attns = () if output_attentions else None
1990
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1991
+ next_decoder_cache = () if use_cache else None
1992
+
1993
+ for idx, decoder_layer in enumerate(self.layers):
1994
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1995
+ if output_hidden_states:
1996
+ all_hidden_states += (hidden_states,)
1997
+ if self.training:
1998
+ dropout_probability = torch.rand([])
1999
+ if dropout_probability < self.layerdrop:
2000
+ continue
2001
+
2002
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
2003
+
2004
+ if self.gradient_checkpointing and self.training:
2005
+ layer_outputs = self._gradient_checkpointing_func(
2006
+ decoder_layer.__call__,
2007
+ hidden_states,
2008
+ attention_mask,
2009
+ encoder_hidden_states,
2010
+ encoder_attention_mask,
2011
+ head_mask[idx] if head_mask is not None else None,
2012
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
2013
+ None,
2014
+ output_attentions,
2015
+ use_cache,
2016
+ )
2017
+ else:
2018
+ layer_outputs = decoder_layer(
2019
+ hidden_states,
2020
+ attention_mask=attention_mask,
2021
+ encoder_hidden_states=encoder_hidden_states,
2022
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
2023
+ cross_attn_layer_head_mask=(
2024
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
2025
+ ),
2026
+ past_key_value=past_key_value,
2027
+ output_attentions=output_attentions,
2028
+ use_cache=use_cache,
2029
+ )
2030
+ hidden_states = layer_outputs[0]
2031
+
2032
+ if use_cache:
2033
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
2034
+
2035
+ if output_attentions:
2036
+ all_self_attns += (layer_outputs[1],)
2037
+
2038
+ if encoder_hidden_states is not None:
2039
+ all_cross_attentions += (layer_outputs[2],)
2040
+
2041
+ # add hidden states from the last decoder layer
2042
+ if output_hidden_states:
2043
+ all_hidden_states += (hidden_states,)
2044
+
2045
+ next_cache = next_decoder_cache if use_cache else None
2046
+ if not return_dict:
2047
+ return tuple(
2048
+ v
2049
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
2050
+ if v is not None
2051
+ )
2052
+ return BaseModelOutputWithPastAndCrossAttentions(
2053
+ last_hidden_state=hidden_states,
2054
+ past_key_values=next_cache,
2055
+ hidden_states=all_hidden_states,
2056
+ attentions=all_self_attns,
2057
+ cross_attentions=all_cross_attentions,
2058
+ )
2059
+
2060
+ ####
2061
+
2062
+ class STLForCausalLM(STLModel, GenerationMixin):
2063
+ _tied_weights_keys = ["lm_head.weight"]
2064
+
2065
+ def __init__(self, config):
2066
+ config = copy.deepcopy(config)
2067
+ config.is_decoder = True
2068
+ config.is_encoder_decoder = False
2069
+
2070
+ super().__init__(config)
2071
+ self.model = STLDecoder(config)
2072
+
2073
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2074
+
2075
+ # Initialize weights and apply final processing
2076
+ self.post_init()
2077
+
2078
+ def get_input_embeddings(self):
2079
+ return self.model.embed_tokens
2080
+
2081
+ def set_input_embeddings(self, value):
2082
+ self.model.embed_tokens = value
2083
+
2084
+ def get_output_embeddings(self):
2085
+ return self.lm_head
2086
+
2087
+ def set_output_embeddings(self, new_embeddings):
2088
+ self.lm_head = new_embeddings
2089
+
2090
+ def set_decoder(self, decoder):
2091
+ self.model = decoder
2092
+
2093
+ def get_decoder(self):
2094
+ return self.model
2095
+
2096
+ def forward(
2097
+ self,
2098
+ input_ids: torch.LongTensor = None, # input sequence
2099
+ attention_mask: Optional[torch.Tensor] = None, # masked MHSA + padding
2100
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # embedding
2101
+ encoder_attention_mask: Optional[torch.FloatTensor] = None, # MHSA + padding
2102
+ head_mask: Optional[torch.Tensor] = None,
2103
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
2104
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2105
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2106
+ labels: Optional[torch.LongTensor] = None, # output sequence
2107
+ use_cache: Optional[bool] = None,
2108
+ output_attentions: Optional[bool] = None,
2109
+ output_hidden_states: Optional[bool] = None,
2110
+ return_dict: Optional[bool] = None,
2111
+ **kwargs,
2112
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
2113
+
2114
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2115
+ output_hidden_states = (
2116
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2117
+ )
2118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2119
+
2120
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2121
+ outputs = self.model(
2122
+ input_ids=input_ids,
2123
+ attention_mask=attention_mask,
2124
+ encoder_hidden_states=encoder_hidden_states,
2125
+ encoder_attention_mask=encoder_attention_mask,
2126
+ head_mask=head_mask,
2127
+ cross_attn_head_mask=cross_attn_head_mask,
2128
+ past_key_values=past_key_values,
2129
+ inputs_embeds=inputs_embeds,
2130
+ use_cache=use_cache,
2131
+ output_attentions=output_attentions,
2132
+ output_hidden_states=output_hidden_states,
2133
+ return_dict=return_dict,
2134
+ **kwargs
2135
+ )
2136
+
2137
+ logits = self.lm_head(outputs[0])
2138
+
2139
+ loss = None
2140
+ if labels is not None:
2141
+ labels = labels.to(logits.device)
2142
+ loss_fct = CrossEntropyLoss()
2143
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2144
+
2145
+ if not return_dict:
2146
+ output = (logits,) + outputs[1:]
2147
+ return (loss,) + output if loss is not None else output
2148
+
2149
+ return CausalLMOutputWithCrossAttentions(
2150
+ loss=loss,
2151
+ logits=logits,
2152
+ past_key_values=outputs.past_key_values,
2153
+ hidden_states=outputs.hidden_states,
2154
+ attentions=outputs.attentions,
2155
+ cross_attentions=outputs.cross_attentions,
2156
+ )
2157
+
2158
+ @staticmethod
2159
+ def _reorder_cache(past_key_values, beam_idx):
2160
+ reordered_past = ()
2161
+ for layer_past in past_key_values:
2162
+ reordered_past += (
2163
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
2164
+ )
2165
+ return reordered_past
2166
+