Diffusers documentation
BriaTransformer2DModel
BriaTransformer2DModel
A modified flux Transformer model from Bria
BriaTransformer2DModel
class diffusers.BriaTransformer2DModel
< source >( patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 num_single_layers: int = 38 attention_head_dim: int = 128 num_attention_heads: int = 24 joint_attention_dim: int = 4096 pooled_projection_dim: int = None guidance_embeds: bool = False axes_dims_rope: typing.List[int] = [16, 56, 56] rope_theta = 10000 time_theta = 10000 )
Parameters
- patch_size (
int
) — Patch size to turn the input data into small patches. - in_channels (
int
, optional, defaults to 16) — The number of channels in the input. - num_layers (
int
, optional, defaults to 18) — The number of layers of MMDiT blocks to use. - num_single_layers (
int
, optional, defaults to 18) — The number of layers of single DiT blocks to use. - attention_head_dim (
int
, optional, defaults to 64) — The number of channels in each head. - num_attention_heads (
int
, optional, defaults to 18) — The number of heads to use for multi-head attention. - joint_attention_dim (
int
, optional) — The number ofencoder_hidden_states
dimensions to use. - pooled_projection_dim (
int
) — Number of dimensions to use when projecting thepooled_projections
. - guidance_embeds (
bool
, defaults to False) — Whether to use guidance embeddings.
The Transformer model introduced in Flux. Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
forward
< source >( hidden_states: Tensor encoder_hidden_states: Tensor = None pooled_projections: Tensor = None timestep: LongTensor = None img_ids: Tensor = None txt_ids: Tensor = None guidance: Tensor = None attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None return_dict: bool = True controlnet_block_samples = None controlnet_single_block_samples = None )
Parameters
- hidden_states (
torch.FloatTensor
of shape(batch size, channel, height, width)
) — Inputhidden_states
. - encoder_hidden_states (
torch.FloatTensor
of shape(batch size, sequence_len, embed_dims)
) — Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (
torch.FloatTensor
of shape(batch_size, projection_dim)
) — Embeddings projected from the embeddings of input conditions. - timestep (
torch.LongTensor
) — Used to indicate denoising step. - block_controlnet_hidden_states — (
list
oftorch.Tensor
): A list of tensors that if specified are added to the residuals of transformer blocks. - attention_kwargs (
dict
, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessor
as defined underself.processor
in diffusers.models.attention_processor. - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a~models.transformer_2d.Transformer2DModelOutput
instead of a plain tuple.
The BriaTransformer2DModel forward method.