Backward Pass for Multi-Head Attention (MHA) operator
- Suhas Bhairav

- Jul 29
- 3 min read
The backward pass, or backpropagation, for the Multi-Head Attention (MHA) operator is where the model figures out how to adjust its internal parameters (the Query, Key, Value, and Output projection matrices) to reduce the prediction error. This involves applying the chain rule of calculus to propagate the gradients of the loss function back through all the operations performed in the forward pass.

Let's break down the general flow of the MHA backward pass. We'll conceptualize it as going backward through the steps of the forward pass:
Recall the MHA Forward Pass:
Input Linear Projections: Qi=XWiQ, Ki=XWiK, Vi=XWiV for each head i.
Scaled Dot-Product Attention: headi=softmax(dkQiKiT)Vi
Concatenation: Concat(head1,…,headh)
Final Linear Projection: MultiHead(Q,K,V)=Concat(head1,…,headh)WO
Now, let's reverse this process for the backward pass, starting with the gradient of the loss with respect to the output of the MHA layer, denoted as ∂MultiHead∂L.
The MHA Backward Pass Steps:
Backward through Final Linear Projection (WO):
Gradient w.r.t. WO: This is a simple matrix multiplication. The gradient of the loss w.r.t. WO is calculated as ∂WO∂L=(Concat(head1,…,headh))T⋅∂MultiHead∂L.
Gradient w.r.t. Concatenated Heads: The gradient is also propagated back to the concatenated output of the heads: ∂Concat∂L=∂MultiHead∂L⋅(WO)T.
Backward through Concatenation:
The gradient ∂Concat∂L is simply split and distributed to each individual head's output. If Concat was formed by stacking head_1 to head_h horizontally, then the gradient for head_i will be the slice of $\frac{\partial L}{\partial \text{Concat}}$ corresponding to head_i. So, ∂headi∂L is obtained by taking the relevant slice from ∂Concat∂L.
Backward through Each Scaled Dot-Product Attention Head (for each headi):
This is the most complex part, as it involves propagating gradients through the matrix multiplication, scaling, softmax, and the final multiplication by Vi. Let's denote the output of the attention mechanism for a single head as O=softmax(A)V, where A=dkQKT. We have ∂O∂L from the previous step.
Gradient w.r.t. Vi: ∂Vi∂L=(softmax(A))T⋅∂O∂L.
Gradient w.r.t. Softmax Input (A): This involves the derivative of the softmax function. The derivative of softmax is a bit complex, but essentially, if S=softmax(A), then ∂A∂L=(∂O∂LViT)⊙S−(S⋅(∂O∂LViT)⋅ST). This simplifies significantly in practical implementations due to vectorized operations.
Gradient w.r.t. Scaling Factor (dk): The gradient will be propagated back through the division operation.
Gradient w.r.t. Qi and Ki (from A=dkQiKiT):
∂Qi∂L=dk1⋅∂A∂LKi.
∂Ki∂L=dk1⋅QiT∂A∂L. (Note: transpose operations are crucial here).
Backward through Initial Linear Projections (for each Qi,Ki,Vi):
Finally, the gradients ∂Qi∂L, ∂headi∂L (which implies ∂Vi∂L), and ∂Ki∂L are propagated back to the projection matrices (WiQ,WiK,WiV) and the original input embedding X.
Gradients w.r.t. WiQ,WiK,WiV:
∂WiQ∂L=XT⋅∂Qi∂L.
∂WiK∂L=XT⋅∂Ki∂L.
∂WiV∂L=XT⋅∂Vi∂L.
Gradients w.r.t. Input X (for each head):
∂Xfrom_Qi∂L=∂Qi∂L⋅(WiQ)T.
∂Xfrom_Ki∂L=∂Ki∂L⋅(WiK)T.
∂Xfrom_Vi∂L=∂Vi∂L⋅(WiV)T.
Summing Gradients for Input X:
Since the original input X was used for all Qi,Ki,Vi projections across all heads, the final gradient for X is the sum of gradients from all paths: ∂X∂L=∑i=1h(∂Xfrom_Qi∂L+∂Xfrom_Ki∂L+∂Xfrom_Vi∂L).
Key Challenges and Optimizations in Practice:
Numerical Stability: Operations involving softmax and large numbers can lead to numerical instability. Frameworks like PyTorch and TensorFlow use optimized kernels that handle these calculations carefully (e.g., log-sum-exp trick for softmax).
Memory Efficiency: Storing all intermediate activations from the forward pass is necessary for the backward pass. For large LLMs and long sequences, this can consume enormous amounts of memory. Techniques like gradient checkpointing (recomputing certain activations during the backward pass rather than storing them) are used to mitigate this.
Hardware Acceleration: The entire process is highly parallelized on GPUs, leveraging specialized hardware for matrix multiplications and other linear algebra operations.
Dropout: If dropout is applied, its backward pass involves ensuring gradients only flow through the un-dropped connections, effectively applying the same mask in reverse.
Masking: For masked self-attention (e.g., in decoder-only models), the masking operation (typically applied by setting attention scores to a very large negative number before softmax) must also be handled correctly in the backward pass, ensuring no gradients flow through masked-out connections.
The MHA backward pass is a computationally intensive and intricate process, but it's fundamental to how LLMs learn. Modern deep learning frameworks abstract away much of this complexity, allowing researchers and developers to focus on model architecture and data, while the underlying mathematical machinery efficiently computes the necessary gradients for learning.

