| from typing import Optional | |
| from transformers import AutoConfig, Gemma3TextConfig | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.modeling_rope_utils import rope_config_validation | |
| from transformers.utils import logging | |
| from transformers.models.siglip import SiglipVisionConfig | |
| logger = logging.get_logger(__name__) | |
| class AudioConfig(PretrainedConfig): | |
| model_type = "gemma3_audio" | |
| def __init__( | |
| self, | |
| input_size=80, | |
| attention_dim=1024, | |
| attention_heads=16, | |
| num_blocks=24, | |
| linear_units=1536, | |
| dropout_rate=0.0, | |
| kernel_size=3, | |
| ext_pw_kernel_size=1, | |
| ext_pw_out_channel=1024, | |
| depthwise_seperable_out_channel=1024, | |
| depthwise_multiplier=1, | |
| activation="swish", | |
| conv_activation="swish", | |
| conv_glu_type="swish", | |
| bias_in_glu=True, | |
| causal=True, | |
| batch_norm=False, | |
| cnn_layer_norm=True, | |
| time_reduction=8, | |
| input_layer="nemo_conv", | |
| nemo_conv_settings=None, | |
| chunk_size=-1, | |
| left_chunk=18, | |
| relative_attention_bias_args=None, | |
| activation_checkpointing=None, | |
| encoder_embedding_config=None, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.input_size = input_size | |
| self.attention_dim = attention_dim | |
| self.attention_heads = attention_heads | |
| self.num_blocks = num_blocks | |
| self.linear_units = linear_units | |
| self.dropout_rate = dropout_rate | |
| self.kernel_size = kernel_size | |
| self.ext_pw_kernel_size = ext_pw_kernel_size | |
| self.ext_pw_out_channel = ext_pw_out_channel | |
| self.depthwise_seperable_out_channel = depthwise_seperable_out_channel | |
| self.depthwise_multiplier = depthwise_multiplier | |
| self.activation = activation | |
| self.conv_activation = conv_activation | |
| self.conv_glu_type = conv_glu_type | |
| self.bias_in_glu = bias_in_glu | |
| self.causal = causal | |
| self.batch_norm = batch_norm | |
| self.cnn_layer_norm = cnn_layer_norm | |
| self.time_reduction = time_reduction | |
| self.input_layer = input_layer | |
| if nemo_conv_settings is None: | |
| self.nemo_conv_settings = {"conv_channels": 1024} | |
| else: | |
| self.nemo_conv_settings = nemo_conv_settings | |
| self.chunk_size = chunk_size | |
| self.left_chunk = left_chunk | |
| if relative_attention_bias_args is None: | |
| self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500} | |
| else: | |
| self.relative_attention_bias_args = relative_attention_bias_args | |
| if activation_checkpointing is None: | |
| self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False} | |
| else: | |
| self.activation_checkpointing = activation_checkpointing | |
| if encoder_embedding_config is None: | |
| self.encoder_embedding_config = {"input_size": input_size} | |
| else: | |
| self.encoder_embedding_config = encoder_embedding_config | |
| class Gemma3MMConfig(PretrainedConfig): | |
| r""" | |
| This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an | |
| Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration | |
| with the defaults will yield a similar configuration to that of the PaliGemma-2B. | |
| e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) | |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |
| documentation from [`PretrainedConfig`] for more information. | |
| Args: | |
| text_config (`Union[Gemma3TextConfig, dict]`, *optional*): | |
| The config object of the text backbone. | |
| vision_config (`Union[AutoConfig, dict]`, *optional*): | |
| Custom vision config or dict. | |
| audio_config (`Union[AutoConfig, dict]`, *optional*): | |
| Custom audio config or dict. | |
| mm_tokens_per_image (`int`, *optional*, defaults to 256): | |
| The number of tokens per image embedding. | |
| boi_token_index (`int`, *optional*, defaults to 255999): | |
| The begin-of-image token index to wrap the image prompt. | |
| eoi_token_index (`int`, *optional*, defaults to 256000): | |
| The end-of-image token index to wrap the image prompt. | |
| image_token_index (`int`, *optional*, defaults to 262144): | |
| The image token index to encode the image prompt. | |
| audio_token_index (`int`, *optional*, defaults to 262145): | |
| The audio token index to encode the audio prompt. | |
| initializer_range (`float`, *optional*, defaults to 0.02): | |
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |
| Example: | |
| ```python | |
| >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig | |
| >>> # Initializing a Siglip-like vision config | |
| >>> vision_config = SiglipVisionConfig() | |
| >>> # Initializing a Siglip-like vision config | |
| >>> audio_config = AudioConfig() | |
| >>> # Initializing a Gemma3 Text config | |
| >>> text_config = Gemma3TextConfig() | |
| >>> # Initializing a Gemma3 gemma-3-4b style configuration | |
| >>> configuration = Gemma3Config(vision_config, text_config) | |
| >>> # Initializing a model from the gemma-3-4b style configuration | |
| >>> model = Gemma3TextConfig(configuration) | |
| >>> # Accessing the model configuration | |
| >>> configuration = model.config | |
| ```""" | |
| model_type = "gemma3mm" | |
| sub_configs = { | |
| "text_config": Gemma3TextConfig, | |
| "vision_config": SiglipVisionConfig, | |
| "audio_config": AudioConfig, | |
| } | |
| def __init__( | |
| self, | |
| text_config: Optional[Gemma3TextConfig] = None, | |
| vision_config: Optional[SiglipVisionConfig] = None, | |
| audio_config: Optional[AudioConfig] = None, | |
| mm_tokens_per_image: int = 256, | |
| boi_token_index: int = 255_999, | |
| eoi_token_index: int = 256_000, | |
| boa_token_index: int = 256_001, | |
| eoa_token_index: int = 256_002, | |
| image_token_index: int = 262_144, | |
| audio_token_index: int = 262_143, | |
| initializer_range: float = 0.02, | |
| **kwargs, | |
| ): | |
| if text_config is None: | |
| text_config = Gemma3TextConfig() | |
| logger.info("text_config is None, using default Gemma3TextConfig vision config.") | |
| elif isinstance(text_config, dict): | |
| text_config = Gemma3TextConfig(**text_config) | |
| if isinstance(vision_config, dict): | |
| vision_config = SiglipVisionConfig(**vision_config) | |
| else: | |
| vision_config = SiglipVisionConfig() | |
| logger.info( | |
| "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " | |
| "to text tasks." | |
| ) | |
| if isinstance(audio_config, dict): | |
| audio_config = AudioConfig(**audio_config) | |
| else: | |
| audio_config = AudioConfig() | |
| logger.info( | |
| "audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited " | |
| "to text tasks." | |
| ) | |
| self.text_config = text_config | |
| self.vision_config = vision_config | |
| self.audio_config = audio_config | |
| self.mm_tokens_per_image = mm_tokens_per_image | |
| self.boi_token_index = boi_token_index | |
| self.eoi_token_index = eoi_token_index | |
| self.boa_token_index = boa_token_index | |
| self.eoa_token_index = eoa_token_index | |
| self.image_token_index = image_token_index | |
| self.audio_token_index = audio_token_index | |
| self.initializer_range = initializer_range | |
| super().__init__(**kwargs) |