|
from typing import List |
|
from transformers import PretrainedConfig |
|
|
|
class LidirlCNNConfig(PretrainedConfig): |
|
model_type = "LidirlCNN" |
|
|
|
def __init__(self, |
|
embed_dim : int = 32, |
|
channels : List[int] = [32], |
|
kernels : List[int] = [3], |
|
strides : List[int] = [1], |
|
vocab_size: int = 256, |
|
label_size : int = 200, |
|
max_length : int = 1024, |
|
multilabel : bool = False, |
|
montecarlo_layer : bool = False, |
|
**kwargs, |
|
): |
|
self.embed_dim = embed_dim |
|
self.channels = channels |
|
self.kernels = kernels |
|
self.strides = strides |
|
|
|
self.vocab_size = vocab_size |
|
self.label_size = label_size |
|
self.max_length = max_length |
|
self.multilabel = multilabel |
|
self.montecarlo_layer = montecarlo_layer |
|
|
|
super().__init__(**kwargs) |