-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Open
Labels
stalestat:awaiting response from contributortype:supportUser is asking for help / asking an implementation question. Stackoverflow would be better suited.User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Description
I have the following block, and I am getting the error "Unrecognized keyword arguments passed to MultiHeadAttention: {'return_attention_scores': True}"
from tensorflow.keras import layers
def transformer_block(self, x, num_heads, mlp_dim, dropout=0.1, alpha = 0.3):
"""
Standard Transformer encoder block with residuals.
Input and output shapes are the same: (B, num_patches, dim)
"""
# LayerNorm + Multi-Head Self-Attention
x_norm1 = layers.LayerNormalization(epsilon=1e-6)(x)
attn_output, attn_scores = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=x.shape[-1] // num_heads,
return_attention_scores=True)(x_norm1, x_norm1) # self-attention
attn_output = layers.Dropout(dropout)(attn_output)
# Residual connection
x = x + attn_output
# LayerNorm + MLP
x_norm2 = layers.LayerNormalization(epsilon=1e-6)(x)
mlp_output = layers.Dense(mlp_dim)(x_norm2)
mlp_output = layers.LeakyReLU(alpha=alpha)(mlp_output)
mlp_output = layers.Dropout(dropout)(mlp_output)
mlp_output = layers.Dense(x.shape[-1])(mlp_output)
mlp_output = layers.Dropout(dropout)(mlp_output)
# Residual connection
x = x + mlp_output
return x, attn_scores
Metadata
Metadata
Assignees
Labels
stalestat:awaiting response from contributortype:supportUser is asking for help / asking an implementation question. Stackoverflow would be better suited.User is asking for help / asking an implementation question. Stackoverflow would be better suited.