Implement Pre-LN vs Post-LN Architecture Comparison

Implement and compare two variants of the Transformer model's layer normalization (LN) placement: Pre-LN and Post-LN. Pre-LN applies layer normalization before the residual connection, whereas Post-LN applies it after. You should implement the forward pass for both architectures, including the necessary components such as multi-head attention and feed-forward networks

Constraints

  • You must implement both Pre-LN and Post-LN architectures within the same codebase.
  • You can use layer normalization and other necessary modules from PyTorch.

Examples

Example 1

{
  "input": "A sequence of embeddings with shape (batch_size, sequence_length, embedding_dim)",
  "output": "The output of the model after processing the sequence, with shape (batch_size, sequence_length, embedding_dim)"
}

</>Code

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads, hidden_dim, pre_ln=True):
super(TransformerLayer, self).__init__()
self.pre_ln = pre_ln
self.attention = MultiHeadAttention(embed_dim, num_heads)
self.feed_forward = FeedForward(embed_dim, hidden_dim)
self.layer_norm = nn.LayerNorm(embed_dim)

def forward(self, x):
# TODO: Implement the forward pass for the Transformer layer
# Consider the placement of layer normalization based on pre_ln flag
pass

Test

Input:

use python data or natural language description


Output: