Implement Multi-Head Attention with Tensor Reshaping

Implement PyTorch module of the multi-head attention, a key component of the Transformer model, using tensor reshaping. The module should take as input a batch of query vectors (batch_size, sequence_length, embedding_dim), a batch of key vectors (batch_size, sequence_length, embedding_dim), and a batch of value vectors (batch_size, sequence_length, embedding_dim). The output should be a batch of sequence vectors (batch_size, sequence_length, embedding_dim) that are attended to the input sequence.

Constraints

  • The embedding dimension must be divisible by the number of heads.
  • Use tensor reshaping to split and combine the tensors for multi-head attention.
  • Do not use any high-level deep learning library functions that directly implement multi-head attention.

Examples

Example 1

{
  "input": {
    "queries": "A tensor of shape (2, 10, 64) representing queries",
    "keys": "A tensor of shape (2, 10, 64) representing keys",
    "values": "A tensor of shape (2, 10, 64) representing values",
    "num_heads": 8
  },
  "output": "A tensor of shape (2, 10, 64) after applying multi-head attention"
}

Example 2

{
  "input": {
    "queries": "A tensor of shape (5, 3, 128) representing queries",
    "keys": "A tensor of shape (5, 3, 128) representing keys",
    "values": "A tensor of shape (5, 3, 128) representing values",
    "num_heads": 4
  },
  "output": "A tensor of shape (5, 3, 128) after applying multi-head attention"
}

</>Code

Test

Input:

use python data or natural language description

Output: