Implement Single-Head Self-Attention

Implement a single-head self-attention module. The module should take as input a batch of squence vectors (batch_size, sequence_length, embedding_dim) representing the input sequence batch. The output should be a batch of sequence vectors (batch_size, sequence_length, embedding_dim) that are attended to the input sequence.

Constraints

  • The input sequence will be in the shape of (batch_size, sequence_length, embedding_dim)
  • The dimension of the key vectors (d_k) will be provided in the constructor
  • The output must be in the shape of (batch_size, sequence_length, embedding_dim)
  • Use scaled dot-product attention
  • You may use PyTorch tensor operations but no built-in attention layers
  • Numerical precision of ±0.001 is acceptable for floating point comparisons

Examples

Example 1

{
  "input": {
    "sequence": [
      [
        [
          1,
          0,
          1
        ],
        [
          0,
          2,
          0
        ],
        [
          1,
          1,
          0
        ]
      ]
    ],
    "d_k": 3
  },
  "output": [
    [
      [
        0.726,
        0.158,
        0.726
      ],
      [
        0,
        2,
        0
      ],
      [
        0.549,
        0.451,
        0
      ]
    ]
  ]
}

Example 2

{
  "input": {
    "sequence": [
      [
        [
          1,
          2
        ],
        [
          3,
          4
        ],
        [
          5,
          6
        ]
      ],
      [
        [
          1,
          2
        ],
        [
          3,
          4
        ],
        [
          5,
          6
        ]
      ]
    ],
    "d_k": 2
  },
  "output": [
    [
      [
        2.318,
        2.727
      ],
      [
        3.543,
        4.229
      ],
      [
        3.785,
        4.454
      ]
    ],
    [
      [
        2.318,
        2.727
      ],
      [
        3.543,
        4.229
      ],
      [
        3.785,
        4.454
      ]
    ]
  ]
}

</>Code

Test

Input:

use python data or natural language description

Output: