Upload STLForCausalLM
Browse files- config.json +4 -0
- configuration_stldec.py +68 -0
- modeling_stldec.py +2166 -0
config.json
CHANGED
|
@@ -5,6 +5,10 @@
|
|
| 5 |
"STLForCausalLM"
|
| 6 |
],
|
| 7 |
"attention_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"bos_token_id": 2,
|
| 9 |
"d_model": 1024,
|
| 10 |
"decoder_attention_heads": 16,
|
|
|
|
| 5 |
"STLForCausalLM"
|
| 6 |
],
|
| 7 |
"attention_dropout": 0.0,
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "configuration_stldec.STLConfig",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_stldec.STLForCausalLM"
|
| 11 |
+
},
|
| 12 |
"bos_token_id": 2,
|
| 13 |
"d_model": 1024,
|
| 14 |
"decoder_attention_heads": 16,
|
configuration_stldec.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class STLConfig(PretrainedConfig):
|
| 4 |
+
|
| 5 |
+
model_type = "stldec"
|
| 6 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 7 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
vocab_size=35,
|
| 12 |
+
decoder_vocab_size=None, # unused
|
| 13 |
+
max_position_embeddings=1024,
|
| 14 |
+
encoder_layers=12,
|
| 15 |
+
encoder_ffn_dim=4096,
|
| 16 |
+
encoder_attention_heads=16,
|
| 17 |
+
decoder_layers=12,
|
| 18 |
+
decoder_ffn_dim=4096,
|
| 19 |
+
decoder_attention_heads=16,
|
| 20 |
+
encoder_layerdrop=0.0,
|
| 21 |
+
decoder_layerdrop=0.0,
|
| 22 |
+
use_cache=True,
|
| 23 |
+
is_encoder_decoder=True,
|
| 24 |
+
activation_function="gelu",
|
| 25 |
+
d_model=1024,
|
| 26 |
+
dropout=0.1,
|
| 27 |
+
attention_dropout=0.0,
|
| 28 |
+
activation_dropout=0.0,
|
| 29 |
+
init_std=0.02,
|
| 30 |
+
decoder_start_token_id=3,
|
| 31 |
+
scale_embedding=False,
|
| 32 |
+
pad_token_id=1,
|
| 33 |
+
eos_token_id=3,
|
| 34 |
+
bos_token_id=2,
|
| 35 |
+
forced_eos_token_id=3,
|
| 36 |
+
share_encoder_decoder_embeddings=True,
|
| 37 |
+
**kwargs,
|
| 38 |
+
):
|
| 39 |
+
self.vocab_size = vocab_size
|
| 40 |
+
self.decoder_vocab_size = decoder_vocab_size or vocab_size
|
| 41 |
+
self.max_position_embeddings = max_position_embeddings
|
| 42 |
+
self.d_model = d_model
|
| 43 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 44 |
+
self.encoder_layers = encoder_layers
|
| 45 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 46 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 47 |
+
self.decoder_layers = decoder_layers
|
| 48 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 49 |
+
self.dropout = dropout
|
| 50 |
+
self.attention_dropout = attention_dropout
|
| 51 |
+
self.activation_dropout = activation_dropout
|
| 52 |
+
self.activation_function = activation_function
|
| 53 |
+
self.init_std = init_std
|
| 54 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 55 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 56 |
+
self.use_cache = use_cache
|
| 57 |
+
self.num_hidden_layers = encoder_layers
|
| 58 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
| 59 |
+
self.share_encoder_decoder_embeddings = share_encoder_decoder_embeddings
|
| 60 |
+
super().__init__(
|
| 61 |
+
bos_token_id=bos_token_id,
|
| 62 |
+
pad_token_id=pad_token_id,
|
| 63 |
+
eos_token_id=eos_token_id,
|
| 64 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 65 |
+
decoder_start_token_id=decoder_start_token_id,
|
| 66 |
+
forced_eos_token_id=forced_eos_token_id,
|
| 67 |
+
**kwargs,
|
| 68 |
+
)
|
modeling_stldec.py
ADDED
|
@@ -0,0 +1,2166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import copy
|
| 3 |
+
import math
|
| 4 |
+
import pickle
|
| 5 |
+
import os
|
| 6 |
+
from collections import deque
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import torch
|
| 12 |
+
import torch.utils.checkpoint
|
| 13 |
+
from torch import nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
|
| 17 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
| 19 |
+
from transformers.generation import GenerationMixin
|
| 20 |
+
from transformers.utils import (
|
| 21 |
+
add_end_docstrings,
|
| 22 |
+
add_start_docstrings,
|
| 23 |
+
add_start_docstrings_to_model_forward,
|
| 24 |
+
logging,
|
| 25 |
+
replace_return_docstrings,
|
| 26 |
+
)
|
| 27 |
+
from transformers.modeling_outputs import (
|
| 28 |
+
BaseModelOutput,
|
| 29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 30 |
+
CausalLMOutputWithCrossAttentions,
|
| 31 |
+
Seq2SeqLMOutput,
|
| 32 |
+
Seq2SeqModelOutput,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from .configuration_stldec import STLConfig
|
| 36 |
+
from nltk.translate.bleu_score import sentence_bleu
|
| 37 |
+
# from stl import *
|
| 38 |
+
import networkx as nx
|
| 39 |
+
from datasets import load_dataset
|
| 40 |
+
|
| 41 |
+
### from custom_typing.py
|
| 42 |
+
|
| 43 |
+
realnum = Union[float, int]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
### from stl.py
|
| 47 |
+
|
| 48 |
+
# For tensor functions
|
| 49 |
+
import torch
|
| 50 |
+
from torch import Tensor
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def eventually(x: Tensor, time_span: int) -> Tensor:
|
| 55 |
+
"""
|
| 56 |
+
STL operator 'eventually' in 1D.
|
| 57 |
+
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
x: torch.Tensor
|
| 61 |
+
Signal
|
| 62 |
+
time_span: any numeric type
|
| 63 |
+
Timespan duration
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
torch.Tensor
|
| 68 |
+
A tensor containing the result of the operation.
|
| 69 |
+
"""
|
| 70 |
+
return F.max_pool1d(x, kernel_size=time_span, stride=1)
|
| 71 |
+
|
| 72 |
+
class Node:
|
| 73 |
+
"""Abstract node class for STL semantics tree."""
|
| 74 |
+
|
| 75 |
+
def __init__(self) -> None:
|
| 76 |
+
# Must be overloaded.
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def __str__(self) -> str:
|
| 80 |
+
# Must be overloaded.
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
def boolean(self, x: Tensor, evaluate_at_all_times: bool = False) -> Tensor:
|
| 84 |
+
"""
|
| 85 |
+
Evaluates the boolean semantics at the node.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
|
| 90 |
+
The input signals, stored as a batch tensor with trhee dimensions.
|
| 91 |
+
evaluate_at_all_times: bool
|
| 92 |
+
Whether to evaluate the semantics at all times (True) or
|
| 93 |
+
just at t=0 (False).
|
| 94 |
+
|
| 95 |
+
Returns
|
| 96 |
+
-------
|
| 97 |
+
torch.Tensor
|
| 98 |
+
A tensor with the boolean semantics for the node.
|
| 99 |
+
"""
|
| 100 |
+
z: Tensor = self._boolean(x)
|
| 101 |
+
if evaluate_at_all_times:
|
| 102 |
+
return z
|
| 103 |
+
else:
|
| 104 |
+
return self._extract_semantics_at_time_zero(z)
|
| 105 |
+
|
| 106 |
+
def quantitative(
|
| 107 |
+
self,
|
| 108 |
+
x: Tensor,
|
| 109 |
+
normalize: bool = False,
|
| 110 |
+
evaluate_at_all_times: bool = False,
|
| 111 |
+
) -> Tensor:
|
| 112 |
+
"""
|
| 113 |
+
Evaluates the quantitative semantics at the node.
|
| 114 |
+
|
| 115 |
+
Parameters
|
| 116 |
+
----------
|
| 117 |
+
x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
|
| 118 |
+
The input signals, stored as a batch tensor with three dimensions.
|
| 119 |
+
normalize: bool
|
| 120 |
+
Whether the measure of robustness if normalized (True) or
|
| 121 |
+
not (False). Currently not in use.
|
| 122 |
+
evaluate_at_all_times: bool
|
| 123 |
+
Whether to evaluate the semantics at all times (True) or
|
| 124 |
+
just at t=0 (False).
|
| 125 |
+
|
| 126 |
+
Returns
|
| 127 |
+
-------
|
| 128 |
+
torch.Tensor
|
| 129 |
+
A tensor with the quantitative semantics for the node.
|
| 130 |
+
"""
|
| 131 |
+
z: Tensor = self._quantitative(x, normalize)
|
| 132 |
+
if evaluate_at_all_times:
|
| 133 |
+
return z
|
| 134 |
+
else:
|
| 135 |
+
return self._extract_semantics_at_time_zero(z)
|
| 136 |
+
|
| 137 |
+
def set_normalizing_flag(self, value: bool = True) -> None:
|
| 138 |
+
"""
|
| 139 |
+
Setter for the 'normalization of robustness of the formula' flag.
|
| 140 |
+
Currently not in use.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def time_depth(self) -> int:
|
| 144 |
+
"""Returns time depth of bounded temporal operators only."""
|
| 145 |
+
# Must be overloaded.
|
| 146 |
+
|
| 147 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 148 |
+
"""Private method equivalent to public one for inner call."""
|
| 149 |
+
# Must be overloaded.
|
| 150 |
+
|
| 151 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 152 |
+
"""Private method equivalent to public one for inner call."""
|
| 153 |
+
# Must be overloaded.
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _extract_semantics_at_time_zero(x: Tensor) -> Tensor:
|
| 157 |
+
"""Extrapolates the vector of truth values at time zero"""
|
| 158 |
+
return torch.reshape(x[:, 0, 0], (-1,))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Atom(Node):
|
| 162 |
+
"""Atomic formula node; for now of the form X<=t or X>=t"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, var_index: int, threshold: realnum, lte: bool = False) -> None:
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.var_index: int = var_index
|
| 167 |
+
self.threshold: realnum = threshold
|
| 168 |
+
self.lte: bool = lte
|
| 169 |
+
|
| 170 |
+
def __str__(self) -> str:
|
| 171 |
+
s: str = (
|
| 172 |
+
"x_"
|
| 173 |
+
+ str(self.var_index)
|
| 174 |
+
+ (" <= " if self.lte else " >= ")
|
| 175 |
+
+ str(round(self.threshold, 4))
|
| 176 |
+
)
|
| 177 |
+
return s
|
| 178 |
+
|
| 179 |
+
def time_depth(self) -> int:
|
| 180 |
+
return 0
|
| 181 |
+
|
| 182 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 183 |
+
# extract tensor of the same dimension as data, but with only one variable
|
| 184 |
+
xj: Tensor = x[:, self.var_index, :]
|
| 185 |
+
xj: Tensor = xj.view(xj.size()[0], 1, -1)
|
| 186 |
+
if self.lte:
|
| 187 |
+
z: Tensor = torch.le(xj, self.threshold)
|
| 188 |
+
else:
|
| 189 |
+
z: Tensor = torch.ge(xj, self.threshold)
|
| 190 |
+
return z
|
| 191 |
+
|
| 192 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 193 |
+
# extract tensor of the same dimension as data, but with only one variable
|
| 194 |
+
xj: Tensor = x[:, self.var_index, :]
|
| 195 |
+
xj: Tensor = xj.view(xj.size()[0], 1, -1)
|
| 196 |
+
if self.lte:
|
| 197 |
+
z: Tensor = -xj + self.threshold
|
| 198 |
+
else:
|
| 199 |
+
z: Tensor = xj - self.threshold
|
| 200 |
+
if normalize:
|
| 201 |
+
z: Tensor = torch.tanh(z)
|
| 202 |
+
return z
|
| 203 |
+
|
| 204 |
+
class Not(Node):
|
| 205 |
+
"""Negation node."""
|
| 206 |
+
|
| 207 |
+
def __init__(self, child: Node) -> None:
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.child: Node = child
|
| 210 |
+
|
| 211 |
+
def __str__(self) -> str:
|
| 212 |
+
s: str = "not ( " + self.child.__str__() + " )"
|
| 213 |
+
return s
|
| 214 |
+
|
| 215 |
+
def time_depth(self) -> int:
|
| 216 |
+
return self.child.time_depth()
|
| 217 |
+
|
| 218 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 219 |
+
z: Tensor = ~self.child._boolean(x)
|
| 220 |
+
return z
|
| 221 |
+
|
| 222 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 223 |
+
z: Tensor = -self.child._quantitative(x, normalize)
|
| 224 |
+
return z
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class And(Node):
|
| 228 |
+
"""Conjunction node."""
|
| 229 |
+
|
| 230 |
+
def __init__(self, left_child: Node, right_child: Node) -> None:
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.left_child: Node = left_child
|
| 233 |
+
self.right_child: Node = right_child
|
| 234 |
+
|
| 235 |
+
def __str__(self) -> str:
|
| 236 |
+
s: str = (
|
| 237 |
+
"( "
|
| 238 |
+
+ self.left_child.__str__()
|
| 239 |
+
+ " and "
|
| 240 |
+
+ self.right_child.__str__()
|
| 241 |
+
+ " )"
|
| 242 |
+
)
|
| 243 |
+
return s
|
| 244 |
+
|
| 245 |
+
def time_depth(self) -> int:
|
| 246 |
+
return max(self.left_child.time_depth(), self.right_child.time_depth())
|
| 247 |
+
|
| 248 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 249 |
+
z1: Tensor = self.left_child._boolean(x)
|
| 250 |
+
z2: Tensor = self.right_child._boolean(x)
|
| 251 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 252 |
+
z1: Tensor = z1[:, :, :size]
|
| 253 |
+
z2: Tensor = z2[:, :, :size]
|
| 254 |
+
z: Tensor = torch.logical_and(z1, z2)
|
| 255 |
+
return z
|
| 256 |
+
|
| 257 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 258 |
+
z1: Tensor = self.left_child._quantitative(x, normalize)
|
| 259 |
+
z2: Tensor = self.right_child._quantitative(x, normalize)
|
| 260 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 261 |
+
z1: Tensor = z1[:, :, :size]
|
| 262 |
+
z2: Tensor = z2[:, :, :size]
|
| 263 |
+
z: Tensor = torch.min(z1, z2)
|
| 264 |
+
return z
|
| 265 |
+
|
| 266 |
+
class Not(Node):
|
| 267 |
+
"""Negation node."""
|
| 268 |
+
|
| 269 |
+
def __init__(self, child: Node) -> None:
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.child: Node = child
|
| 272 |
+
|
| 273 |
+
def __str__(self) -> str:
|
| 274 |
+
s: str = "not ( " + self.child.__str__() + " )"
|
| 275 |
+
return s
|
| 276 |
+
|
| 277 |
+
def time_depth(self) -> int:
|
| 278 |
+
return self.child.time_depth()
|
| 279 |
+
|
| 280 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 281 |
+
z: Tensor = ~self.child._boolean(x)
|
| 282 |
+
return z
|
| 283 |
+
|
| 284 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 285 |
+
z: Tensor = -self.child._quantitative(x, normalize)
|
| 286 |
+
return z
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class And(Node):
|
| 290 |
+
"""Conjunction node."""
|
| 291 |
+
|
| 292 |
+
def __init__(self, left_child: Node, right_child: Node) -> None:
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.left_child: Node = left_child
|
| 295 |
+
self.right_child: Node = right_child
|
| 296 |
+
|
| 297 |
+
def __str__(self) -> str:
|
| 298 |
+
s: str = (
|
| 299 |
+
"( "
|
| 300 |
+
+ self.left_child.__str__()
|
| 301 |
+
+ " and "
|
| 302 |
+
+ self.right_child.__str__()
|
| 303 |
+
+ " )"
|
| 304 |
+
)
|
| 305 |
+
return s
|
| 306 |
+
|
| 307 |
+
def time_depth(self) -> int:
|
| 308 |
+
return max(self.left_child.time_depth(), self.right_child.time_depth())
|
| 309 |
+
|
| 310 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 311 |
+
z1: Tensor = self.left_child._boolean(x)
|
| 312 |
+
z2: Tensor = self.right_child._boolean(x)
|
| 313 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 314 |
+
z1: Tensor = z1[:, :, :size]
|
| 315 |
+
z2: Tensor = z2[:, :, :size]
|
| 316 |
+
z: Tensor = torch.logical_and(z1, z2)
|
| 317 |
+
return z
|
| 318 |
+
|
| 319 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 320 |
+
z1: Tensor = self.left_child._quantitative(x, normalize)
|
| 321 |
+
z2: Tensor = self.right_child._quantitative(x, normalize)
|
| 322 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 323 |
+
z1: Tensor = z1[:, :, :size]
|
| 324 |
+
z2: Tensor = z2[:, :, :size]
|
| 325 |
+
z: Tensor = torch.min(z1, z2)
|
| 326 |
+
return z
|
| 327 |
+
|
| 328 |
+
class Or(Node):
|
| 329 |
+
"""Disjunction node."""
|
| 330 |
+
|
| 331 |
+
def __init__(self, left_child: Node, right_child: Node) -> None:
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.left_child: Node = left_child
|
| 334 |
+
self.right_child: Node = right_child
|
| 335 |
+
|
| 336 |
+
def __str__(self) -> str:
|
| 337 |
+
s: str = (
|
| 338 |
+
"( "
|
| 339 |
+
+ self.left_child.__str__()
|
| 340 |
+
+ " or "
|
| 341 |
+
+ self.right_child.__str__()
|
| 342 |
+
+ " )"
|
| 343 |
+
)
|
| 344 |
+
return s
|
| 345 |
+
|
| 346 |
+
def time_depth(self) -> int:
|
| 347 |
+
return max(self.left_child.time_depth(), self.right_child.time_depth())
|
| 348 |
+
|
| 349 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 350 |
+
z1: Tensor = self.left_child._boolean(x)
|
| 351 |
+
z2: Tensor = self.right_child._boolean(x)
|
| 352 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 353 |
+
z1: Tensor = z1[:, :, :size]
|
| 354 |
+
z2: Tensor = z2[:, :, :size]
|
| 355 |
+
z: Tensor = torch.logical_or(z1, z2)
|
| 356 |
+
return z
|
| 357 |
+
|
| 358 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 359 |
+
z1: Tensor = self.left_child._quantitative(x, normalize)
|
| 360 |
+
z2: Tensor = self.right_child._quantitative(x, normalize)
|
| 361 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 362 |
+
z1: Tensor = z1[:, :, :size]
|
| 363 |
+
z2: Tensor = z2[:, :, :size]
|
| 364 |
+
z: Tensor = torch.max(z1, z2)
|
| 365 |
+
return z
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class Globally(Node):
|
| 369 |
+
"""Globally node."""
|
| 370 |
+
def __init__(
|
| 371 |
+
self,
|
| 372 |
+
child: Node,
|
| 373 |
+
unbound: bool = False,
|
| 374 |
+
right_unbound: bool = False,
|
| 375 |
+
left_time_bound: int = 0,
|
| 376 |
+
right_time_bound: int = 1,
|
| 377 |
+
adapt_unbound: bool = True,
|
| 378 |
+
) -> None:
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.child: Node = child
|
| 381 |
+
self.unbound: bool = unbound
|
| 382 |
+
self.right_unbound: bool = right_unbound
|
| 383 |
+
self.left_time_bound: int = left_time_bound
|
| 384 |
+
self.right_time_bound: int = right_time_bound + 1
|
| 385 |
+
self.adapt_unbound: bool = adapt_unbound
|
| 386 |
+
|
| 387 |
+
def __str__(self) -> str:
|
| 388 |
+
s_left = "[" + str(self.left_time_bound) + ","
|
| 389 |
+
s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
|
| 390 |
+
s0: str = s_left + s_right + "]" if not self.unbound else ""
|
| 391 |
+
s: str = "always" + s0 + " ( " + self.child.__str__() + " )"
|
| 392 |
+
return s
|
| 393 |
+
|
| 394 |
+
def time_depth(self) -> int:
|
| 395 |
+
if self.unbound:
|
| 396 |
+
return self.child.time_depth()
|
| 397 |
+
elif self.right_unbound:
|
| 398 |
+
return self.child.time_depth() + self.left_time_bound
|
| 399 |
+
else:
|
| 400 |
+
# diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
|
| 401 |
+
return self.child.time_depth() + self.right_time_bound - 1
|
| 402 |
+
# (self.right_time_bound - self.left_time_bound + 1) - diff
|
| 403 |
+
|
| 404 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 405 |
+
z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) # nested temporal parameters
|
| 406 |
+
# z1 = z1[:, :, self.left_time_bound:]
|
| 407 |
+
if self.unbound or self.right_unbound:
|
| 408 |
+
if self.adapt_unbound:
|
| 409 |
+
z: Tensor
|
| 410 |
+
_: Tensor
|
| 411 |
+
z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
|
| 412 |
+
z: Tensor = torch.flip(z, [2])
|
| 413 |
+
else:
|
| 414 |
+
z: Tensor
|
| 415 |
+
_: Tensor
|
| 416 |
+
z, _ = torch.min(z1, 2, keepdim=True)
|
| 417 |
+
else:
|
| 418 |
+
z: Tensor = torch.ge(1.0 - eventually((~z1).double(), self.right_time_bound - self.left_time_bound), 0.5)
|
| 419 |
+
return z
|
| 420 |
+
|
| 421 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 422 |
+
z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
|
| 423 |
+
# z1 = z1[:, :, self.left_time_bound:]
|
| 424 |
+
if self.unbound or self.right_unbound:
|
| 425 |
+
if self.adapt_unbound:
|
| 426 |
+
z: Tensor
|
| 427 |
+
_: Tensor
|
| 428 |
+
z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
|
| 429 |
+
z: Tensor = torch.flip(z, [2])
|
| 430 |
+
else:
|
| 431 |
+
z: Tensor
|
| 432 |
+
_: Tensor
|
| 433 |
+
z, _ = torch.min(z1, 2, keepdim=True)
|
| 434 |
+
else:
|
| 435 |
+
z: Tensor = -eventually(-z1, self.right_time_bound - self.left_time_bound)
|
| 436 |
+
return z
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class Eventually(Node):
|
| 441 |
+
"""Eventually node."""
|
| 442 |
+
|
| 443 |
+
def __init__(
|
| 444 |
+
self,
|
| 445 |
+
child: Node,
|
| 446 |
+
unbound: bool = False,
|
| 447 |
+
right_unbound: bool = False,
|
| 448 |
+
left_time_bound: int = 0,
|
| 449 |
+
right_time_bound: int = 1,
|
| 450 |
+
adapt_unbound: bool = True,
|
| 451 |
+
) -> None:
|
| 452 |
+
super().__init__()
|
| 453 |
+
self.child: Node = child
|
| 454 |
+
self.unbound: bool = unbound
|
| 455 |
+
self.right_unbound: bool = right_unbound
|
| 456 |
+
self.left_time_bound: int = left_time_bound
|
| 457 |
+
self.right_time_bound: int = right_time_bound + 1
|
| 458 |
+
self.adapt_unbound: bool = adapt_unbound
|
| 459 |
+
|
| 460 |
+
if (self.unbound is False) and (self.right_unbound is False) and \
|
| 461 |
+
(self.right_time_bound <= self.left_time_bound):
|
| 462 |
+
raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
|
| 463 |
+
|
| 464 |
+
def __str__(self) -> str:
|
| 465 |
+
s_left = "[" + str(self.left_time_bound) + ","
|
| 466 |
+
s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
|
| 467 |
+
s0: str = s_left + s_right + "]" if not self.unbound else ""
|
| 468 |
+
s: str = "eventually" + s0 + " ( " + self.child.__str__() + " )"
|
| 469 |
+
return s
|
| 470 |
+
|
| 471 |
+
def time_depth(self) -> int:
|
| 472 |
+
if self.unbound:
|
| 473 |
+
return self.child.time_depth()
|
| 474 |
+
elif self.right_unbound:
|
| 475 |
+
return self.child.time_depth() + self.left_time_bound
|
| 476 |
+
else:
|
| 477 |
+
# diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
|
| 478 |
+
return self.child.time_depth() + self.right_time_bound - 1
|
| 479 |
+
# (self.right_time_bound - self.left_time_bound + 1) - diff
|
| 480 |
+
|
| 481 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 482 |
+
z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:])
|
| 483 |
+
if self.unbound or self.right_unbound:
|
| 484 |
+
if self.adapt_unbound:
|
| 485 |
+
z: Tensor
|
| 486 |
+
_: Tensor
|
| 487 |
+
z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
|
| 488 |
+
z: Tensor = torch.flip(z, [2])
|
| 489 |
+
else:
|
| 490 |
+
z: Tensor
|
| 491 |
+
_: Tensor
|
| 492 |
+
z, _ = torch.max(z1, 2, keepdim=True)
|
| 493 |
+
else:
|
| 494 |
+
z: Tensor = torch.ge(eventually(z1.double(), self.right_time_bound - self.left_time_bound), 0.5)
|
| 495 |
+
return z
|
| 496 |
+
|
| 497 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 498 |
+
z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
|
| 499 |
+
if self.unbound or self.right_unbound:
|
| 500 |
+
if self.adapt_unbound:
|
| 501 |
+
z: Tensor
|
| 502 |
+
_: Tensor
|
| 503 |
+
z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
|
| 504 |
+
z: Tensor = torch.flip(z, [2])
|
| 505 |
+
else:
|
| 506 |
+
z: Tensor
|
| 507 |
+
_: Tensor
|
| 508 |
+
z, _ = torch.max(z1, 2, keepdim=True)
|
| 509 |
+
else:
|
| 510 |
+
z: Tensor = eventually(z1, self.right_time_bound - self.left_time_bound)
|
| 511 |
+
return z
|
| 512 |
+
|
| 513 |
+
class Until(Node):
|
| 514 |
+
"""Until node."""
|
| 515 |
+
|
| 516 |
+
def __init__(
|
| 517 |
+
self,
|
| 518 |
+
left_child: Node,
|
| 519 |
+
right_child: Node,
|
| 520 |
+
unbound: bool = False,
|
| 521 |
+
right_unbound: bool = False,
|
| 522 |
+
left_time_bound: int = 0,
|
| 523 |
+
right_time_bound: int = 1,
|
| 524 |
+
) -> None:
|
| 525 |
+
super().__init__()
|
| 526 |
+
self.left_child: Node = left_child
|
| 527 |
+
self.right_child: Node = right_child
|
| 528 |
+
self.unbound: bool = unbound
|
| 529 |
+
self.right_unbound: bool = right_unbound
|
| 530 |
+
self.left_time_bound: int = left_time_bound
|
| 531 |
+
self.right_time_bound: int = right_time_bound + 1
|
| 532 |
+
|
| 533 |
+
if (self.unbound is False) and (self.right_unbound is False) and \
|
| 534 |
+
(self.right_time_bound <= self.left_time_bound):
|
| 535 |
+
raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
|
| 536 |
+
|
| 537 |
+
def __str__(self) -> str:
|
| 538 |
+
s_left = "[" + str(self.left_time_bound) + ","
|
| 539 |
+
s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
|
| 540 |
+
s0: str = s_left + s_right + "]" if not self.unbound else ""
|
| 541 |
+
s: str = "( " + self.left_child.__str__() + " until" + s0 + " " + self.right_child.__str__() + " )"
|
| 542 |
+
return s
|
| 543 |
+
|
| 544 |
+
def time_depth(self) -> int:
|
| 545 |
+
sum_children_depth: int = self.left_child.time_depth() + self.right_child.time_depth()
|
| 546 |
+
if self.unbound:
|
| 547 |
+
return sum_children_depth
|
| 548 |
+
elif self.right_unbound:
|
| 549 |
+
return sum_children_depth + self.left_time_bound
|
| 550 |
+
else:
|
| 551 |
+
# diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
|
| 552 |
+
return sum_children_depth + self.right_time_bound - 1
|
| 553 |
+
# (self.right_time_bound - self.left_time_bound + 1) - diff
|
| 554 |
+
|
| 555 |
+
def _boolean(self, x: Tensor) -> Tensor:
|
| 556 |
+
if self.unbound:
|
| 557 |
+
z1: Tensor = self.left_child._boolean(x)
|
| 558 |
+
z2: Tensor = self.right_child._boolean(x)
|
| 559 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 560 |
+
z1: Tensor = z1[:, :, :size]
|
| 561 |
+
z2: Tensor = z2[:, :, :size]
|
| 562 |
+
z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
|
| 563 |
+
z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
|
| 564 |
+
z1_triu = torch.triu(z1_rep)
|
| 565 |
+
z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
|
| 566 |
+
|
| 567 |
+
z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
|
| 568 |
+
z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
|
| 569 |
+
z2_triu = torch.triu(z2_rep)
|
| 570 |
+
z2_def = z2_tril + z2_triu
|
| 571 |
+
z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
|
| 572 |
+
dim=-1)[0]
|
| 573 |
+
elif self.right_unbound:
|
| 574 |
+
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
|
| 575 |
+
And(Eventually(self.right_child, right_unbound=True,
|
| 576 |
+
left_time_bound=self.left_time_bound),
|
| 577 |
+
Eventually(Until(self.left_child, self.right_child, unbound=True),
|
| 578 |
+
left_time_bound=self.left_time_bound, right_unbound=True)))
|
| 579 |
+
z: Tensor = timed_until._boolean(x)
|
| 580 |
+
else:
|
| 581 |
+
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
|
| 582 |
+
And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
|
| 583 |
+
right_time_bound=self.right_time_bound - 1),
|
| 584 |
+
Eventually(Until(self.left_child, self.right_child, unbound=True),
|
| 585 |
+
left_time_bound=self.left_time_bound, right_unbound=True)))
|
| 586 |
+
z: Tensor = timed_until._boolean(x)
|
| 587 |
+
return z
|
| 588 |
+
|
| 589 |
+
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
|
| 590 |
+
if self.unbound:
|
| 591 |
+
z1: Tensor = self.left_child._quantitative(x, normalize)
|
| 592 |
+
z2: Tensor = self.right_child._quantitative(x, normalize)
|
| 593 |
+
size: int = min(z1.size()[2], z2.size()[2])
|
| 594 |
+
z1: Tensor = z1[:, :, :size]
|
| 595 |
+
z2: Tensor = z2[:, :, :size]
|
| 596 |
+
|
| 597 |
+
# z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
|
| 598 |
+
# z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
|
| 599 |
+
# z1_triu = torch.triu(z1_rep)
|
| 600 |
+
# z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
|
| 601 |
+
|
| 602 |
+
# z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
|
| 603 |
+
# z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
|
| 604 |
+
# z2_triu = torch.triu(z2_rep)
|
| 605 |
+
# z2_def = z2_tril + z2_triu
|
| 606 |
+
# z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
|
| 607 |
+
# dim=-1)[0]
|
| 608 |
+
z: Tensor = torch.cat([torch.max(torch.min(
|
| 609 |
+
torch.cat([torch.cummin(z1[:, :, t:].unsqueeze(-1), dim=2)[0], z2[:, :, t:].unsqueeze(-1)], dim=-1),
|
| 610 |
+
dim=-1)[0], dim=2, keepdim=True)[0] for t in range(size)], dim=2)
|
| 611 |
+
elif self.right_unbound:
|
| 612 |
+
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
|
| 613 |
+
And(Eventually(self.right_child, right_unbound=True,
|
| 614 |
+
left_time_bound=self.left_time_bound),
|
| 615 |
+
Eventually(Until(self.left_child, self.right_child, unbound=True),
|
| 616 |
+
left_time_bound=self.left_time_bound, right_unbound=True)))
|
| 617 |
+
z: Tensor = timed_until._quantitative(x, normalize=normalize)
|
| 618 |
+
else:
|
| 619 |
+
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
|
| 620 |
+
And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
|
| 621 |
+
right_time_bound=self.right_time_bound-1),
|
| 622 |
+
Eventually(Until(self.left_child, self.right_child, unbound=True),
|
| 623 |
+
left_time_bound=self.left_time_bound, right_unbound=True)))
|
| 624 |
+
z: Tensor = timed_until._quantitative(x, normalize=normalize)
|
| 625 |
+
return z
|
| 626 |
+
|
| 627 |
+
# from anchor_set_generation import anchorGeneration
|
| 628 |
+
|
| 629 |
+
import re
|
| 630 |
+
import json
|
| 631 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 632 |
+
from transformers import PreTrainedTokenizer
|
| 633 |
+
from transformers.utils import logging
|
| 634 |
+
|
| 635 |
+
logger = logging.get_logger(__name__)
|
| 636 |
+
|
| 637 |
+
#### utils ####
|
| 638 |
+
|
| 639 |
+
def load_pickle(path):
|
| 640 |
+
with open(path, 'rb') as f:
|
| 641 |
+
x = pickle.load(f)
|
| 642 |
+
return x
|
| 643 |
+
|
| 644 |
+
def dump_pickle(name, thing):
|
| 645 |
+
with open(name + '.pickle', 'wb') as f:
|
| 646 |
+
pickle.dump(thing, f)
|
| 647 |
+
|
| 648 |
+
def from_string_to_formula(st):
|
| 649 |
+
root_arity = 2 if st.startswith('(') else 1
|
| 650 |
+
st_split = st.split()
|
| 651 |
+
if root_arity <= 1:
|
| 652 |
+
root_op_str = copy.deepcopy(st_split[0])
|
| 653 |
+
if root_op_str.startswith('x'):
|
| 654 |
+
atom_sign = True if st_split[1] == '<=' else False
|
| 655 |
+
root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
|
| 656 |
+
return root_phi
|
| 657 |
+
else:
|
| 658 |
+
assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
|
| 659 |
+
or root_op_str.startswith('always'))
|
| 660 |
+
current_st = copy.deepcopy(st_split[2:-1])
|
| 661 |
+
if root_op_str == 'not':
|
| 662 |
+
root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
|
| 663 |
+
elif root_op_str.startswith('eventually'):
|
| 664 |
+
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
|
| 665 |
+
root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
|
| 666 |
+
right_unbound=right_unbound, left_time_bound=left_time_bound,
|
| 667 |
+
right_time_bound=right_time_bound)
|
| 668 |
+
else:
|
| 669 |
+
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
|
| 670 |
+
root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
|
| 671 |
+
right_unbound=right_unbound, left_time_bound=left_time_bound,
|
| 672 |
+
right_time_bound=right_time_bound)
|
| 673 |
+
else:
|
| 674 |
+
# 1 - delete everything which is contained in other sets of parenthesis (if any)
|
| 675 |
+
current_st = copy.deepcopy(st_split[1:-1])
|
| 676 |
+
if '(' in current_st:
|
| 677 |
+
par_queue = deque()
|
| 678 |
+
par_idx_list = []
|
| 679 |
+
for i, sub in enumerate(current_st):
|
| 680 |
+
if sub == '(':
|
| 681 |
+
par_queue.append(i)
|
| 682 |
+
elif sub == ')':
|
| 683 |
+
par_idx_list.append(tuple([par_queue.pop(), i]))
|
| 684 |
+
# open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
|
| 685 |
+
# union of parentheses range --> from these we may extract the substrings to be the children!!!
|
| 686 |
+
children_range = []
|
| 687 |
+
for begin, end in sorted(par_idx_list):
|
| 688 |
+
if children_range and children_range[-1][1] >= begin - 1:
|
| 689 |
+
children_range[-1][1] = max(children_range[-1][1], end)
|
| 690 |
+
else:
|
| 691 |
+
children_range.append([begin, end])
|
| 692 |
+
n_children = len(children_range)
|
| 693 |
+
assert (n_children in [1, 2])
|
| 694 |
+
if n_children == 1:
|
| 695 |
+
# one of the children is a variable --> need to individuate it
|
| 696 |
+
var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
|
| 697 |
+
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
|
| 698 |
+
children_range[0][0] -= 1
|
| 699 |
+
left_child_str = current_st[:3] if var_child_idx == 0 else \
|
| 700 |
+
current_st[children_range[0][0]:children_range[0][1] + 1]
|
| 701 |
+
right_child_str = current_st[-3:] if var_child_idx == 1 else \
|
| 702 |
+
current_st[children_range[0][0]:children_range[0][1] + 1]
|
| 703 |
+
root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
|
| 704 |
+
current_st[children_range[0][0] - 1]
|
| 705 |
+
assert (root_op_str[:2] in ['an', 'or', 'un'])
|
| 706 |
+
else:
|
| 707 |
+
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
|
| 708 |
+
children_range[0][0] -= 1
|
| 709 |
+
if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
|
| 710 |
+
children_range[1][0] -= 1
|
| 711 |
+
# if there are two children, with parentheses, the element in the middle is the root
|
| 712 |
+
root_op_str = current_st[children_range[0][1] + 1]
|
| 713 |
+
assert (root_op_str[:2] in ['an', 'or', 'un'])
|
| 714 |
+
left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
|
| 715 |
+
right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
|
| 716 |
+
else:
|
| 717 |
+
# no parentheses means that both children are variables
|
| 718 |
+
left_child_str = current_st[:3]
|
| 719 |
+
right_child_str = current_st[-3:]
|
| 720 |
+
root_op_str = current_st[3]
|
| 721 |
+
left_child_str = ' '.join(left_child_str)
|
| 722 |
+
right_child_str = ' '.join(right_child_str)
|
| 723 |
+
if root_op_str == 'and':
|
| 724 |
+
root_phi = And(left_child=from_string_to_formula(left_child_str),
|
| 725 |
+
right_child=from_string_to_formula(right_child_str))
|
| 726 |
+
elif root_op_str == 'or':
|
| 727 |
+
root_phi = Or(left_child=from_string_to_formula(left_child_str),
|
| 728 |
+
right_child=from_string_to_formula(right_child_str))
|
| 729 |
+
else:
|
| 730 |
+
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
|
| 731 |
+
root_phi = Until(left_child=from_string_to_formula(left_child_str),
|
| 732 |
+
right_child=from_string_to_formula(right_child_str),
|
| 733 |
+
unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
|
| 734 |
+
right_time_bound=right_time_bound)
|
| 735 |
+
return root_phi
|
| 736 |
+
|
| 737 |
+
def load_json(path: str) -> Union[Dict, List]:
|
| 738 |
+
"""
|
| 739 |
+
Load a JSON file from the given path.
|
| 740 |
+
Args:
|
| 741 |
+
path (str): The path to the JSON file to be loaded.
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
|
| 745 |
+
"""
|
| 746 |
+
with open(path, "r") as f:
|
| 747 |
+
return json.load(f)
|
| 748 |
+
|
| 749 |
+
#### phis_generator ####
|
| 750 |
+
|
| 751 |
+
class StlGenerator:
|
| 752 |
+
def __init__(
|
| 753 |
+
self,
|
| 754 |
+
leaf_prob: float = 0.3,
|
| 755 |
+
inner_node_prob: list = None,
|
| 756 |
+
threshold_mean: float = 0.0,
|
| 757 |
+
threshold_sd: float = 1.0,
|
| 758 |
+
unbound_prob: float = 0.1,
|
| 759 |
+
right_unbound_prob: float = 0.2,
|
| 760 |
+
time_bound_max_range: float = 20,
|
| 761 |
+
adaptive_unbound_temporal_ops: bool = True,
|
| 762 |
+
max_timespan: int = 100,
|
| 763 |
+
):
|
| 764 |
+
"""
|
| 765 |
+
leaf_prob
|
| 766 |
+
probability of generating a leaf (always zero for root)
|
| 767 |
+
node_types = ["not", "and", "or", "always", "eventually", "until"]
|
| 768 |
+
Inner node types
|
| 769 |
+
inner_node_prob
|
| 770 |
+
probability vector for the different types of internal nodes
|
| 771 |
+
threshold_mean
|
| 772 |
+
threshold_sd
|
| 773 |
+
mean and std for the normal distribution of the thresholds of atoms
|
| 774 |
+
unbound_prob
|
| 775 |
+
probability of a temporal operator to have a time bound o the type [0,infty]
|
| 776 |
+
time_bound_max_range
|
| 777 |
+
maximum value of time span of a temporal operator (i.e. max value of t in [0,t])
|
| 778 |
+
adaptive_unbound_temporal_ops
|
| 779 |
+
if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise
|
| 780 |
+
they are evaluated only at time zero.
|
| 781 |
+
max_timespan
|
| 782 |
+
maximum time depth of a formula.
|
| 783 |
+
"""
|
| 784 |
+
|
| 785 |
+
# Address the mutability of default arguments
|
| 786 |
+
if inner_node_prob is None:
|
| 787 |
+
inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166]
|
| 788 |
+
|
| 789 |
+
self.leaf_prob = leaf_prob
|
| 790 |
+
self.inner_node_prob = inner_node_prob
|
| 791 |
+
self.threshold_mean = threshold_mean
|
| 792 |
+
self.threshold_sd = threshold_sd
|
| 793 |
+
self.unbound_prob = unbound_prob
|
| 794 |
+
self.right_unbound_prob = right_unbound_prob
|
| 795 |
+
self.time_bound_max_range = time_bound_max_range
|
| 796 |
+
self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops
|
| 797 |
+
self.node_types = ["not", "and", "or", "always", "eventually", "until"]
|
| 798 |
+
self.max_timespan = max_timespan
|
| 799 |
+
|
| 800 |
+
def sample(self, nvars):
|
| 801 |
+
"""
|
| 802 |
+
Samples a random formula with distribution defined in class instance parameters
|
| 803 |
+
Parameters
|
| 804 |
+
----------
|
| 805 |
+
nvars : number of variables of input signals
|
| 806 |
+
how many variables the formula is expected to consider.
|
| 807 |
+
Returns
|
| 808 |
+
-------
|
| 809 |
+
TYPE
|
| 810 |
+
A random formula.
|
| 811 |
+
"""
|
| 812 |
+
return self._sample_internal_node(nvars)
|
| 813 |
+
def bag_sample(self, bag_size, nvars):
|
| 814 |
+
"""
|
| 815 |
+
Samples a bag of bag_size formulae
|
| 816 |
+
Parameters
|
| 817 |
+
----------
|
| 818 |
+
bag_size : INT
|
| 819 |
+
number of formulae.
|
| 820 |
+
nvars : INT
|
| 821 |
+
number of vars in formulae.
|
| 822 |
+
Returns
|
| 823 |
+
-------
|
| 824 |
+
a list of formulae.
|
| 825 |
+
"""
|
| 826 |
+
formulae = []
|
| 827 |
+
for _ in range(bag_size):
|
| 828 |
+
phi = self.sample(nvars)
|
| 829 |
+
formulae.append(phi)
|
| 830 |
+
return formulae
|
| 831 |
+
|
| 832 |
+
def _sample_internal_node(self, nvars):
|
| 833 |
+
# Declare & dummy-assign "idiom"
|
| 834 |
+
node: Union[None, Node]
|
| 835 |
+
node = None
|
| 836 |
+
# choose node type
|
| 837 |
+
nodetype = rnd.choice(self.node_types, p=self.inner_node_prob)
|
| 838 |
+
while True:
|
| 839 |
+
if nodetype == "not":
|
| 840 |
+
n = self._sample_node(nvars)
|
| 841 |
+
node = Not(n)
|
| 842 |
+
elif nodetype == "and":
|
| 843 |
+
n1 = self._sample_node(nvars)
|
| 844 |
+
n2 = self._sample_node(nvars)
|
| 845 |
+
node = And(n1, n2)
|
| 846 |
+
elif nodetype == "or":
|
| 847 |
+
n1 = self._sample_node(nvars)
|
| 848 |
+
n2 = self._sample_node(nvars)
|
| 849 |
+
node = Or(n1, n2)
|
| 850 |
+
elif nodetype == "always":
|
| 851 |
+
n = self._sample_node(nvars)
|
| 852 |
+
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
|
| 853 |
+
node = Globally(
|
| 854 |
+
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
|
| 855 |
+
)
|
| 856 |
+
elif nodetype == "eventually":
|
| 857 |
+
n = self._sample_node(nvars)
|
| 858 |
+
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
|
| 859 |
+
node = Eventually(
|
| 860 |
+
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
|
| 861 |
+
)
|
| 862 |
+
elif nodetype == "until":
|
| 863 |
+
n1 = self._sample_node(nvars)
|
| 864 |
+
n2 = self._sample_node(nvars)
|
| 865 |
+
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
|
| 866 |
+
node = Until(
|
| 867 |
+
n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
if (node is not None) and (node.time_depth() < self.max_timespan):
|
| 871 |
+
return node
|
| 872 |
+
|
| 873 |
+
def _sample_node(self, nvars):
|
| 874 |
+
if rnd.rand() < self.leaf_prob:
|
| 875 |
+
# sample a leaf
|
| 876 |
+
var, thr, lte = self._get_atom(nvars)
|
| 877 |
+
return Atom(var, thr, lte)
|
| 878 |
+
else:
|
| 879 |
+
return self._sample_internal_node(nvars)
|
| 880 |
+
|
| 881 |
+
def _get_temporal_parameters(self):
|
| 882 |
+
if rnd.rand() < self.unbound_prob:
|
| 883 |
+
return True, False, 0, 0
|
| 884 |
+
elif rnd.rand() < self.right_unbound_prob:
|
| 885 |
+
return False, True, rnd.randint(self.time_bound_max_range), 1
|
| 886 |
+
else:
|
| 887 |
+
left_bound = rnd.randint(self.time_bound_max_range)
|
| 888 |
+
return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1
|
| 889 |
+
|
| 890 |
+
def _get_atom(self, nvars):
|
| 891 |
+
variable = rnd.randint(nvars)
|
| 892 |
+
lte = rnd.rand() > 0.5
|
| 893 |
+
threshold = rnd.normal(self.threshold_mean, self.threshold_sd)
|
| 894 |
+
return variable, threshold, lte
|
| 895 |
+
|
| 896 |
+
#### traj_measure ####
|
| 897 |
+
|
| 898 |
+
class Measure:
|
| 899 |
+
def sample(self, samples=100000, varn=2, points=100):
|
| 900 |
+
# Must be overridden
|
| 901 |
+
pass
|
| 902 |
+
|
| 903 |
+
class BaseMeasure(Measure):
|
| 904 |
+
def __init__(
|
| 905 |
+
self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
|
| 906 |
+
):
|
| 907 |
+
"""
|
| 908 |
+
Parameters
|
| 909 |
+
----------
|
| 910 |
+
mu0 : mean of normal distribution of initial state, optional
|
| 911 |
+
The default is 0.0.
|
| 912 |
+
sigma0 : standard deviation of normal distribution of initial state, optional
|
| 913 |
+
The default is 1.0.
|
| 914 |
+
mu1 : DOUBLE, optional
|
| 915 |
+
mean of normal distribution of total variation. The default is 0.0.
|
| 916 |
+
sigma1 : standard deviation of normal distribution of total variation, optional
|
| 917 |
+
The default is 1.0.
|
| 918 |
+
q : DOUBLE, optional
|
| 919 |
+
probability of change of sign in derivative. The default is 0.1.
|
| 920 |
+
q0 : DOUBLE, optional
|
| 921 |
+
probability of initial sign of derivative. The default is 0.5.
|
| 922 |
+
device : 'cpu' or 'cuda', optional
|
| 923 |
+
device on which to run the algorithm. The default is 'cpu'.
|
| 924 |
+
Returns
|
| 925 |
+
-------
|
| 926 |
+
None.
|
| 927 |
+
"""
|
| 928 |
+
self.mu0 = mu0
|
| 929 |
+
self.sigma0 = sigma0
|
| 930 |
+
self.mu1 = mu1
|
| 931 |
+
self.sigma1 = sigma1
|
| 932 |
+
self.q = q
|
| 933 |
+
self.q0 = q0
|
| 934 |
+
self.device = device
|
| 935 |
+
|
| 936 |
+
def sample(self, samples=100000, varn=2, points=100):
|
| 937 |
+
"""
|
| 938 |
+
Samples a set of trajectories from the basic measure space, with parameters
|
| 939 |
+
passed to the sampler
|
| 940 |
+
Parameters
|
| 941 |
+
----------
|
| 942 |
+
points : INT, optional
|
| 943 |
+
number of points per trajectory, including initial one. The default is 1000.
|
| 944 |
+
samples : INT, optional
|
| 945 |
+
number of trajectories. The default is 100000.
|
| 946 |
+
varn : INT, optional
|
| 947 |
+
number of variables per trajectory. The default is 2.
|
| 948 |
+
Returns
|
| 949 |
+
-------
|
| 950 |
+
signal : samples x varn x points double pytorch tensor
|
| 951 |
+
The sampled signals.
|
| 952 |
+
"""
|
| 953 |
+
if self.device == "cuda" and not torch.cuda.is_available():
|
| 954 |
+
raise RuntimeError("GPU card or CUDA library not available!")
|
| 955 |
+
|
| 956 |
+
# generate unif RN
|
| 957 |
+
signal = torch.rand(samples, varn, points, device=self.device)
|
| 958 |
+
# first point is special - set to zero for the moment, and set one point to 1
|
| 959 |
+
signal[:, :, 0] = 0.0
|
| 960 |
+
signal[:, :, -1] = 1.0
|
| 961 |
+
# sorting each trajectory
|
| 962 |
+
signal, _ = torch.sort(signal, 2)
|
| 963 |
+
# computing increments and storing them in points 1 to end
|
| 964 |
+
signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1]
|
| 965 |
+
# generate initial state, according to a normal distribution
|
| 966 |
+
signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size())
|
| 967 |
+
|
| 968 |
+
# sampling change signs from bernoulli in -1, 1
|
| 969 |
+
derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device)
|
| 970 |
+
derivs = 2 * torch.bernoulli(derivs) - 1
|
| 971 |
+
# sampling initial derivative
|
| 972 |
+
derivs[:, :, 0] = self.q0
|
| 973 |
+
derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1
|
| 974 |
+
# taking the cumulative product along axis 2
|
| 975 |
+
derivs = torch.cumprod(derivs, 2)
|
| 976 |
+
|
| 977 |
+
# sampling total variation
|
| 978 |
+
totvar = torch.pow(
|
| 979 |
+
self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device),
|
| 980 |
+
2,
|
| 981 |
+
)
|
| 982 |
+
# multiplying total variation and derivatives and making initial point non-invasive
|
| 983 |
+
derivs = derivs * totvar
|
| 984 |
+
derivs[:, :, 0] = 1.0
|
| 985 |
+
|
| 986 |
+
# computing trajectories by multiplying and then doing a cumulative sum
|
| 987 |
+
signal = signal * derivs
|
| 988 |
+
signal = torch.cumsum(signal, 2)
|
| 989 |
+
return signal
|
| 990 |
+
|
| 991 |
+
#### kernel ####
|
| 992 |
+
|
| 993 |
+
realnum = Union[float, int]
|
| 994 |
+
|
| 995 |
+
class StlKernel:
|
| 996 |
+
def __init__(
|
| 997 |
+
self,
|
| 998 |
+
measure,
|
| 999 |
+
normalize=True,
|
| 1000 |
+
exp_kernel=True,
|
| 1001 |
+
sigma2=0.2, # 0.5 meglio, inizialmente era a 0.2
|
| 1002 |
+
integrate_time=False,
|
| 1003 |
+
samples=100000,
|
| 1004 |
+
varn=2,
|
| 1005 |
+
points=100,
|
| 1006 |
+
boolean=False,
|
| 1007 |
+
signals=None,
|
| 1008 |
+
):
|
| 1009 |
+
self.traj_measure = measure
|
| 1010 |
+
self.exp_kernel = exp_kernel
|
| 1011 |
+
self.normalize = normalize
|
| 1012 |
+
self.sigma2 = sigma2
|
| 1013 |
+
self.samples = samples
|
| 1014 |
+
self.varn = varn
|
| 1015 |
+
self.points = points
|
| 1016 |
+
self.integrate_time = integrate_time
|
| 1017 |
+
if signals is not None:
|
| 1018 |
+
self.signals = signals
|
| 1019 |
+
else:
|
| 1020 |
+
self.signals = measure.sample(points=points, samples=samples, varn=varn)
|
| 1021 |
+
self.boolean = boolean
|
| 1022 |
+
|
| 1023 |
+
def compute(self, phi1, phi2):
|
| 1024 |
+
return self.compute_one_one(phi1, phi2)
|
| 1025 |
+
|
| 1026 |
+
def compute_one_one(self, phi1, phi2):
|
| 1027 |
+
phis1: list = [phi1]
|
| 1028 |
+
phis2: list = [phi2]
|
| 1029 |
+
ker = self.compute_bag_bag(phis1, phis2)
|
| 1030 |
+
return ker[0, 0]
|
| 1031 |
+
|
| 1032 |
+
def compute_bag(self, phis, return_robustness=True):
|
| 1033 |
+
if self.integrate_time:
|
| 1034 |
+
rhos, selfk, len0 = self._compute_robustness_time(phis)
|
| 1035 |
+
kernel_matrix = self._compute_kernel_time(
|
| 1036 |
+
rhos, rhos, selfk, selfk, len0, len0
|
| 1037 |
+
)
|
| 1038 |
+
else:
|
| 1039 |
+
rhos, selfk = self._compute_robustness_no_time(phis)
|
| 1040 |
+
kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk)
|
| 1041 |
+
len0 = None
|
| 1042 |
+
if return_robustness:
|
| 1043 |
+
return kernel_matrix.cpu(), rhos, selfk, len0
|
| 1044 |
+
else:
|
| 1045 |
+
return kernel_matrix.cpu()
|
| 1046 |
+
|
| 1047 |
+
def compute_one_bag(self, phi1, phis2, return_robustness=False):
|
| 1048 |
+
phis1: list = [phi1]
|
| 1049 |
+
return self.compute_bag_bag(phis1, phis2, return_robustness)
|
| 1050 |
+
|
| 1051 |
+
def compute_bag_bag(self, phis1, phis2, return_robustness=False):
|
| 1052 |
+
if self.integrate_time:
|
| 1053 |
+
rhos1, selfk1, len1 = self._compute_robustness_time(phis1)
|
| 1054 |
+
rhos2, selfk2, len2 = self._compute_robustness_time(phis2)
|
| 1055 |
+
kernel_matrix = self._compute_kernel_time(
|
| 1056 |
+
rhos1, rhos2, selfk1, selfk2, len1, len2
|
| 1057 |
+
)
|
| 1058 |
+
else:
|
| 1059 |
+
rhos1, selfk1 = self._compute_robustness_no_time(phis1)
|
| 1060 |
+
rhos2, selfk2 = self._compute_robustness_no_time(phis2)
|
| 1061 |
+
len1, len2 = [None, None]
|
| 1062 |
+
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2)
|
| 1063 |
+
if return_robustness:
|
| 1064 |
+
return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2
|
| 1065 |
+
else:
|
| 1066 |
+
return kernel_matrix.cpu()
|
| 1067 |
+
|
| 1068 |
+
def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False):
|
| 1069 |
+
phis: list = [phi]
|
| 1070 |
+
return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness)
|
| 1071 |
+
|
| 1072 |
+
def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False):
|
| 1073 |
+
if self.integrate_time:
|
| 1074 |
+
rhos1, selfk1, len1 = self._compute_robustness_time(phis)
|
| 1075 |
+
kernel_matrix = self._compute_kernel_time(
|
| 1076 |
+
rhos1, rhos, selfk1, rho_self, len1, lengths
|
| 1077 |
+
)
|
| 1078 |
+
else:
|
| 1079 |
+
rhos1, selfk1 = self._compute_robustness_no_time(phis)
|
| 1080 |
+
len1 = None
|
| 1081 |
+
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self)
|
| 1082 |
+
if return_robustness:
|
| 1083 |
+
return kernel_matrix.cpu(), rhos1, selfk1, len1
|
| 1084 |
+
else:
|
| 1085 |
+
return kernel_matrix.cpu()
|
| 1086 |
+
n = self.samples
|
| 1087 |
+
p = self.points
|
| 1088 |
+
k = len(phis)
|
| 1089 |
+
rhos = torch.zeros((k, n, p), device="cpu")
|
| 1090 |
+
lengths = torch.zeros(k)
|
| 1091 |
+
self_kernels = torch.zeros((k, 1))
|
| 1092 |
+
for i, phi in enumerate(phis):
|
| 1093 |
+
if self.boolean:
|
| 1094 |
+
rho = phi.boolean(self.signals, evaluate_at_all_times=True).float()
|
| 1095 |
+
rho[rho == 0.0] = -1.0
|
| 1096 |
+
else:
|
| 1097 |
+
rho = phi.quantitative(self.signals, evaluate_at_all_times=True)
|
| 1098 |
+
actual_p = rho.size()[2]
|
| 1099 |
+
rho = rho.reshape(n, actual_p).cpu()
|
| 1100 |
+
rhos[i, :, :actual_p] = rho
|
| 1101 |
+
lengths[i] = actual_p
|
| 1102 |
+
self_kernels[i] = torch.tensordot(
|
| 1103 |
+
rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]]
|
| 1104 |
+
) / (actual_p * n)
|
| 1105 |
+
return rhos, self_kernels, lengths
|
| 1106 |
+
|
| 1107 |
+
def _compute_robustness_no_time(self, phis):
|
| 1108 |
+
n = self.samples
|
| 1109 |
+
k = len(phis)
|
| 1110 |
+
rhos = torch.zeros((k, n), device=self.traj_measure.device)
|
| 1111 |
+
self_kernels = torch.zeros((k, 1), device=self.traj_measure.device)
|
| 1112 |
+
for i, phi in enumerate(phis):
|
| 1113 |
+
if self.boolean:
|
| 1114 |
+
rho = phi.boolean(self.signals, evaluate_at_all_times=False).float()
|
| 1115 |
+
rho[rho == 0.0] = -1.0
|
| 1116 |
+
else:
|
| 1117 |
+
rho = phi.quantitative(self.signals, evaluate_at_all_times=False)
|
| 1118 |
+
self_kernels[i] = rho.dot(rho) / n
|
| 1119 |
+
rhos[i, :] = rho
|
| 1120 |
+
return rhos, self_kernels
|
| 1121 |
+
|
| 1122 |
+
def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2):
|
| 1123 |
+
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]])
|
| 1124 |
+
length_normalizer = self._compute_trajectory_length_normalizer(len1, len2)
|
| 1125 |
+
kernel_matrix = kernel_matrix * length_normalizer / self.samples
|
| 1126 |
+
if self.normalize:
|
| 1127 |
+
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
|
| 1128 |
+
if self.exp_kernel:
|
| 1129 |
+
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
|
| 1130 |
+
return kernel_matrix
|
| 1131 |
+
|
| 1132 |
+
def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2):
|
| 1133 |
+
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]])
|
| 1134 |
+
kernel_matrix = kernel_matrix / self.samples
|
| 1135 |
+
if self.normalize:
|
| 1136 |
+
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
|
| 1137 |
+
if self.exp_kernel:
|
| 1138 |
+
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
|
| 1139 |
+
return kernel_matrix
|
| 1140 |
+
|
| 1141 |
+
@staticmethod
|
| 1142 |
+
def _normalize(kernel_matrix, selfk1, selfk2):
|
| 1143 |
+
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
|
| 1144 |
+
kernel_matrix = kernel_matrix / normalize
|
| 1145 |
+
return kernel_matrix
|
| 1146 |
+
|
| 1147 |
+
@staticmethod
|
| 1148 |
+
def _normalize(kernel_matrix, selfk1, selfk2):
|
| 1149 |
+
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
|
| 1150 |
+
kernel_matrix = kernel_matrix / normalize
|
| 1151 |
+
return kernel_matrix
|
| 1152 |
+
|
| 1153 |
+
def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
|
| 1154 |
+
if sigma2 is None:
|
| 1155 |
+
sigma2 = self.sigma2
|
| 1156 |
+
if self.normalize:
|
| 1157 |
+
# selfk is (1.0^2 + 1.0^2)
|
| 1158 |
+
selfk = 2.0
|
| 1159 |
+
else:
|
| 1160 |
+
k1 = selfk1.size()[0]
|
| 1161 |
+
k2 = selfk2.size()[0]
|
| 1162 |
+
selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose(
|
| 1163 |
+
selfk2 * selfk2, 0, 1
|
| 1164 |
+
).repeat(k1, 1)
|
| 1165 |
+
return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2))
|
| 1166 |
+
|
| 1167 |
+
@staticmethod
|
| 1168 |
+
def _compute_trajectory_length_normalizer(len1, len2):
|
| 1169 |
+
k1 = len1.size()[0]
|
| 1170 |
+
k2 = len2.size()[0]
|
| 1171 |
+
y1 = len1.reshape(-1, 1)
|
| 1172 |
+
y1 = y1.repeat(1, k2)
|
| 1173 |
+
y2 = len2.repeat(k1, 1)
|
| 1174 |
+
return 1.0 / torch.min(y1, y2)
|
| 1175 |
+
|
| 1176 |
+
class GramMatrix:
|
| 1177 |
+
def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None):
|
| 1178 |
+
self.kernel = kernel
|
| 1179 |
+
self.formulae_list = formulae
|
| 1180 |
+
# if kernel is computed from robustness at time zero only,
|
| 1181 |
+
# we store the robustness for each formula and each sample
|
| 1182 |
+
# to speed up computation later
|
| 1183 |
+
self.store_robustness = store_robustness
|
| 1184 |
+
self.dim = len(self.formulae_list) if not bag_size else int(bag_size)
|
| 1185 |
+
self.sample = sample # whether to generate formulae in a controlled manner
|
| 1186 |
+
if self.sample:
|
| 1187 |
+
self.t = 0.99 if self.kernel.boolean else 0.85
|
| 1188 |
+
self.sampler = sampler # stl formulae generator
|
| 1189 |
+
self._compute_gram_matrix()
|
| 1190 |
+
|
| 1191 |
+
def _compute_gram_matrix(self):
|
| 1192 |
+
if self.sample:
|
| 1193 |
+
gram = torch.zeros(self.dim, self.dim)
|
| 1194 |
+
rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \
|
| 1195 |
+
not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points),
|
| 1196 |
+
device=self.kernel.traj_measure.device)
|
| 1197 |
+
lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim)
|
| 1198 |
+
kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device)
|
| 1199 |
+
phis = [self.sampler.sample(nvars=self.kernel.varn)]
|
| 1200 |
+
gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True)
|
| 1201 |
+
while len(phis) < self.dim:
|
| 1202 |
+
i = len(phis)
|
| 1203 |
+
phi = self.sampler.sample(nvars=self.kernel.varn)
|
| 1204 |
+
gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness(
|
| 1205 |
+
phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True)
|
| 1206 |
+
if torch.sum(gram[i, :i + 1] >= self.t) < 3:
|
| 1207 |
+
phis.append(phi)
|
| 1208 |
+
gram[:i, i] = gram[i, :i]
|
| 1209 |
+
gram[i, i] = kernels[i, :]
|
| 1210 |
+
|
| 1211 |
+
self.formulae_list = phis
|
| 1212 |
+
self.gram = gram.cpu()
|
| 1213 |
+
self.robustness = rhos if self.store_robustness else None
|
| 1214 |
+
self.self_kernels = kernels if self.store_robustness else None
|
| 1215 |
+
self.robustness_lengths = lengths if self.store_robustness else None
|
| 1216 |
+
else:
|
| 1217 |
+
if self.store_robustness:
|
| 1218 |
+
k_matrix, rhos, selfk, len0 = self.kernel.compute_bag(
|
| 1219 |
+
self.formulae_list, return_robustness=True
|
| 1220 |
+
)
|
| 1221 |
+
self.gram = k_matrix
|
| 1222 |
+
self.robustness = rhos
|
| 1223 |
+
self.self_kernels = selfk
|
| 1224 |
+
self.robustness_lengths = len0
|
| 1225 |
+
else:
|
| 1226 |
+
self.gram = self.kernel.compute_bag(
|
| 1227 |
+
self.formulae_list, return_robustness=False
|
| 1228 |
+
)
|
| 1229 |
+
self.robustness = None
|
| 1230 |
+
self.self_kernels = None
|
| 1231 |
+
self.robustness_lengths = None
|
| 1232 |
+
|
| 1233 |
+
def compute_kernel_vector(self, phi):
|
| 1234 |
+
if self.store_robustness:
|
| 1235 |
+
return self.kernel.compute_one_from_robustness(
|
| 1236 |
+
phi, self.robustness, self.self_kernels, self.robustness_lengths
|
| 1237 |
+
)
|
| 1238 |
+
else:
|
| 1239 |
+
return self.kernel.compute_one_bag(phi, self.formulae_list)
|
| 1240 |
+
|
| 1241 |
+
def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None):
|
| 1242 |
+
if generate_phis:
|
| 1243 |
+
gram_test = torch.zeros(bag_size, self.dim) # self.dim, bag_size
|
| 1244 |
+
rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \
|
| 1245 |
+
not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points),
|
| 1246 |
+
device=self.kernel.traj_measure.device)
|
| 1247 |
+
lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size)
|
| 1248 |
+
kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device)
|
| 1249 |
+
phi_test = []
|
| 1250 |
+
while len(phi_test) < bag_size:
|
| 1251 |
+
i = len(phi_test)
|
| 1252 |
+
phi = self.sampler.sample(nvars=self.kernel.varn)
|
| 1253 |
+
if self.store_robustness:
|
| 1254 |
+
gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \
|
| 1255 |
+
self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels,
|
| 1256 |
+
self.robustness_lengths, return_robustness=True)
|
| 1257 |
+
else:
|
| 1258 |
+
gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \
|
| 1259 |
+
self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True)
|
| 1260 |
+
if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()):
|
| 1261 |
+
phi_test.append(phi)
|
| 1262 |
+
return phi_test, gram_test.cpu()
|
| 1263 |
+
else:
|
| 1264 |
+
if self.store_robustness:
|
| 1265 |
+
return self.kernel.compute_bag_from_robustness(
|
| 1266 |
+
phis, self.robustness, self.self_kernels, self.robustness_lengths
|
| 1267 |
+
)
|
| 1268 |
+
else:
|
| 1269 |
+
return self.kernel.compute_bag_bag(phis, self.formulae_list)
|
| 1270 |
+
|
| 1271 |
+
def invert_regularized(self, alpha):
|
| 1272 |
+
regularizer = abs(pow(10, alpha)) * torch.eye(self.dim)
|
| 1273 |
+
return torch.inverse(self.gram + regularizer)
|
| 1274 |
+
|
| 1275 |
+
#### anchor_generation ####
|
| 1276 |
+
|
| 1277 |
+
def anchorGeneration(diff_init = False, # to control whether we want formulae to be semantically different by construction
|
| 1278 |
+
embed_dim: int = 30, # embedding dimension, aka number of generated formulae in the anchor set
|
| 1279 |
+
n_vars: int = 3, # dimension of the input signal (3D in this case)
|
| 1280 |
+
leaf_prob: float = 0.4, # complexity of the generated formula
|
| 1281 |
+
cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
|
| 1282 |
+
) -> str:
|
| 1283 |
+
|
| 1284 |
+
# initialize STL formula generator
|
| 1285 |
+
sampler = StlGenerator(leaf_prob)
|
| 1286 |
+
|
| 1287 |
+
# effective anchor set generation
|
| 1288 |
+
if diff_init:
|
| 1289 |
+
|
| 1290 |
+
# initialize the anchor set with a randomly sampled formula
|
| 1291 |
+
diff_anchor_set = [sampler.sample(nvars=n_vars)]
|
| 1292 |
+
|
| 1293 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 1294 |
+
mu = BaseMeasure(device=device)
|
| 1295 |
+
|
| 1296 |
+
# generates a set of random signals working as a tester for the formulae testing
|
| 1297 |
+
signals = mu.sample(samples=10000, varn=n_vars)
|
| 1298 |
+
|
| 1299 |
+
# computes robustness value for the initial set of formulae in the anchor set
|
| 1300 |
+
anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0)
|
| 1301 |
+
|
| 1302 |
+
while len(diff_anchor_set) < embed_dim:
|
| 1303 |
+
# sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
|
| 1304 |
+
candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
|
| 1305 |
+
|
| 1306 |
+
# compute robustness of candidate anchor formulae on the same signals as previous anchor set
|
| 1307 |
+
candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
|
| 1308 |
+
|
| 1309 |
+
# compute cosine similarity between current anchor set and candidate new formulae
|
| 1310 |
+
cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
|
| 1311 |
+
|
| 1312 |
+
# check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
|
| 1313 |
+
# NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
|
| 1314 |
+
similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
|
| 1315 |
+
|
| 1316 |
+
# keep only those who are semantically distant
|
| 1317 |
+
keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
|
| 1318 |
+
|
| 1319 |
+
diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
|
| 1320 |
+
|
| 1321 |
+
# Convert keep_idx to a tensor on the same device as candidate_robs
|
| 1322 |
+
keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
|
| 1323 |
+
|
| 1324 |
+
# Use index_select to pick the relevant rows
|
| 1325 |
+
selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
|
| 1326 |
+
|
| 1327 |
+
# Concatenate on the same device
|
| 1328 |
+
anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
|
| 1329 |
+
|
| 1330 |
+
anchor_set = diff_anchor_set[:embed_dim]
|
| 1331 |
+
|
| 1332 |
+
else:
|
| 1333 |
+
anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
|
| 1334 |
+
|
| 1335 |
+
filename = f'anchor_set_no_diff_{embed_dim}_dim'
|
| 1336 |
+
dump_pickle(filename, anchor_set)
|
| 1337 |
+
return filename
|
| 1338 |
+
|
| 1339 |
+
####
|
| 1340 |
+
|
| 1341 |
+
"""
|
| 1342 |
+
A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
|
| 1343 |
+
This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
|
| 1344 |
+
and handle padding and special tokens.
|
| 1345 |
+
"""
|
| 1346 |
+
|
| 1347 |
+
def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
|
| 1348 |
+
bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs):
|
| 1349 |
+
"""
|
| 1350 |
+
Initializes the STLTokenizer with a given vocabulary and special tokens.
|
| 1351 |
+
Args:
|
| 1352 |
+
vocab_path (str): The path to the JSON file containing the vocabulary.
|
| 1353 |
+
unk_token (str, optional): The token used for unknown words. Defaults to "unk".
|
| 1354 |
+
pad_token (str, optional): The token used for padding. Defaults to "pad".
|
| 1355 |
+
bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s".
|
| 1356 |
+
eos_token (str, optional): The token used for the end of a sequence. Defaults to "s".
|
| 1357 |
+
"""
|
| 1358 |
+
self.vocab = load_json(vocab_path)
|
| 1359 |
+
self.unk_token = unk_token
|
| 1360 |
+
self.pad_token = pad_token
|
| 1361 |
+
self.bos_token = bos_token
|
| 1362 |
+
self.eos_token = eos_token
|
| 1363 |
+
self.model_max_length = model_max_length
|
| 1364 |
+
self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
|
| 1365 |
+
super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token,
|
| 1366 |
+
model_max_length=model_max_length, *args, **kwargs)
|
| 1367 |
+
|
| 1368 |
+
@property
|
| 1369 |
+
def vocab_size(self) -> int:
|
| 1370 |
+
"""
|
| 1371 |
+
Returns the size of the vocabulary.
|
| 1372 |
+
Returns:
|
| 1373 |
+
int: The number of tokens in the vocabulary.
|
| 1374 |
+
"""
|
| 1375 |
+
return len(self.vocab)
|
| 1376 |
+
|
| 1377 |
+
def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
|
| 1378 |
+
"""
|
| 1379 |
+
Replaces spaces in the input sequence with a specified token.
|
| 1380 |
+
Args:
|
| 1381 |
+
sequence (str): The input sequence.
|
| 1382 |
+
undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
|
| 1383 |
+
Returns:
|
| 1384 |
+
str: The preprocessed sequence with spaces or padding tokens replaced.
|
| 1385 |
+
"""
|
| 1386 |
+
if undo:
|
| 1387 |
+
return sequence.replace(new_space_token, space_token)
|
| 1388 |
+
else:
|
| 1389 |
+
return sequence.replace(space_token, new_space_token)
|
| 1390 |
+
|
| 1391 |
+
def add_bos_eos(self, sequence: str) -> str:
|
| 1392 |
+
"""
|
| 1393 |
+
Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
|
| 1394 |
+
Args:
|
| 1395 |
+
sequence (str): La sequenza di input.
|
| 1396 |
+
Returns:
|
| 1397 |
+
str: La sequenza con i token BOS ed EOS.
|
| 1398 |
+
"""
|
| 1399 |
+
return f'{self.bos_token} {sequence} {self.eos_token}'
|
| 1400 |
+
|
| 1401 |
+
def tokenize(self, text: str) -> List[str]:
|
| 1402 |
+
"""
|
| 1403 |
+
Tokenizes the input text into a list of tokens.
|
| 1404 |
+
The method preprocesses the input text by replacing spaces with padding tokens and then tries to
|
| 1405 |
+
find the longest possible match for each substring in the vocabulary.
|
| 1406 |
+
Args:
|
| 1407 |
+
text (str): The input text to be tokenized.
|
| 1408 |
+
Returns:
|
| 1409 |
+
List[str]: A list of tokens representing the tokenized text.
|
| 1410 |
+
"""
|
| 1411 |
+
text = self.add_bos_eos(text)
|
| 1412 |
+
text = self.prepad_sequence(text)
|
| 1413 |
+
tokens = []
|
| 1414 |
+
i = 0
|
| 1415 |
+
while i < len(text):
|
| 1416 |
+
best_match = None
|
| 1417 |
+
for j in range(len(text), i, -1): # Try matching substrings of decreasing length
|
| 1418 |
+
subtoken = text[i:j]
|
| 1419 |
+
if subtoken in self.vocab:
|
| 1420 |
+
best_match = subtoken
|
| 1421 |
+
break
|
| 1422 |
+
if best_match:
|
| 1423 |
+
tokens.append(best_match)
|
| 1424 |
+
i += len(best_match)
|
| 1425 |
+
else:
|
| 1426 |
+
tokens.append(self.unk_token)
|
| 1427 |
+
i += 1
|
| 1428 |
+
return tokens
|
| 1429 |
+
|
| 1430 |
+
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
|
| 1431 |
+
"""
|
| 1432 |
+
Converts a list of tokens into a list of token IDs.
|
| 1433 |
+
Args:
|
| 1434 |
+
tokens (List[str]): A list of tokens to be converted into IDs.
|
| 1435 |
+
Returns:
|
| 1436 |
+
List[int]: A list of corresponding token IDs.
|
| 1437 |
+
"""
|
| 1438 |
+
return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
|
| 1439 |
+
|
| 1440 |
+
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
|
| 1441 |
+
"""
|
| 1442 |
+
Converts a list of token IDs into a list of tokens.
|
| 1443 |
+
Args:
|
| 1444 |
+
ids (List[int]): A list of token IDs to be converted into tokens.
|
| 1445 |
+
Returns:
|
| 1446 |
+
List[str]: A list of corresponding tokens.
|
| 1447 |
+
"""
|
| 1448 |
+
return [self.id_to_token.get(i, self.unk_token) for i in ids]
|
| 1449 |
+
|
| 1450 |
+
def encode(self, sequence: str) -> List[int]:
|
| 1451 |
+
"""
|
| 1452 |
+
Encodes a string sequence into a list of token IDs.
|
| 1453 |
+
|
| 1454 |
+
This method tokenizes the input sequence using the `tokenize` method,
|
| 1455 |
+
and then converts the resulting tokens into their corresponding token IDs
|
| 1456 |
+
using the `convert_tokens_to_ids` method.
|
| 1457 |
+
|
| 1458 |
+
Args:
|
| 1459 |
+
sequence (str): The input sequence (text) to be encoded.
|
| 1460 |
+
|
| 1461 |
+
Returns:
|
| 1462 |
+
List[int]: A list of token IDs corresponding to the input sequence.
|
| 1463 |
+
"""
|
| 1464 |
+
splitted_sequence = self.tokenize(sequence)
|
| 1465 |
+
return self.convert_tokens_to_ids(splitted_sequence)
|
| 1466 |
+
|
| 1467 |
+
def postpad_sequence(self, sequence, pad_token_id):
|
| 1468 |
+
"""
|
| 1469 |
+
Fills the sequence up to max_length padding elements
|
| 1470 |
+
"""
|
| 1471 |
+
num_extra_elements = self.model_max_length - len(sequence) -1
|
| 1472 |
+
if num_extra_elements > 0:
|
| 1473 |
+
sequence.extend([pad_token_id] * num_extra_elements)
|
| 1474 |
+
return sequence
|
| 1475 |
+
|
| 1476 |
+
def decode(self, token_ids: List[int]) -> str:
|
| 1477 |
+
"""
|
| 1478 |
+
Decodes a list of token IDs into a string of text.
|
| 1479 |
+
The method converts the IDs to tokens and joins them to form a string.
|
| 1480 |
+
It also restores the original spaces or padding tokens if `undo` is True.
|
| 1481 |
+
Args:
|
| 1482 |
+
token_ids (List[int]): A list of token IDs to be decoded.
|
| 1483 |
+
skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
|
| 1484 |
+
Returns:
|
| 1485 |
+
str: The decoded string.
|
| 1486 |
+
"""
|
| 1487 |
+
tokens = self.convert_ids_to_tokens(token_ids)
|
| 1488 |
+
decoded = "".join(tokens)
|
| 1489 |
+
return self.prepad_sequence(decoded, undo=True)
|
| 1490 |
+
|
| 1491 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 1492 |
+
"""
|
| 1493 |
+
Saves the tokenizer's vocabulary to a file.
|
| 1494 |
+
Useful only when the vocabulary has to be retrieved and is not given
|
| 1495 |
+
(thus this is not the case: here to further improvements with sentencepiece).
|
| 1496 |
+
This method saves the vocabulary to a JSON file in the specified directory.
|
| 1497 |
+
Args:
|
| 1498 |
+
save_directory (str): The directory where the vocabulary file will be saved.
|
| 1499 |
+
filename_prefix (Optional[str]): An optional prefix for the filename.
|
| 1500 |
+
Returns:
|
| 1501 |
+
Tuple[str]: A tuple containing the path to the saved vocabulary file.
|
| 1502 |
+
"""
|
| 1503 |
+
vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json"
|
| 1504 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 1505 |
+
json.dump(self.vocab, f, indent=2, ensure_ascii=False)
|
| 1506 |
+
return (vocab_file,)
|
| 1507 |
+
|
| 1508 |
+
def get_vocab(self) -> dict:
|
| 1509 |
+
"""
|
| 1510 |
+
Retrieves the vocabulary used by the tokenizer.
|
| 1511 |
+
Returns:
|
| 1512 |
+
dict: The vocabulary as a dictionary.
|
| 1513 |
+
"""
|
| 1514 |
+
return self.vocab
|
| 1515 |
+
|
| 1516 |
+
class STLSinusoidalPositionalEmbedding(nn.Embedding):
|
| 1517 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
| 1518 |
+
|
| 1519 |
+
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
| 1520 |
+
super().__init__(num_positions, embedding_dim)
|
| 1521 |
+
self.weight = self._init_weight(self.weight)
|
| 1522 |
+
|
| 1523 |
+
@staticmethod
|
| 1524 |
+
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
| 1525 |
+
"""
|
| 1526 |
+
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
| 1527 |
+
the 2nd half of the vector. [dim // 2:]
|
| 1528 |
+
"""
|
| 1529 |
+
n_pos, dim = out.shape
|
| 1530 |
+
position_enc = np.array(
|
| 1531 |
+
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
| 1532 |
+
)
|
| 1533 |
+
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
|
| 1534 |
+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
|
| 1535 |
+
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
|
| 1536 |
+
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
| 1537 |
+
out.detach_()
|
| 1538 |
+
return out
|
| 1539 |
+
@torch.no_grad()
|
| 1540 |
+
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
|
| 1541 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
| 1542 |
+
bsz, seq_len = input_ids_shape[:2]
|
| 1543 |
+
positions = torch.arange(
|
| 1544 |
+
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
| 1545 |
+
)
|
| 1546 |
+
return super().forward(positions)
|
| 1547 |
+
|
| 1548 |
+
class STLAttention(nn.Module):
|
| 1549 |
+
""" Multi-Head Attention as depicted from 'Attention is all you need' """
|
| 1550 |
+
|
| 1551 |
+
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
|
| 1552 |
+
is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
|
| 1553 |
+
|
| 1554 |
+
super().__init__()
|
| 1555 |
+
self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
|
| 1556 |
+
self.num_heads = num_heads
|
| 1557 |
+
self.dropout = dropout
|
| 1558 |
+
self.head_dim = embed_dim // num_heads
|
| 1559 |
+
assert (self.head_dim * num_heads) == self.embed_dim
|
| 1560 |
+
self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
|
| 1561 |
+
self.is_decoder = is_decoder
|
| 1562 |
+
self.is_causal = is_causal
|
| 1563 |
+
|
| 1564 |
+
# 'roleplaying' matrices
|
| 1565 |
+
self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
|
| 1566 |
+
self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
|
| 1567 |
+
self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
|
| 1568 |
+
|
| 1569 |
+
# to project the heads' outputs into a single vector
|
| 1570 |
+
self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
|
| 1571 |
+
|
| 1572 |
+
|
| 1573 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
| 1574 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 1575 |
+
|
| 1576 |
+
|
| 1577 |
+
def forward(self,
|
| 1578 |
+
hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
|
| 1579 |
+
key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
|
| 1580 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
|
| 1581 |
+
attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
|
| 1582 |
+
layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
|
| 1583 |
+
output_attentions: bool = False # flag to control the output of the attn values,
|
| 1584 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 1585 |
+
|
| 1586 |
+
is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
|
| 1587 |
+
|
| 1588 |
+
batch_size, tgt_len, embed_dim = hidden_states.size()
|
| 1589 |
+
|
| 1590 |
+
# Project the current input in the `query` role:
|
| 1591 |
+
query = self.W_q(hidden_states) * self.scaling
|
| 1592 |
+
|
| 1593 |
+
if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]):
|
| 1594 |
+
key = past_key_value[0]
|
| 1595 |
+
value = past_key_value[1]
|
| 1596 |
+
elif is_cross_attention:
|
| 1597 |
+
key = self._shape(self.W_k(key_value_states), -1, batch_size)
|
| 1598 |
+
value = self._shape(self.W_v(key_value_states), -1, batch_size)
|
| 1599 |
+
elif past_key_value is not None:
|
| 1600 |
+
key = self._shape(self.W_k(hidden_states), -1, batch_size)
|
| 1601 |
+
value = self._shape(self.W_v(hidden_states), -1, batch_size)
|
| 1602 |
+
key = torch.cat([past_key_value[0], key], dim=2)
|
| 1603 |
+
value = torch.cat([past_key_value[1], value], dim=2)
|
| 1604 |
+
else:
|
| 1605 |
+
key = self._shape(self.W_k(hidden_states), -1, batch_size)
|
| 1606 |
+
value = self._shape(self.W_v(hidden_states), -1, batch_size)
|
| 1607 |
+
if self.is_decoder:
|
| 1608 |
+
past_key_value = (key, value)
|
| 1609 |
+
|
| 1610 |
+
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
| 1611 |
+
|
| 1612 |
+
query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
|
| 1613 |
+
key = key.reshape(*proj_shape)
|
| 1614 |
+
value = value.reshape(*proj_shape)
|
| 1615 |
+
|
| 1616 |
+
src_len = key.size(1)
|
| 1617 |
+
|
| 1618 |
+
|
| 1619 |
+
######################################################################################################
|
| 1620 |
+
|
| 1621 |
+
# 'traditional' attention computation
|
| 1622 |
+
# i.e. softmax(Q*K^T / sqrt(d_model) + self_attn_mask) * V
|
| 1623 |
+
|
| 1624 |
+
# Batch-wise matrix multiplication between `query` and (TRANSPOSED) `key`
|
| 1625 |
+
attn_weights = torch.bmm(query, key.transpose(1, 2))
|
| 1626 |
+
|
| 1627 |
+
if attention_mask is not None:
|
| 1628 |
+
attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
|
| 1629 |
+
attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
|
| 1630 |
+
|
| 1631 |
+
# Normalize values on the `key` axis (dim=-1)
|
| 1632 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 1633 |
+
|
| 1634 |
+
# if layer_head_mask is not None:
|
| 1635 |
+
# attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(batch_size, self.num_heads, tgt_len, src_len)
|
| 1636 |
+
# attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
|
| 1637 |
+
|
| 1638 |
+
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 1639 |
+
|
| 1640 |
+
# Batch-wise matrix multiplication between the resulting probs and the value
|
| 1641 |
+
attn_output = torch.bmm(attn_probs, value)
|
| 1642 |
+
|
| 1643 |
+
######################################################################################################
|
| 1644 |
+
|
| 1645 |
+
attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
|
| 1646 |
+
attn_output = attn_output.transpose(1, 2)
|
| 1647 |
+
|
| 1648 |
+
attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
|
| 1649 |
+
attn_output = self.W_o(attn_output)
|
| 1650 |
+
|
| 1651 |
+
return attn_output, None, past_key_value
|
| 1652 |
+
|
| 1653 |
+
####
|
| 1654 |
+
|
| 1655 |
+
class STLEncoder():
|
| 1656 |
+
def __init__(self,
|
| 1657 |
+
embed_dim: int,
|
| 1658 |
+
anchor_filename: Optional[str] = None,
|
| 1659 |
+
n_vars: int = 3):
|
| 1660 |
+
|
| 1661 |
+
self.n_vars = n_vars # passaglielo in input
|
| 1662 |
+
self.embed_dim = embed_dim
|
| 1663 |
+
self.anchorset_filename = anchor_filename
|
| 1664 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 1665 |
+
self.mu = BaseMeasure(device=self.device)
|
| 1666 |
+
self.kernel = StlKernel(self.mu, varn=self.n_vars)
|
| 1667 |
+
|
| 1668 |
+
if anchor_filename is None:
|
| 1669 |
+
anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars)
|
| 1670 |
+
anchor_filename+='.pickle'
|
| 1671 |
+
|
| 1672 |
+
# TO DO: check on the dimensions of the anchor set and the `embed_dim` and `n_vars` values
|
| 1673 |
+
anchor_set = load_pickle(anchor_filename)
|
| 1674 |
+
if len(anchor_set) != self.embed_dim:
|
| 1675 |
+
raise ValueError("The anchor set and the embedding dimension do not match!")
|
| 1676 |
+
|
| 1677 |
+
self.anchor_set = anchor_set
|
| 1678 |
+
|
| 1679 |
+
def compute_embeddings(self, formula: List[str]):
|
| 1680 |
+
return self.kernel.compute_bag_bag(formula, self.anchor_set)
|
| 1681 |
+
|
| 1682 |
+
class STLModel(PreTrainedModel):
|
| 1683 |
+
config_class = STLConfig
|
| 1684 |
+
base_model_prefix = "model"
|
| 1685 |
+
supports_gradient_checkpointing = True
|
| 1686 |
+
|
| 1687 |
+
# initializes the weights of `nn.Linear`, `nn.Embedding` and `STLSinusoidalPositionalEmbedding`
|
| 1688 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, STLSinusoidalPositionalEmbedding]):
|
| 1689 |
+
std = self.config.init_std
|
| 1690 |
+
if isinstance(module, nn.Linear):
|
| 1691 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1692 |
+
if module.bias is not None:
|
| 1693 |
+
module.bias.data.zero_()
|
| 1694 |
+
elif isinstance(module, STLSinusoidalPositionalEmbedding):
|
| 1695 |
+
pass
|
| 1696 |
+
elif isinstance(module, nn.Embedding):
|
| 1697 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1698 |
+
if module.padding_idx is not None:
|
| 1699 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1700 |
+
|
| 1701 |
+
@property
|
| 1702 |
+
def dummy_inputs(self):
|
| 1703 |
+
pad_token = self.config.pad_token_id
|
| 1704 |
+
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
| 1705 |
+
dummy_inputs = {
|
| 1706 |
+
"attention_mask": input_ids.ne(pad_token),
|
| 1707 |
+
"input_ids": input_ids,
|
| 1708 |
+
"decoder_input_ids": input_ids,
|
| 1709 |
+
}
|
| 1710 |
+
return dummy_inputs
|
| 1711 |
+
|
| 1712 |
+
class STLDecoderBlock(nn.Module):
|
| 1713 |
+
|
| 1714 |
+
def __init__(self, embed_dim: int,
|
| 1715 |
+
num_decoder_attention_heads: int,
|
| 1716 |
+
num_decoder_ffn_dim: int,
|
| 1717 |
+
dropout: float = 0.0,
|
| 1718 |
+
attention_dropout: float = 0.0,
|
| 1719 |
+
activation_dropout: float = 0.0,
|
| 1720 |
+
):
|
| 1721 |
+
|
| 1722 |
+
super().__init__()
|
| 1723 |
+
|
| 1724 |
+
self.embed_dim = embed_dim
|
| 1725 |
+
|
| 1726 |
+
# first block
|
| 1727 |
+
self.self_attn = STLAttention(
|
| 1728 |
+
embed_dim=self.embed_dim,
|
| 1729 |
+
num_heads=num_decoder_attention_heads,
|
| 1730 |
+
dropout=dropout,
|
| 1731 |
+
is_decoder=True, # not used, debugging purposes
|
| 1732 |
+
is_causal=True, # not used, debugging purposes
|
| 1733 |
+
)
|
| 1734 |
+
self.dropout = dropout
|
| 1735 |
+
self.activation_fn = nn.functional.gelu
|
| 1736 |
+
self.activation_dropout = activation_dropout
|
| 1737 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1738 |
+
|
| 1739 |
+
# second block
|
| 1740 |
+
self.encoder_attn = STLAttention(
|
| 1741 |
+
self.embed_dim,
|
| 1742 |
+
num_decoder_attention_heads,
|
| 1743 |
+
dropout=attention_dropout,
|
| 1744 |
+
is_decoder=True, # not used, debugging purposes
|
| 1745 |
+
)
|
| 1746 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1747 |
+
|
| 1748 |
+
# third block
|
| 1749 |
+
self.fc1 = nn.Linear(self.embed_dim, num_decoder_ffn_dim)
|
| 1750 |
+
self.fc2 = nn.Linear(num_decoder_ffn_dim, self.embed_dim)
|
| 1751 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1752 |
+
|
| 1753 |
+
|
| 1754 |
+
def forward(
|
| 1755 |
+
self,
|
| 1756 |
+
hidden_states: torch.Tensor,
|
| 1757 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1758 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1759 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1760 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 1761 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
| 1762 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 1763 |
+
output_attentions: Optional[bool] = False,
|
| 1764 |
+
use_cache: Optional[bool] = True,
|
| 1765 |
+
**kwargs,
|
| 1766 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 1767 |
+
"""
|
| 1768 |
+
Args:
|
| 1769 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1770 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 1771 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 1772 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 1773 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1774 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
| 1775 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 1776 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 1777 |
+
`(encoder_attention_heads,)`.
|
| 1778 |
+
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
| 1779 |
+
size `(decoder_attention_heads,)`.
|
| 1780 |
+
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
| 1781 |
+
output_attentions (`bool`, *optional*):
|
| 1782 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1783 |
+
returned tensors for more detail.
|
| 1784 |
+
"""
|
| 1785 |
+
|
| 1786 |
+
###################################################################
|
| 1787 |
+
|
| 1788 |
+
# BLOCK 1: processing what has been previously generated
|
| 1789 |
+
|
| 1790 |
+
# previous state is stored into an auxiliary variable `residual`
|
| 1791 |
+
residual = hidden_states
|
| 1792 |
+
|
| 1793 |
+
# tries to exploit previous K, V values if there are any
|
| 1794 |
+
# (practically picks up to the first 2 values stored in `past_key_value` vector)
|
| 1795 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 1796 |
+
|
| 1797 |
+
# masked MHSA on the already generated sequence
|
| 1798 |
+
# invokes `forward` method to transform the original vector accordingly
|
| 1799 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
|
| 1800 |
+
hidden_states=hidden_states, # Q
|
| 1801 |
+
past_key_value=self_attn_past_key_value, # K, V
|
| 1802 |
+
attention_mask=attention_mask, # passed as input of the decoder layer
|
| 1803 |
+
layer_head_mask=layer_head_mask, # to deactivate certain attn layers
|
| 1804 |
+
output_attentions=output_attentions,
|
| 1805 |
+
)
|
| 1806 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1807 |
+
|
| 1808 |
+
# residual connection
|
| 1809 |
+
hidden_states = residual + hidden_states
|
| 1810 |
+
|
| 1811 |
+
# normalization
|
| 1812 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1813 |
+
|
| 1814 |
+
###################################################################
|
| 1815 |
+
|
| 1816 |
+
# BLOCK 2: cross-attn between already generated input and previous information (from the encoder)
|
| 1817 |
+
|
| 1818 |
+
# initialize K, Q, attn_weights for this new attn operation
|
| 1819 |
+
cross_attn_present_key_value = None
|
| 1820 |
+
cross_attn_weights = None
|
| 1821 |
+
|
| 1822 |
+
# the important condition is that the encoder carries some information
|
| 1823 |
+
if encoder_hidden_states is not None:
|
| 1824 |
+
|
| 1825 |
+
# previous state is stored into an auxiliary variable `residual`
|
| 1826 |
+
residual = hidden_states
|
| 1827 |
+
|
| 1828 |
+
# cross_attn cached key/values tuple is at positions 3, 4 of PAST_key_value tuple
|
| 1829 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
| 1830 |
+
|
| 1831 |
+
# MHSA in cross-attn
|
| 1832 |
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn.forward(
|
| 1833 |
+
hidden_states=hidden_states, # Q = generated output
|
| 1834 |
+
key_value_states=encoder_hidden_states, # K, V = encoder memory (used only in the 1st step when `use_cache = True`)
|
| 1835 |
+
attention_mask=encoder_attention_mask, # just pads some elements (not causal this time!)
|
| 1836 |
+
layer_head_mask=cross_attn_layer_head_mask, # again to mask certain heads
|
| 1837 |
+
past_key_value=cross_attn_past_key_value, # K, V = encoder CACHED memory (used from the 2nd step on when `use_cache = True`)
|
| 1838 |
+
output_attentions=output_attentions,
|
| 1839 |
+
)
|
| 1840 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1841 |
+
|
| 1842 |
+
# residual connection
|
| 1843 |
+
hidden_states = residual + hidden_states
|
| 1844 |
+
|
| 1845 |
+
# normalization
|
| 1846 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 1847 |
+
|
| 1848 |
+
# add cross-attn to positions 3, 4 of PRESENT_key_value tuple
|
| 1849 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
| 1850 |
+
|
| 1851 |
+
###################################################################
|
| 1852 |
+
|
| 1853 |
+
# BLOCK 3: FFNN (transforming some merged generated output - encoder information)
|
| 1854 |
+
|
| 1855 |
+
# previous state is stored into an auxiliary variable `residual`
|
| 1856 |
+
residual = hidden_states
|
| 1857 |
+
|
| 1858 |
+
# FFNN - core
|
| 1859 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1860 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 1861 |
+
hidden_states = self.fc2(hidden_states)
|
| 1862 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1863 |
+
|
| 1864 |
+
# residual connection
|
| 1865 |
+
hidden_states = residual + hidden_states
|
| 1866 |
+
|
| 1867 |
+
# normalization
|
| 1868 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 1869 |
+
|
| 1870 |
+
outputs = (hidden_states,)
|
| 1871 |
+
|
| 1872 |
+
if output_attentions:
|
| 1873 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 1874 |
+
|
| 1875 |
+
if use_cache: # otherwise it picks K and V each time
|
| 1876 |
+
outputs += (present_key_value,)
|
| 1877 |
+
|
| 1878 |
+
return outputs
|
| 1879 |
+
|
| 1880 |
+
class STLDecoder(STLModel):
|
| 1881 |
+
def __init__(self, config):
|
| 1882 |
+
super().__init__(config)
|
| 1883 |
+
|
| 1884 |
+
# Extract from `config` file
|
| 1885 |
+
embed_dim = config.d_model
|
| 1886 |
+
num_decoder_attention_heads = config.decoder_attention_heads
|
| 1887 |
+
num_decoder_ffn_dim = config.decoder_ffn_dim
|
| 1888 |
+
max_position_embeddings = config.max_position_embeddings
|
| 1889 |
+
decoder_vocab_size = config.vocab_size
|
| 1890 |
+
pad_token_id = config.pad_token_id
|
| 1891 |
+
num_decoder_layers = config.decoder_layers
|
| 1892 |
+
scale_embedding = config.scale_embedding
|
| 1893 |
+
dropout = config.dropout
|
| 1894 |
+
attention_dropout = config.attention_dropout
|
| 1895 |
+
activation_dropout = config.activation_dropout
|
| 1896 |
+
decoder_layerdrop = config.decoder_layerdrop
|
| 1897 |
+
|
| 1898 |
+
self.dropout = dropout
|
| 1899 |
+
self.layerdrop = decoder_layerdrop
|
| 1900 |
+
self.padding_idx = pad_token_id
|
| 1901 |
+
self.max_target_positions = max_position_embeddings
|
| 1902 |
+
self.embed_scale = math.sqrt(embed_dim) if scale_embedding else 1.0
|
| 1903 |
+
|
| 1904 |
+
# Initialize the input embedding (if not passed already)
|
| 1905 |
+
self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx)
|
| 1906 |
+
|
| 1907 |
+
# Initialize positional embedding also
|
| 1908 |
+
self.embed_positions = STLSinusoidalPositionalEmbedding(
|
| 1909 |
+
max_position_embeddings, embed_dim, self.padding_idx
|
| 1910 |
+
)
|
| 1911 |
+
|
| 1912 |
+
# Initialize decoder layers (of a prespecified number)
|
| 1913 |
+
self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads,
|
| 1914 |
+
num_decoder_ffn_dim, dropout,
|
| 1915 |
+
attention_dropout, activation_dropout)
|
| 1916 |
+
for _ in range(num_decoder_layers)])
|
| 1917 |
+
|
| 1918 |
+
self.gradient_checkpointing = False
|
| 1919 |
+
self.post_init()
|
| 1920 |
+
|
| 1921 |
+
def forward(
|
| 1922 |
+
self,
|
| 1923 |
+
input_ids: torch.LongTensor = None,
|
| 1924 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1925 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1926 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1927 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1928 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 1929 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1930 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1931 |
+
use_cache: Optional[bool] = None,
|
| 1932 |
+
output_attentions: Optional[bool] = None,
|
| 1933 |
+
output_hidden_states: Optional[bool] = None,
|
| 1934 |
+
return_dict: Optional[bool] = None,
|
| 1935 |
+
**kwargs,
|
| 1936 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 1937 |
+
|
| 1938 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1939 |
+
output_hidden_states = (
|
| 1940 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1941 |
+
)
|
| 1942 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1943 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1944 |
+
|
| 1945 |
+
# retrieve input_ids and inputs_embeds
|
| 1946 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1947 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 1948 |
+
elif input_ids is not None:
|
| 1949 |
+
input_shape = input_ids.size()
|
| 1950 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 1951 |
+
elif inputs_embeds is not None:
|
| 1952 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1953 |
+
else:
|
| 1954 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 1955 |
+
|
| 1956 |
+
# past_key_values_length
|
| 1957 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 1958 |
+
|
| 1959 |
+
if inputs_embeds is None:
|
| 1960 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 1961 |
+
|
| 1962 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1963 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 1964 |
+
)
|
| 1965 |
+
|
| 1966 |
+
# expand encoder attention mask
|
| 1967 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
| 1968 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1969 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
| 1970 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 1971 |
+
)
|
| 1972 |
+
|
| 1973 |
+
# embed positions
|
| 1974 |
+
positions = self.embed_positions(input_shape, past_key_values_length)
|
| 1975 |
+
|
| 1976 |
+
hidden_states = inputs_embeds + positions
|
| 1977 |
+
|
| 1978 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1979 |
+
|
| 1980 |
+
if self.gradient_checkpointing and self.training:
|
| 1981 |
+
if use_cache:
|
| 1982 |
+
logger.warning_once(
|
| 1983 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1984 |
+
)
|
| 1985 |
+
use_cache = False
|
| 1986 |
+
|
| 1987 |
+
# decoder layers
|
| 1988 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1989 |
+
all_self_attns = () if output_attentions else None
|
| 1990 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 1991 |
+
next_decoder_cache = () if use_cache else None
|
| 1992 |
+
|
| 1993 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 1994 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 1995 |
+
if output_hidden_states:
|
| 1996 |
+
all_hidden_states += (hidden_states,)
|
| 1997 |
+
if self.training:
|
| 1998 |
+
dropout_probability = torch.rand([])
|
| 1999 |
+
if dropout_probability < self.layerdrop:
|
| 2000 |
+
continue
|
| 2001 |
+
|
| 2002 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 2003 |
+
|
| 2004 |
+
if self.gradient_checkpointing and self.training:
|
| 2005 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 2006 |
+
decoder_layer.__call__,
|
| 2007 |
+
hidden_states,
|
| 2008 |
+
attention_mask,
|
| 2009 |
+
encoder_hidden_states,
|
| 2010 |
+
encoder_attention_mask,
|
| 2011 |
+
head_mask[idx] if head_mask is not None else None,
|
| 2012 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
| 2013 |
+
None,
|
| 2014 |
+
output_attentions,
|
| 2015 |
+
use_cache,
|
| 2016 |
+
)
|
| 2017 |
+
else:
|
| 2018 |
+
layer_outputs = decoder_layer(
|
| 2019 |
+
hidden_states,
|
| 2020 |
+
attention_mask=attention_mask,
|
| 2021 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 2022 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
| 2023 |
+
cross_attn_layer_head_mask=(
|
| 2024 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
| 2025 |
+
),
|
| 2026 |
+
past_key_value=past_key_value,
|
| 2027 |
+
output_attentions=output_attentions,
|
| 2028 |
+
use_cache=use_cache,
|
| 2029 |
+
)
|
| 2030 |
+
hidden_states = layer_outputs[0]
|
| 2031 |
+
|
| 2032 |
+
if use_cache:
|
| 2033 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
| 2034 |
+
|
| 2035 |
+
if output_attentions:
|
| 2036 |
+
all_self_attns += (layer_outputs[1],)
|
| 2037 |
+
|
| 2038 |
+
if encoder_hidden_states is not None:
|
| 2039 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 2040 |
+
|
| 2041 |
+
# add hidden states from the last decoder layer
|
| 2042 |
+
if output_hidden_states:
|
| 2043 |
+
all_hidden_states += (hidden_states,)
|
| 2044 |
+
|
| 2045 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 2046 |
+
if not return_dict:
|
| 2047 |
+
return tuple(
|
| 2048 |
+
v
|
| 2049 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
| 2050 |
+
if v is not None
|
| 2051 |
+
)
|
| 2052 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 2053 |
+
last_hidden_state=hidden_states,
|
| 2054 |
+
past_key_values=next_cache,
|
| 2055 |
+
hidden_states=all_hidden_states,
|
| 2056 |
+
attentions=all_self_attns,
|
| 2057 |
+
cross_attentions=all_cross_attentions,
|
| 2058 |
+
)
|
| 2059 |
+
|
| 2060 |
+
####
|
| 2061 |
+
|
| 2062 |
+
class STLForCausalLM(STLModel, GenerationMixin):
|
| 2063 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 2064 |
+
|
| 2065 |
+
def __init__(self, config):
|
| 2066 |
+
config = copy.deepcopy(config)
|
| 2067 |
+
config.is_decoder = True
|
| 2068 |
+
config.is_encoder_decoder = False
|
| 2069 |
+
|
| 2070 |
+
super().__init__(config)
|
| 2071 |
+
self.model = STLDecoder(config)
|
| 2072 |
+
|
| 2073 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 2074 |
+
|
| 2075 |
+
# Initialize weights and apply final processing
|
| 2076 |
+
self.post_init()
|
| 2077 |
+
|
| 2078 |
+
def get_input_embeddings(self):
|
| 2079 |
+
return self.model.embed_tokens
|
| 2080 |
+
|
| 2081 |
+
def set_input_embeddings(self, value):
|
| 2082 |
+
self.model.embed_tokens = value
|
| 2083 |
+
|
| 2084 |
+
def get_output_embeddings(self):
|
| 2085 |
+
return self.lm_head
|
| 2086 |
+
|
| 2087 |
+
def set_output_embeddings(self, new_embeddings):
|
| 2088 |
+
self.lm_head = new_embeddings
|
| 2089 |
+
|
| 2090 |
+
def set_decoder(self, decoder):
|
| 2091 |
+
self.model = decoder
|
| 2092 |
+
|
| 2093 |
+
def get_decoder(self):
|
| 2094 |
+
return self.model
|
| 2095 |
+
|
| 2096 |
+
def forward(
|
| 2097 |
+
self,
|
| 2098 |
+
input_ids: torch.LongTensor = None, # input sequence
|
| 2099 |
+
attention_mask: Optional[torch.Tensor] = None, # masked MHSA + padding
|
| 2100 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None, # embedding
|
| 2101 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None, # MHSA + padding
|
| 2102 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 2103 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 2104 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 2105 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 2106 |
+
labels: Optional[torch.LongTensor] = None, # output sequence
|
| 2107 |
+
use_cache: Optional[bool] = None,
|
| 2108 |
+
output_attentions: Optional[bool] = None,
|
| 2109 |
+
output_hidden_states: Optional[bool] = None,
|
| 2110 |
+
return_dict: Optional[bool] = None,
|
| 2111 |
+
**kwargs,
|
| 2112 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
| 2113 |
+
|
| 2114 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2115 |
+
output_hidden_states = (
|
| 2116 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 2117 |
+
)
|
| 2118 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2119 |
+
|
| 2120 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 2121 |
+
outputs = self.model(
|
| 2122 |
+
input_ids=input_ids,
|
| 2123 |
+
attention_mask=attention_mask,
|
| 2124 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 2125 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 2126 |
+
head_mask=head_mask,
|
| 2127 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
| 2128 |
+
past_key_values=past_key_values,
|
| 2129 |
+
inputs_embeds=inputs_embeds,
|
| 2130 |
+
use_cache=use_cache,
|
| 2131 |
+
output_attentions=output_attentions,
|
| 2132 |
+
output_hidden_states=output_hidden_states,
|
| 2133 |
+
return_dict=return_dict,
|
| 2134 |
+
**kwargs
|
| 2135 |
+
)
|
| 2136 |
+
|
| 2137 |
+
logits = self.lm_head(outputs[0])
|
| 2138 |
+
|
| 2139 |
+
loss = None
|
| 2140 |
+
if labels is not None:
|
| 2141 |
+
labels = labels.to(logits.device)
|
| 2142 |
+
loss_fct = CrossEntropyLoss()
|
| 2143 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 2144 |
+
|
| 2145 |
+
if not return_dict:
|
| 2146 |
+
output = (logits,) + outputs[1:]
|
| 2147 |
+
return (loss,) + output if loss is not None else output
|
| 2148 |
+
|
| 2149 |
+
return CausalLMOutputWithCrossAttentions(
|
| 2150 |
+
loss=loss,
|
| 2151 |
+
logits=logits,
|
| 2152 |
+
past_key_values=outputs.past_key_values,
|
| 2153 |
+
hidden_states=outputs.hidden_states,
|
| 2154 |
+
attentions=outputs.attentions,
|
| 2155 |
+
cross_attentions=outputs.cross_attentions,
|
| 2156 |
+
)
|
| 2157 |
+
|
| 2158 |
+
@staticmethod
|
| 2159 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 2160 |
+
reordered_past = ()
|
| 2161 |
+
for layer_past in past_key_values:
|
| 2162 |
+
reordered_past += (
|
| 2163 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 2164 |
+
)
|
| 2165 |
+
return reordered_past
|
| 2166 |
+
|