<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/Atom">
  <channel>
    <title>varmology</title>
    <description>Technical blog on AI, LLMs, quantization, and machine learning engineering.</description>
    <link>https://aakashvarma.github.io/</link>
    <atom:link href="https://aakashvarma.github.io/feed.xml" rel="self" type="application/rss+xml"/>
    <pubDate>Tue, 17 Mar 2026 17:13:15 -0400</pubDate>
    <lastBuildDate>Tue, 17 Mar 2026 17:13:15 -0400</lastBuildDate>
    <generator>Jekyll v4.4.1</generator>
    
      <item>
        <title>Let it Flow</title>
        <description>&lt;p&gt;Coming soon.&lt;/p&gt;
</description>
        <pubDate>Tue, 17 Mar 2026 00:00:00 -0400</pubDate>
        <link>https://aakashvarma.github.io/let_it_flow/</link>
        <guid isPermaLink="true">https://aakashvarma.github.io/let_it_flow/</guid>
        
      </item>
    
      <item>
        <title>Layer Normalization as a Projection: The Complete Geometric Interpretation</title>
        <description>&lt;h2 id=&quot;1-introduction&quot;&gt;1. Introduction&lt;/h2&gt;

&lt;p&gt;Layer Normalization is a crucial technique in modern neural networks, particularly in Large Language Models (LLMs), where it helps stabilize training and accelerate convergence. While typically presented as a statistical normalization procedure, there’s a deeper, more elegant interpretation: layer normalization can be understood as a sequence of geometric projections in vector space.&lt;/p&gt;

&lt;p&gt;This statistical operation, now a standard component in most neural network architectures, serves as a vital stabilizer during training. By normalizing activations across the feature dimension, it helps prevent the internal covariate shift problem that can slow down or destabilize training. However, beyond its practical benefits, layer normalization harbors a beautiful geometric interpretation that provides deeper insights into why it works so effectively.&lt;/p&gt;

&lt;p&gt;This article provides a comprehensive exploration of this geometric perspective, breaking down each step with rigorous mathematical derivations and intuitive explanations. By understanding layer normalization through the lens of projections, we gain insights into why it works so effectively and how it relates to the geometry of feature spaces.&lt;/p&gt;

&lt;h2 id=&quot;2-the-standard-layer-normalization-formulation&quot;&gt;2. The Standard Layer Normalization Formulation&lt;/h2&gt;

&lt;p&gt;Before diving into the geometric interpretation, let’s review the standard formulation of layer normalization.&lt;/p&gt;

&lt;p&gt;Given an input vector \(x = (x_1, x_2, \ldots, x_d)\) of dimension \(d\), layer normalization performs the following transformation:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;\(\mathrm{E}[x] = \frac{1}{d}\sum_{i=1}^d x_i\) is the mean of the vector&lt;/li&gt;
  &lt;li&gt;\(\mathrm{Var}[x] = \frac{1}{d}\sum_{i=1}^d (x_i - \mathrm{E}[x])^2\) is the variance of the vector&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;This transformation centers the vector by subtracting the mean, then scales it by dividing by the standard deviation. The result is a vector with zero mean and unit variance across its components, regardless of the input’s original scale or offset. After this normalization, the vector typically undergoes an affine transformation with learnable parameters (a scaling and a bias term) that allows the network to recover the representational power that might be lost during normalization.&lt;/p&gt;

&lt;p&gt;At first glance, this appears to be a purely statistical operation. However, as we’ll see, it can be elegantly reinterpreted as a sequence of geometric transformations in the vector space.&lt;/p&gt;

&lt;h2 id=&quot;3-understanding-vector-centering&quot;&gt;3. Understanding Vector Centering&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-ones-vector&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-ones-vector&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The all-ones vector \(\vec{1}\) has special significance in many mathematical fields including linear algebra, statistics, and machine learning. In the context of normalization, it represents the direction along which all components change uniformly. Its geometric interpretation connects statistical concepts like mean and variance to vector projections in high-dimensional spaces. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;The first step in layer normalization is centering the vector by subtracting the mean from each component. Centering is a fundamental preprocessing step in many statistical and machine learning methods. It shifts the coordinate system so that the “center of mass” of the data lies at the origin. In the context of a single vector, centering removes the common offset across all dimensions, focusing instead on the relative differences between components.&lt;/p&gt;

&lt;p&gt;For a vector \(x\), the centered vector is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
x_{centered} = x - \mathrm{E}[x] \cdot \vec{1} = (x_1 - \mathrm{E}[x], x_2 - \mathrm{E}[x], \ldots, x_d - \mathrm{E}[x])
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where \(\vec{1} = (1, 1, \ldots, 1)\) is the all-ones vector.&lt;/p&gt;

&lt;p&gt;In neural networks, centering helps stabilize gradients during training by removing large offsets that might cause activations to saturate. It also makes the learning process more consistent across different input scales.&lt;/p&gt;

&lt;p&gt;Centering has several important geometric interpretations. First, it can be viewed as a translation to the origin, shifting the coordinate system so that the mean becomes the new origin. This is a rigid translation of the vector space, preserving all distances and angles between points while moving the center of mass to zero.&lt;/p&gt;

&lt;p&gt;Second, centering removes the “common mode” component of the vector that is the same across all dimensions, leaving only the pattern of variations. This “common mode” represents a uniform shift in all directions and often contains less discriminative information than the relative patterns between features.&lt;/p&gt;

&lt;p&gt;Third, as we’ll explore in detail, centering can be viewed as projecting a vector onto the hyperplane orthogonal to the all-ones vector. This perspective connects statistical centering to the geometric operation of projection, providing new insights into its properties.&lt;/p&gt;

&lt;h2 id=&quot;4-vector-centering-as-a-projection&quot;&gt;4. Vector Centering as a Projection&lt;/h2&gt;

&lt;p&gt;Now we come to the key insight: centering a vector is geometrically equivalent to projecting it onto the hyperplane orthogonal to the all-ones vector. This connection between a statistical operation (centering) and a geometric one (projection) is both elegant and profound.&lt;/p&gt;

&lt;p&gt;To understand this equivalence, we need to explore how projections work and how the all-ones vector defines a special direction in the space. Let’s define the all-ones vector \(\vec{1} = (1, 1, \ldots, 1) \in \mathbb{R}^d\). This vector has several important properties.&lt;/p&gt;

&lt;p&gt;Its length is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\|\vec{1}\| = \sqrt{\sum_{i=1}^d 1^2} = \sqrt{d}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The all-ones vector has a magnitude that grows with the square root of the dimension, reflecting the fact that adding more dimensions increases its length.&lt;/p&gt;

&lt;p&gt;The normalized all-ones vector is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\hat{1} = \frac{\vec{1}}{\|\vec{1}\|} = \frac{(1, 1, \ldots, 1)}{\sqrt{d}} = (\frac{1}{\sqrt{d}}, \frac{1}{\sqrt{d}}, \ldots, \frac{1}{\sqrt{d}})
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This unit vector points in the same direction as \(\vec{1}\) but has length 1, making it useful for projections.&lt;/p&gt;

&lt;p&gt;For any vector \(x\), the inner product with \(\vec{1}\) gives the sum of its components:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\langle x, \vec{1} \rangle = \sum_{i=1}^d x_i
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This property connects the geometric operation of inner product with the statistical operation of summation.&lt;/p&gt;

&lt;p&gt;The inner product with the normalized all-ones vector gives:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\langle x, \hat{1} \rangle = \sum_{i=1}^d x_i \cdot \frac{1}{\sqrt{d}} = \frac{1}{\sqrt{d}} \sum_{i=1}^d x_i = \frac{\sum_{i=1}^d x_i}{\sqrt{d}} = \sqrt{d} \cdot \mathrm{E}[x]
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This remarkable result connects the mean (a statistical concept) with the inner product (a geometric concept).&lt;/p&gt;

&lt;p&gt;The hyperplane orthogonal to \(\vec{1}\) consists of all vectors \(v\) such that \(\langle v, \vec{1} \rangle = 0\), or equivalently, \(\sum_{i=1}^d v_i = 0\). This is a \((d-1)\)-dimensional subspace of \(\mathbb{R}^d\). This hyperplane has a special statistical interpretation: it contains all vectors whose components sum to zero, or equivalently, all vectors with mean zero. It represents the space of centered vectors, those with no “common mode” component.&lt;/p&gt;

&lt;p&gt;In 3D, this is the plane passing through the origin with equation \(x + y + z = 0\). We can visualize this as a plane that cuts through the origin and is tilted equally with respect to all three coordinate axes.&lt;/p&gt;

&lt;h2 id=&quot;5-the-geometric-equivalence-of-centering-and-hyperplane-projectio&quot;&gt;5. The Geometric Equivalence of Centering and Hyperplane Projectio&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-hyperplane&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-hyperplane&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The hyperplane projection concept can be visualized geometrically: In 3D space, the all-ones vector \(\vec{1} = (1,1,1)\) points along the main diagonal from the origin. The hyperplane orthogonal to this vector is the plane \(x + y + z = 0\), which passes through the origin and forms equal angles with all three coordinate axes. When we project a vector onto this hyperplane, we are essentially removing any component that points in the direction of this diagonal. This isolates the variations between components while eliminating the common offset. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;The key insight is understanding why projecting a vector onto the hyperplane orthogonal to the all-ones vector is geometrically equivalent to centering it.&lt;/p&gt;

&lt;p&gt;When we center a vector \(x\), we’re subtracting the same value (the mean) from each component:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
x_{centered} = (x_1 - \mathrm{E}[x], x_2 - \mathrm{E}[x], \ldots, x_d - \mathrm{E}[x])
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Geometrically, this means we’re moving the vector in the direction opposite to the all-ones vector \(\vec{1} = (1, 1, \ldots, 1)\) by a distance of \(\mathrm{E}[x]\) along each dimension.&lt;/p&gt;

&lt;p&gt;Now, consider what happens when we project a vector onto a hyperplane. The projection removes the component of the vector that is parallel to the normal vector of the hyperplane. In our case, the hyperplane is orthogonal to \(\vec{1}\), so its normal vector is \(\vec{1}\).&lt;/p&gt;

&lt;p&gt;The component of \(x\) parallel to \(\vec{1}\) is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{comp}_{\vec{1}}(x) = \frac{\langle x, \vec{1} \rangle}{\|\vec{1}\|^2} \cdot \vec{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Since \(\langle x, \vec{1} \rangle = \sum_{i=1}^d x_i\) and \(\|\vec{1}\|^2 = d\), we have:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{comp}_{\vec{1}}(x) = \frac{\sum_{i=1}^d x_i}{d} \cdot \vec{1} = \mathrm{E}[x] \cdot \vec{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This component represents a vector where all elements are equal to the mean of \(x\). It’s the part of \(x\) that points in the direction of the all-ones vector, corresponding to the “common mode” or uniform shift across all dimensions.&lt;/p&gt;

&lt;p&gt;When we project \(x\) onto the hyperplane orthogonal to \(\vec{1}\), we remove this component:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{proj}_{\text{hyperplane}}(x) = x - \text{comp}_{\vec{1}}(x) = x - \mathrm{E}[x] \cdot \vec{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This is exactly the centered vector! The projection operation has produced the same result as centering.&lt;/p&gt;

&lt;p&gt;So, geometrically, centering a vector is equivalent to projecting it onto the hyperplane orthogonal to the all-ones vector because centering removes the mean from each component, effectively removing the “uniform” part of the vector. Projection onto the hyperplane removes the component parallel to the normal vector, which in this case is the all-ones vector. These two operations are mathematically identical, both resulting in \(x - \mathrm{E}[x] \cdot \vec{1}\).&lt;/p&gt;

&lt;p&gt;This equivalence provides a powerful geometric interpretation of the statistical operation of centering, connecting two seemingly different mathematical concepts.&lt;/p&gt;

&lt;h2 id=&quot;6-deriving-the-projection-formula&quot;&gt;6. Deriving the Projection Formula&lt;/h2&gt;

&lt;p&gt;Let’s derive the formula for projecting a vector \(x\) onto the hyperplane orthogonal to \(\vec{1}\) in a step-by-step manner.&lt;/p&gt;

&lt;p&gt;The projection of a vector onto a subspace involves two steps: first, finding the component of the vector along the normal direction to the subspace, and second, subtracting this component from the original vector.&lt;/p&gt;

&lt;p&gt;The vector projection formula is foundational in linear algebra:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{proj}_{\text{subspace}}(v) = v - \frac{\langle v, n \rangle}{\|n\|^2} \cdot n
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This operation has wide applications beyond normalization, including in computer graphics (shadow calculations), signal processing (noise elimination), and quantum mechanics (measurement operations). Understanding projections helps connect abstract mathematical concepts to their geometric interpretations.&lt;/p&gt;

&lt;p&gt;The projection of \(x\) onto the direction of \(\hat{1}\) (the normalized all-ones vector) is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{proj}_{\hat{1}}(x) = \langle x, \hat{1} \rangle \hat{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This gives the component of \(x\) that points in the direction of the all-ones vector. Substituting the value of \(\langle x, \hat{1} \rangle\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{proj}_{\hat{1}}(x) = \sqrt{d} \cdot \mathrm{E}[x] \cdot \frac{\vec{1}}{\sqrt{d}} = \mathrm{E}[x] \cdot \vec{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;To get the projection onto the hyperplane orthogonal to \(\vec{1}\), we subtract this component:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_1(x) = x - \text{proj}_{\hat{1}}(x) = x - \mathrm{E}[x] \cdot \vec{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Component-wise, this gives us:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_1(x)_j = x_j - \mathrm{E}[x]
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Which is exactly the centered vector!&lt;/p&gt;

&lt;p&gt;Let’s verify that \(p_1(x)\) is indeed orthogonal to \(\vec{1}\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\langle p_1(x), \vec{1} \rangle = \sum_{j=1}^d (x_j - \mathrm{E}[x]) = \sum_{j=1}^d x_j - d \cdot \mathrm{E}[x] = \sum_{j=1}^d x_j - \sum_{j=1}^d x_j = 0
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This confirms that the projection is orthogonal to \(\vec{1}\) as required. The centered vector lies exactly on the hyperplane defined by the all-ones vector.&lt;/p&gt;

&lt;p&gt;In our case, \(n = \vec{1}\) and \(v = x\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_1(x) = x - \frac{\langle x, \vec{1} \rangle}{\|\vec{1}\|^2} \cdot \vec{1} = x - \frac{\sum_{i=1}^d x_i}{d} \cdot \vec{1} = x - \mathrm{E}[x] \cdot \vec{1}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This gives us the same result as before, confirming our understanding of centering as a projection.&lt;/p&gt;

&lt;h2 id=&quot;7-the-subspace-perspective&quot;&gt;7. The Subspace Perspective&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-subspace&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-subspace&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;This decomposition has deep connections to concepts in linear algebra and statistics. In statistics, it relates to the decomposition of total variance into “between-group” and “within-group” components. In signal processing, it corresponds to separating DC offset from AC components. In physics, it resembles decomposing a force into conservative and non-conservative components. The power of this perspective is that it clarifies what information layer normalization preserves (relative patterns) versus what it removes (common offsets). &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;The space \(\mathbb{R}^d\) can be decomposed into two orthogonal subspaces: the one-dimensional subspace spanned by \(\vec{1}\), which contains all vectors with equal components (the space of “uniform shifts” or “common modes”), and the \((d-1)\)-dimensional subspace orthogonal to \(\vec{1}\), which contains all vectors whose components sum to zero (the space of “variations around the mean”).&lt;/p&gt;

&lt;p&gt;Any vector \(x\) can be uniquely expressed as the sum of two components, one from each subspace:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
x = (x - p_1(x)) + p_1(x)
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Or equivalently:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
x = \mathrm{E}[x] \cdot \vec{1} + (x - \mathrm{E}[x] \cdot \vec{1})
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where \(\mathrm{E}[x] \cdot \vec{1}\) is the component in the direction of \(\vec{1}\), representing the uniform shift (the mean), and \(x - \mathrm{E}[x] \cdot \vec{1}\) is the component orthogonal to \(\vec{1}\), representing the pattern of variations around the mean.&lt;/p&gt;

&lt;p&gt;This decomposition provides insight into the structure of the vector: it separates the overall magnitude (represented by the mean) from the pattern of variations between components.&lt;/p&gt;

&lt;h2 id=&quot;8-projection-onto-the-unit-sphere-the-second-step&quot;&gt;8. Projection onto the Unit Sphere: The Second Step&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-sphere&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-sphere&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The unit sphere projection introduces a critical non-linearity in the normalization process. Unlike the hyperplane projection (which is linear), projecting onto the unit sphere is a non-linear operation. This non-linearity contributes to the expressiveness of neural networks with layer normalization, allowing them to represent more complex functions. In optimization terms, this projection constrains the solution space to vectors of unit length, improving the conditioning of the optimization problem. Without this step, the scale of activations could vary widely between different layers and neurons, causing optimization instabilities. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;After centering the vector, the next step in layer normalization is normalizing by the standard deviation. This can be interpreted as a second geometric operation: projection onto the unit sphere, followed by a scaling.&lt;/p&gt;

&lt;p&gt;The unit sphere is the set of all points at a fixed distance (radius 1) from the origin. Projecting a vector onto the unit sphere normalizes its length while preserving its direction, making it a natural geometric counterpart to the statistical operation of dividing by the standard deviation.&lt;/p&gt;

&lt;p&gt;The projection of any non-zero vector \(v\) onto the unit sphere, denoted as \(p_S(v)\), normalizes the vector to unit length:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_S(v) = \frac{v}{\|v\|}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This operation preserves the direction of the vector but changes its length to 1. It can be interpreted as scaling the vector so that it just touches the unit sphere.&lt;/p&gt;

&lt;p&gt;The projection onto the unit sphere has several important properties that make it useful for normalization. It preserves the direction of the original vector, ensuring that the relative relationships between dimensions are maintained, which is often more important than the absolute values. It normalizes the length to exactly 1, which helps stabilize gradient magnitudes during training, preventing them from exploding or vanishing.&lt;/p&gt;

&lt;p&gt;Unlike projection onto a subspace, projection onto the unit sphere is a non-linear operation. This non-linearity plays a role in the expressiveness of neural networks, allowing them to represent more complex functions. One technical note is that the projection is undefined for the zero vector (since division by zero is undefined). In practice, this is rarely an issue since deep learning frameworks add a small epsilon to the denominator to prevent division by zero.&lt;/p&gt;

&lt;h2 id=&quot;9-connecting-the-norm-of-the-centered-vector-to-variance&quot;&gt;9. Connecting the Norm of the Centered Vector to Variance&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-variance&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-variance&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;This relationship connects two seemingly different mathematical domains: geometry and statistics. The equality \(\|p_1(x)\|^2 = d \cdot \mathrm{Var}[x]\) shows that the geometric concept of distance in the centered subspace directly corresponds to the statistical concept of variance scaled by dimension. Historical Note: This connection has been implicitly used in statistics for decades, particularly in Principal Component Analysis (PCA), but the explicit relationship between variance and projection distance in the context of neural network normalization was only formalized with layer normalization techniques. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;To establish the link between variance normalization and sphere projection, we need to relate the norm of the centered vector to the variance.&lt;/p&gt;

&lt;p&gt;The squared norm of the centered vector \(p_1(x) = x - \mathrm{E}[x] \cdot \vec{1}\) is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\|p_1(x)\|^2 = \sum_{i=1}^d (x_i - \mathrm{E}[x])^2
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This sum represents the total squared deviation from the mean across all dimensions. It’s closely related to the variance:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\sum_{i=1}^d (x_i - \mathrm{E}[x])^2 = d \cdot \frac{1}{d} \sum_{i=1}^d (x_i - \mathrm{E}[x])^2 = d \cdot \mathrm{Var}[x]
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Therefore:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\|p_1(x)\|^2 = d \cdot \mathrm{Var}[x]
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Taking the square root:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\|p_1(x)\| = \sqrt{d \cdot \mathrm{Var}[x]}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This beautiful result connects the geometric measure (norm) with the statistical measure (variance multiplied by dimension). It shows that the length of the centered vector is proportional to the standard deviation, with the dimension as the constant of proportionality.&lt;/p&gt;

&lt;p&gt;Rearranging the equation, we can express the variance in terms of the norm:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\mathrm{Var}[x] = \frac{\|p_1(x)\|^2}{d}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This shows that the variance is the average squared distance from the mean, which is the squared norm of the centered vector divided by the dimension.&lt;/p&gt;

&lt;p&gt;Taking the square root:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\sqrt{\mathrm{Var}[x]} = \frac{\|p_1(x)\|}{\sqrt{d}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This result allows us to connect layer normalization’s division by the standard deviation to a geometric scaling operation related to the norm of the centered vector.&lt;/p&gt;

&lt;h2 id=&quot;10-connecting-layer-normalization-to-sequential-projections&quot;&gt;10. Connecting Layer Normalization to Sequential Projections&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-sequential&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-sequential&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The unit sphere projection introduces a critical non-linearity in the normalization process. Unlike the hyperplane projection (which is linear), projecting onto the unit sphere is a non-linear operation. This non-linearity contributes to the expressiveness of neural networks with layer normalization, allowing them to represent more complex functions. In optimization terms, this projection constrains the solution space to vectors of unit length, improving the conditioning of the optimization problem. Without this step, the scale of activations could vary widely between different layers and neurons, causing optimization instabilities. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;Now we’ll derive the complete connection between layer normalization and the two projections. This will show how the statistical normalization procedure can be reinterpreted as a sequence of geometric transformations.&lt;/p&gt;

&lt;p&gt;The projection of \(p_1(x)\) onto the unit sphere is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_S(p_1(x)) = \frac{p_1(x)}{\|p_1(x)\|} = \frac{x - \mathrm{E}[x]}{\|p_1(x)\|}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This normalization preserves the direction of the centered vector but scales it to have unit length. Substituting the value of \(\|p_1(x)\|\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_S(p_1(x)) = \frac{x - \mathrm{E}[x]}{\sqrt{d \cdot \mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This expression shows how the projection onto the unit sphere relates to the standard statistical normalization formula, but with a different scaling factor.&lt;/p&gt;

&lt;p&gt;Let’s rewrite the original layer normalization formula:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Now let’s manipulate this to match our projection-based expression:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}} \cdot \frac{\sqrt{d}}{\sqrt{d}} = \sqrt{d} \cdot \frac{x - \mathrm{E}[x]}{\sqrt{d \cdot \mathrm{Var}[x]}} = \sqrt{d} \cdot p_S(p_1(x))
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This gives us our final result:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}} = \sqrt{d} \cdot p_S(p_1(x))
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This elegant formula reveals that layer normalization is equivalent to: first projecting onto the hyperplane orthogonal to \(\vec{1}\) (centering), then projecting onto the unit sphere (normalizing), and finally scaling by \(\sqrt{d}\). The scaling factor \(\sqrt{d}\) accounts for the difference between normalizing by the norm of the centered vector and normalizing by the standard deviation.&lt;/p&gt;

&lt;h2 id=&quot;11-alternative-derivation&quot;&gt;11. Alternative Derivation&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-alt-derivation&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-alt-derivation&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Alternative derivations strengthen mathematical proofs by approaching the same result from different starting points. This particular approach starts from the statistical formula and derives the geometric interpretation, whereas our primary derivation began with the geometric perspective and showed its equivalence to the statistical formulation. This bidirectional relationship establishes a more robust connection between the two domains. It is similar to how in physics, one can derive the laws of motion from either energy principles or force principles and arrive at equivalent formulations. Such multiple derivations also help identify the core mathematical principles governing a phenomenon, which in turn can inspire new algorithms and approaches to normalization in deep learning. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;Let’s approach the derivation from a different angle to further reinforce our understanding. This alternative approach starts with the standard layer normalization formula and progressively transforms it into the projection-based expression.&lt;/p&gt;

&lt;p&gt;Starting with the layer normalization formula:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;We can rewrite this in terms of the centered vector \(p_1(x) = x - \mathrm{E}[x]\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{p_1(x)}{\sqrt{\mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This formulation already separates the centering step (creation of \(p_1(x)\)) from the normalization step (division by \(\sqrt{\mathrm{Var}[x]}\)).&lt;/p&gt;

&lt;p&gt;Now, let’s expand the variance in terms of the centered vector:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\mathrm{Var}[x] = \frac{1}{d} \sum_{i=1}^d (x_i - \mathrm{E}[x])^2 = \frac{1}{d} \sum_{i=1}^d p_1(x)_i^2 = \frac{\|p_1(x)\|^2}{d}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This expresses the variance as the squared norm of the centered vector divided by the dimension, connecting the statistical measure to the geometric one.&lt;/p&gt;

&lt;p&gt;Substituting this into our equation:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{p_1(x)}{\sqrt{\frac{\|p_1(x)\|^2}{d}}} = \frac{p_1(x)}{\frac{\|p_1(x)\|}{\sqrt{d}}} = \sqrt{d} \cdot \frac{p_1(x)}{\|p_1(x)\|}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Since \(\frac{p_1(x)}{\|p_1(x)\|} = p_S(p_1(x))\) is the projection onto the unit sphere, we have:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \sqrt{d} \cdot p_S(p_1(x))
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This confirms our previous derivation from a different starting point, strengthening our confidence in the result.&lt;/p&gt;

&lt;h2 id=&quot;12-the-complete-geometric-interpretation&quot;&gt;12. The Complete Geometric Interpretation&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-complete&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-complete&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Layer normalization projection sequence relates to mathematical concepts in differential geometry, where operations on manifolds (curved spaces) involve projections onto tangent spaces followed by normalization. The fact that these operations compose to form a useful neural network operation is not coincidental. Similar sequences of operations appear in areas like quantum mechanics (normalization of wave functions), computer graphics (normal mapping and shading), signal processing (whitening transformations), and control systems (state space normalization). This suggests that layer normalization taps into a fundamental geometric principle that has broad applicability across multiple domains where normalization is beneficial. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;We can now interpret layer normalization geometrically as a sequence of operations: First, we project the vector \(x\) onto the hyperplane orthogonal to \(\vec{1}\), giving us \(p_1(x) = x - \mathrm{E}[x] \cdot \vec{1}\). Geometrically, this centers the vector by subtracting the mean from each component. The resulting vector \(p_1(x)\) lies in a subspace where the sum of all components is zero.&lt;/p&gt;

&lt;p&gt;Second, we project the centered vector \(p_1(x)\) onto the unit sphere, giving us \(p_S(p_1(x)) = \frac{p_1(x)}{\|p_1(x)\|}\). This normalizes the vector to have a length of 1. The resulting vector points in the same direction as \(p_1(x)\) but has unit length.&lt;/p&gt;

&lt;p&gt;Finally, we scale the unit vector by \(\sqrt{d}\), giving us \(\sqrt{d} \cdot p_S(p_1(x))\). This scaling factor ensures the final result matches the standard layer normalization formula.&lt;/p&gt;

&lt;p&gt;The complete transformation can be written as:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}} = \sqrt{d} \cdot p_S(p_1(x))
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This geometric interpretation provides insights into why layer normalization works so effectively in neural networks. It removes the “common mode” component of the input (which often carries less discriminative information) and standardizes the scale of the remaining variations, helping gradient-based optimization algorithms converge more quickly and stably.&lt;/p&gt;

&lt;p&gt;In 3D, the hyperplane orthogonal to \(\vec{1} = (1, 1, 1)\) is the plane \(x + y + z = 0\). This plane passes through the origin and is oriented symmetrically with respect to all three coordinate axes. The geometric interpretation of layer normalization involves projecting a vector onto this plane, then onto the unit sphere, and finally scaling by \(\sqrt{d}\). This sequence of operations standardizes the vector, making it more amenable to further processing in a neural network.&lt;/p&gt;

&lt;h2 id=&quot;13-working-example-layer-normalization-in-3d-space&quot;&gt;13. Working Example: Layer Normalization in 3D Space&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-example&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-example&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Visualizing in 3D space helps build intuition about the geometric interpretation. The vector (5,8,2) starts in a general position in space. After centering, it moves to the plane x+y+z=0. Then projection onto the unit sphere normalizes its length, before the final scaling gives it a length of √3. This concrete example demonstrates that the mathematical formulations actually produce the expected results, confirming our theoretical understanding. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;To make our discussion concrete, let’s trace a vector’s transformation through layer normalization step by step. We’ll use a 3D example with vector \(x = (5, 8, 2)\) and follow its journey.&lt;/p&gt;

&lt;p&gt;First, we calculate its statistical properties:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Mean: \(\mathrm{E}[x] = \frac{5 + 8 + 2}{3} = 5\)&lt;/li&gt;
  &lt;li&gt;Variance: \(\mathrm{Var}[x] = \frac{1}{3}[(5-5)^2 + (8-5)^2 + (2-5)^2] = \frac{1}{3}(0 + 9 + 9) = 6\)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Step 1: Centering the vector&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;We center the vector by subtracting the mean from each component:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
x - \mathrm{E}[x] \cdot \vec{1} = (5, 8, 2) - 5 \cdot (1, 1, 1) = (0, 3, -3)
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This centered vector lies on the hyperplane \(x + y + z = 0\): \(0 + 3 + (-3) = 0\).&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Step 2: Verifying that centering is a projection&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;For the all-ones vector, we have \(\vec{1} = (1, 1, 1)\) with length \(\|\vec{1}\| = \sqrt{3}\).&lt;/p&gt;

&lt;p&gt;The component of \(x\) along \(\vec{1}\) is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\frac{\langle (5, 8, 2), \vec{1} \rangle}{\|\vec{1}\|^2} \cdot \vec{1} = \frac{15}{3} \cdot (1, 1, 1) = 5 \cdot (1, 1, 1) = (5, 5, 5)
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Subtracting gives us the projection onto the hyperplane:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
(5, 8, 2) - (5, 5, 5) = (0, 3, -3)
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This matches our centered vector, confirming that centering equals projection.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Step 3: Projecting onto the unit sphere&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The squared norm of the centered vector is \(\|p_1(x)\|^2 = 0^2 + 3^2 + (-3)^2 = 18\), which equals \(d \cdot \mathrm{Var}[x] = 3 \cdot 6 = 18\).&lt;/p&gt;

&lt;p&gt;We project onto the unit sphere:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
p_S(p_1(x)) = \frac{(0, 3, -3)}{\sqrt{18}} = (0, \frac{3}{\sqrt{18}}, -\frac{3}{\sqrt{18}})
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Step 4: Final scaling&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;We scale by \(\sqrt{d} = \sqrt{3}\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\sqrt{3} \cdot p_S(p_1(x)) = (0, \frac{3\sqrt{3}}{\sqrt{18}}, -\frac{3\sqrt{3}}{\sqrt{18}}) = (0, \frac{3}{\sqrt{6}}, -\frac{3}{\sqrt{6}})
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Verification&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;This matches the direct layer normalization calculation:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}} = \frac{(0, 3, -3)}{\sqrt{6}} = (0, \frac{3}{\sqrt{6}}, -\frac{3}{\sqrt{6}})
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The normalized vector has:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Mean of zero: \(\mathrm{E}[y] = \frac{1}{3}(0 + \frac{3}{\sqrt{6}} + (-\frac{3}{\sqrt{6}})) = 0\)&lt;/li&gt;
  &lt;li&gt;Variance of one: \(\mathrm{Var}[y] = \frac{1}{3}(0^2 + (\frac{3}{\sqrt{6}})^2 + (-\frac{3}{\sqrt{6}})^2) = 1\)&lt;/li&gt;
  &lt;li&gt;Norm of \(\sqrt{d}\): \(\|y\| = \sqrt{0^2 + (\frac{3}{\sqrt{6}})^2 + (-\frac{3}{\sqrt{6}})^2} = \sqrt{3}\)&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;14-key-properties-of-layer-normalization&quot;&gt;14. Key Properties of Layer Normalization&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-properties&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-properties&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Layer normalization invariance properties have profound implications for deep learning. Scale invariance means a model does not need to learn separate weights for inputs of different magnitudes. Shift invariance means it can focus on relative patterns rather than absolute values. Together, these properties create a more stable optimization landscape and better generalization, especially for models like transformers that must process inputs with widely varying scales and offsets. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;The geometric perspective reveals several important properties that explain why layer normalization is so effective in neural networks:&lt;/p&gt;

&lt;p&gt;The normalized vector always has zero mean. This follows directly from the centered vector being on the hyperplane orthogonal to the all-ones vector:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\mathrm{E}[y] = \frac{1}{\sqrt{\mathrm{Var}[x]}} \cdot \frac{1}{d} \left( \sum_{i=1}^d x_i - d \cdot \mathrm{E}[x] \right) = 0
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This property ensures that subsequent layers receive well-centered inputs, preventing activation saturation.&lt;/p&gt;

&lt;p&gt;The normalized vector always has unit variance:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\mathrm{Var}[y] = \frac{1}{d \cdot \mathrm{Var}[x]} \sum_{i=1}^d (x_i - \mathrm{E}[x])^2 = \frac{d \cdot \mathrm{Var}[x]}{d \cdot \mathrm{Var}[x]} = 1
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This stabilizes gradient magnitudes during backpropagation, making the optimization more consistent.&lt;/p&gt;

&lt;p&gt;If we multiply all elements of \(x\) by a constant \(c\), the output remains unchanged:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\frac{c \cdot x - \mathrm{E}[c \cdot x]}{\sqrt{\mathrm{Var}[c \cdot x]}} = \frac{c \cdot (x - \mathrm{E}[x])}{c \cdot \sqrt{\mathrm{Var}[x]}} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This makes neural networks more robust to input scaling variations.&lt;/p&gt;

&lt;p&gt;If we add a constant \(b\) to all elements of \(x\), the output remains unchanged:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\frac{x + b - \mathrm{E}[x + b]}{\sqrt{\mathrm{Var}[x + b]}} = \frac{x + b - (\mathrm{E}[x] + b)}{\sqrt{\mathrm{Var}[x]}} = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Geometrically, this means adding a vector along the all-ones direction, which the projection removes entirely.&lt;/p&gt;

&lt;p&gt;The normalized vector always has a norm of \(\sqrt{d}\):&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\|y\| = \frac{\|x - \mathrm{E}[x]\|}{\sqrt{\mathrm{Var}[x]}} = \frac{\sqrt{d \cdot \mathrm{Var}[x]}}{\sqrt{\mathrm{Var}[x]}} = \sqrt{d}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This consistent magnitude helps prevent exploding or vanishing gradients in deep networks.&lt;/p&gt;

&lt;h2 id=&quot;15-why-this-geometric-interpretation-matters&quot;&gt;15. Why This Geometric Interpretation Matters&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-matters&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-matters&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;This relationship connects two seemingly different mathematical domains: geometry and statistics. The equality \(\|p_1(x)\|^2 = d \cdot \mathrm{Var}[x]\) shows that the geometric concept of distance in the centered subspace directly corresponds to the statistical concept of variance scaled by dimension. Historical Note: This connection has been implicitly used in statistics for decades, particularly in Principal Component Analysis (PCA), but the explicit relationship between variance and projection distance in the context of neural network normalization was only formalized with layer normalization techniques. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;The geometric perspective on layer normalization provides several significant insights beyond the standard statistical view:&lt;/p&gt;

&lt;p&gt;The projection onto the hyperplane orthogonal to the all-ones vector reduces the effective dimensionality from \(d\) to \(d-1\). This removes a redundant degree of freedom (the common offset), allowing the network to focus its capacity on modeling informative patterns of variation between features.&lt;/p&gt;

&lt;p&gt;By projecting onto the hyperplane, layer normalization isolates and removes the “common mode” component—the uniform signal across all dimensions. In many contexts, this global offset carries less discriminative information than the relative variations between features.&lt;/p&gt;

&lt;p&gt;While normalizing the length, layer normalization preserves the direction of the centered vector. This maintains the relative relationships between features, which often encode the essential information extracted from the input.&lt;/p&gt;

&lt;p&gt;Standardizing the scale and removing shifts creates a more symmetrical optimization landscape. This makes gradient descent more effective by preventing pathological curvature and allowing more balanced optimization steps across different dimensions.&lt;/p&gt;

&lt;p&gt;The geometric view clarifies why layer normalization is invariant to both shifts and rescalings of the input—properties that make networks more robust to variations in input distributions and reduce the need for careful data preprocessing.&lt;/p&gt;

&lt;p&gt;This interpretation bridges statistical operations (centering, standardizing) and geometric transformations (projections), providing a unifying framework that enhances our understanding of how neural networks process information through sequential layers.&lt;/p&gt;

&lt;h2 id=&quot;16-applications-and-comparison-with-normalization-techniques&quot;&gt;16. Applications and Comparison with Normalization Techniques&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-comparison&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-comparison&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Comparison of Normalization Techniques Through the Geometric Lens: Batch Normalization normalizes across the batch dimension, effectively projecting onto hyperplanes defined by batch statistics for each feature. This makes it dependent on batch size and requiring running statistics during inference. Instance Normalization, used in image processing, applies normalization to each channel separately, performing projections in channel-specific subspaces. This is particularly effective for style transfer tasks. Group Normalization is a middle ground between layer and instance normalization, dividing channels into groups and normalizing within each group. Geometrically, this corresponds to projecting onto group-specific hyperplanes. Weight Normalization, instead of normalizing activations, normalizes weight vectors by projecting them onto unit spheres. This aims to improve the conditioning of the optimization problem from the parameter side rather than the activation side. Each technique corresponds to a different choice of projection subspace, with layer normalization offering the advantage of being independent of batch statistics while still normalizing across the full feature dimension. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;Layer normalization has become a critical component in modern deep neural networks for several key reasons:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;1. Gradient Stability&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;By normalizing activations, layer normalization helps prevent exploding or vanishing gradients, a critical issue in deep networks. The consistent scale at each layer ensures gradients remain within a reasonable range as they propagate backward through the network.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;2. Faster Convergence&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The standardized scale and zero mean of normalized activations create a more favorable optimization landscape. This allows optimizers to take larger, more effective steps, reducing the number of iterations needed to reach good solutions.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;3. Reduction of Internal Covariate Shift&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Normalization stabilizes the distributions of network activations, preventing the phenomenon where each layer must continuously adapt to shifting input statistics. This allows each layer to learn more efficiently.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;4. Independence from Batch Size&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Unlike batch normalization, layer normalization operates independently for each sample, making it ideal for:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Variable batch sizes&lt;/li&gt;
  &lt;li&gt;Recurrent neural networks&lt;/li&gt;
  &lt;li&gt;Transformer architectures&lt;/li&gt;
  &lt;li&gt;Online learning scenarios&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;This independence from batch statistics provides consistent behavior during both training and inference, eliminating the need for running statistics.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;5. Facilitation of Deep Architectures&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Layer normalization has been crucial for enabling the training of very deep networks, particularly transformers with dozens or hundreds of layers. By stabilizing the signal through these deep stacks, it prevents the compounding effects of statistical irregularities.&lt;/p&gt;

&lt;p&gt;The geometric interpretation helps us understand the relationship between different normalization techniques as different projection operations applied to different subspaces of the data.&lt;/p&gt;

&lt;h2 id=&quot;17-conclusion&quot;&gt;17. Conclusion&lt;/h2&gt;
&lt;p&gt;&lt;label for=&quot;sn-conclusion&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-conclusion&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Alternative derivations strengthen mathematical proofs by approaching the same result from different starting points. This particular approach starts from the statistical formula and derives the geometric interpretation, whereas our primary derivation began with the geometric perspective and showed its equivalence to the statistical formulation. This bidirectional relationship establishes a more robust connection between the two domains. It is similar to how in physics, one can derive the laws of motion from either energy principles or force principles and arrive at equivalent formulations. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;Layer normalization, while typically presented as a statistical operation, reveals its deeper nature when viewed through the lens of geometric transformations in vector space. This interpretation unfolds as a sequence of elegant projections:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;
    &lt;p&gt;&lt;strong&gt;Hyperplane Projection&lt;/strong&gt; (Centering): We project the input vector onto the hyperplane orthogonal to the all-ones vector, removing the “common mode” component and centering the representation.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;&lt;strong&gt;Unit Sphere Projection&lt;/strong&gt; (Normalizing): We project the centered vector onto the unit sphere, preserving its direction while standardizing its length.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;&lt;strong&gt;Scaling&lt;/strong&gt;: We scale by \(\sqrt{d}\) to match the conventional formulation, ensuring unit variance.&lt;/p&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;These operations are captured in the formula:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]}} = \sqrt{d} \cdot p_S(p_1(x))
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This geometric perspective provides several key insights:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;It connects statistical operations (centering, standardizing) with geometric transformations (projections)&lt;/li&gt;
  &lt;li&gt;It explains why layer normalization helps gradient-based optimization&lt;/li&gt;
  &lt;li&gt;It reveals why the technique is invariant to shifts and rescalings&lt;/li&gt;
  &lt;li&gt;It provides a unifying framework for understanding various normalization approaches&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The formula \(y = \sqrt{d} \cdot p_S(p_1(x))\) encapsulates this understanding, showing that layer normalization fundamentally projects data onto a standardized subspace where the “common mode” has been removed and the scale has been normalized.&lt;/p&gt;

&lt;p&gt;By viewing layer normalization as a geometric transformation rather than just a statistical operation, we gain a more intuitive understanding of its effects and can better appreciate its role in the remarkable success of modern neural networks, particularly transformers and other deep architectures.&lt;/p&gt;

&lt;h2 id=&quot;references&quot;&gt;References&lt;/h2&gt;

&lt;div class=&quot;work-references&quot;&gt;
&lt;p&gt;[1] Ba, J. L., Kiros, J. R., &amp;amp; Hinton, G. E. (2016). &quot;Layer normalization.&quot; &lt;em&gt;arXiv preprint arXiv:1607.06450&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[2] Ioffe, S., &amp;amp; Szegedy, C. (2015). &quot;Batch normalization: Accelerating deep network training by reducing internal covariate shift.&quot; In &lt;em&gt;International Conference on Machine Learning&lt;/em&gt; (pp. 448-456).&lt;/p&gt;
&lt;p&gt;[3] Ulyanov, D., Vedaldi, A., &amp;amp; Lempitsky, V. (2016). &quot;Instance normalization: The missing ingredient for fast stylization.&quot; &lt;em&gt;arXiv preprint arXiv:1607.08022&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[4] Wu, Y., &amp;amp; He, K. (2018). &quot;Group normalization.&quot; In &lt;em&gt;Proceedings of the European Conference on Computer Vision (ECCV)&lt;/em&gt; (pp. 3-19).&lt;/p&gt;
&lt;p&gt;[5] Salimans, T., &amp;amp; Kingma, D. P. (2016). &quot;Weight normalization: A simple reparameterization to accelerate training of deep neural networks.&quot; &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, 29, 901-909.&lt;/p&gt;
&lt;/div&gt;
</description>
        <pubDate>Sat, 17 May 2025 00:00:00 -0400</pubDate>
        <link>https://aakashvarma.github.io/layernorm/</link>
        <guid isPermaLink="true">https://aakashvarma.github.io/layernorm/</guid>
        
      </item>
    
      <item>
        <title>The RoPE Compatibility Problem in DeepSeek&apos;s Multi Head Latent Attention</title>
        <description>&lt;h2 id=&quot;1-introduction&quot;&gt;1. Introduction&lt;/h2&gt;

&lt;h3 id=&quot;11-multi-head-latent-attention-advancing-inference-efficiency-in-large-language-models&quot;&gt;1.1 Multi-Head Latent Attention: Advancing Inference Efficiency in Large Language Models&lt;/h3&gt;

&lt;p&gt;Large Language Models (LLMs) have transformed natural language processing capabilities, yet their deployment presents substantial challenges as model size increases to hundreds of billions of parameters with extended context windows of tens or hundreds of thousands of tokens. During the autoregressive generation process, the Key-Value (KV) cache emerges as a critical bottleneck, presenting organizations with a fundamental trade-off between computational efficiency and memory resource allocation.&lt;/p&gt;

&lt;p&gt;Without KV caching, the computational complexity for generating each token scales quadratically with sequence length (&lt;em&gt;O(n&lt;sup&gt;2&lt;/sup&gt;)&lt;/em&gt; per token), while maintaining minimal memory requirements &lt;em&gt;O(1)&lt;/em&gt;. This approach becomes prohibitively expensive for long sequences, as each new token would require recomputing attention scores with all previous tokens. KV caching strategy reduce this to linear computational complexity (&lt;em&gt;O(n)&lt;/em&gt; per token), but at the cost of linear memory growth &lt;em&gt;O(n)&lt;/em&gt;. For standard Multi-Head Attention (MHA), the total KV cache memory consumption can be expressed as:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{Memory}_{\text{MHA}} = B \times L \times N_L \times 2 \times N_H \times D_H \times P \tag{1}
&lt;/script&gt;&lt;/div&gt;
&lt;p&gt;&lt;label for=&quot;sn-memory-example&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-memory-example&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Concrete Memory Savings Example&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;Lets calculate the memory requirements for a typical large language model with:&lt;br /&gt;- Batch size (B) = 1&lt;br /&gt;- Sequence length (L) = 32,768 tokens&lt;br /&gt;- Number of layers (NL) = 32&lt;br /&gt;- Number of heads (NH) = 32&lt;br /&gt;- Head dimension (DH) = 128&lt;br /&gt;- Precision (P) = 2 bytes (FP16)&lt;br /&gt;- MLA content dimension (DC) = 64&lt;br /&gt;- MLA rotary dimension (DR) = 8&lt;br /&gt;&lt;br /&gt;&lt;strong&gt;Standard MHA memory (Equation 1):&lt;/strong&gt;&lt;br /&gt;Memory for standard multi-head attention is given by: Memory_MHA = B × L × NL × 2 × NH × DH × P.&lt;br /&gt;&lt;br /&gt;With our parameters this becomes:&lt;br /&gt;Memory_MHA = 1 × 32,768 × 32 × 2 × 32 × 128 × 2 bytes,&lt;br /&gt;which simplifies step by step to 17,179,869,184 bytes, or approximately 16 GB.&lt;br /&gt;&lt;br /&gt;&lt;strong&gt;MLA memory with compression (Equation 2):&lt;/strong&gt;&lt;br /&gt;For MLA, memory is given by: Memory_MLA = B × L × NL × (DC + DR) × P.&lt;br /&gt;&lt;br /&gt;With our parameters this becomes:&lt;br /&gt;Memory_MLA = 1 × 32,768 × 32 × (64 + 8) × 2 bytes,&lt;br /&gt;which simplifies step by step to 151,003,136 bytes, or approximately 144 MB.&lt;br /&gt;&lt;br /&gt;&lt;strong&gt;Memory reduction ratio:&lt;/strong&gt;&lt;br /&gt;The reduction ratio is Memory_MHA divided by Memory_MLA, i.e., 17,179,869,184 bytes divided by 151,003,136 bytes, which is about 113.77, corresponding to roughly a 99.1% reduction in KV cache memory. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;This dramatic reduction enables models to handle much longer contexts with the same hardware, or allows deployment on more resource-constrained devices.&lt;/p&gt;

&lt;p&gt;This formula encapsulates the memory requirements across batch size (&lt;em&gt;B&lt;/em&gt;), sequence length (&lt;em&gt;L&lt;/em&gt;), number of layers (&lt;em&gt;N&lt;sub&gt;L&lt;/sub&gt;&lt;/em&gt;), heads (&lt;em&gt;N&lt;sub&gt;H&lt;/sub&gt;&lt;/em&gt;), head dimension (&lt;em&gt;D&lt;sub&gt;H&lt;/sub&gt;&lt;/em&gt;), and precision (&lt;em&gt;P&lt;/em&gt;). The factor of 2 accounts for storing both keys and values separately. As models grow larger and context windows expand, this memory requirement becomes increasingly untenable, even on high-end hardware accelerators.&lt;/p&gt;

&lt;p&gt;Multi-Head Latent Attention (MLA) addresses this challenge through a novel approach that transforms the fundamental equation of memory consumption:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{Memory}_{\text{MLA}} = B \times L \times N_L \times (D_C + D_R) \times P \tag{2}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where &lt;em&gt;D&lt;sub&gt;C&lt;/sub&gt;&lt;/em&gt; is the compression KV dimension and &lt;em&gt;D&lt;sub&gt;R&lt;/sub&gt;&lt;/em&gt; is the dimension of the key rotary position component. This reformulation enables substantial memory savings without compromising model capabilities, creating new possibilities for deploying models in resource-constrained environments.&lt;/p&gt;

&lt;h4 id=&quot;architectural-differences-mha-vs-mla&quot;&gt;Architectural Differences: MHA vs. MLA&lt;/h4&gt;

&lt;p&gt;Standard Multi-Head Attention (MHA) and Multi-Head Latent Attention (MLA) share the same high-level goal of enabling tokens to attend to each other, but differ significantly in their internal architecture and memory efficiency characteristics. In standard MHA, each token’s hidden representation undergoes three parallel linear projections to create query, key, and value vectors. This process can be represented mathematically as:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t = W^Q h_t \tag{3}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t = W^K h_t \tag{4}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
v_t = W^V h_t \tag{5}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;These projections are then split into &lt;em&gt;N&lt;sub&gt;H&lt;/sub&gt;&lt;/em&gt; attention heads, each operating in a lower-dimensional space:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t^i, k_t^i, v_t^i \in \mathbb{R}^{D_H} \tag{6}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The attention mechanism computes weighted interactions between tokens, where the weights are determined by the compatibility between queries and keys. For each head &lt;em&gt;i&lt;/em&gt;, the attention output is computed as:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{o}_{t,i} = \sum_{j=1}^{t} \text{Softmax}_j\left(\frac{(\text{q}_{t,i})^T\text{k}_{j,i}}{\sqrt{d_h}}\right)\text{v}_{j,i} \tag{7}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} \text{where:} \\ \text{j} &amp;: \text{ Position index } (1 \leq j \leq t) \text{ of previous tokens in the sequence} \\ \text{t} &amp;: \text{ Current position in the sequence} \\ \text{q}_{t,i} &amp;: \text{ Query vector at position } t \text{ for head } i \\ \text{k}_{j,i} &amp;: \text{ Key vector at position } j \text{ for head } i \\ \text{v}_{j,i} &amp;: \text{ Value vector at position } j \text{ for head } i \\ \text{d}_h &amp;: \text{ Dimension of each attention head} \\ \text{o}_{t,i} &amp;: \text{ Output of attention at position } t \text{ for head } i \\ \text{i} &amp;\in \{1, 2, \ldots, n_h\}, \text{ where } n_h \text{ is the total number of attention heads} \end{align} \tag{8}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The outputs from all heads are concatenated and projected through an output matrix:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{u}_t = W^O[\text{o}_{t,1}; \text{o}_{t,2}; ...; \text{o}_{t,n_h}] \tag{9}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;During inference, MHA caches the full key and value vectors for each token across all layers and heads, creating substantial memory pressure as sequence length increases. MLA fundamentally reimagines this architecture by introducing an intermediate compression step and decoupling content information from positional information. The architecture consists of two parallel paths:&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/mla/mla_arch.png&quot; alt=&quot;MLA Architecture&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 1: Architecture of Multi Head Latent Attention
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;Content Path&lt;/strong&gt; (with compression):&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
c_t^{KV} = W^{DKV}h_t \tag{10}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t^C = W^{UK}c_t^{KV} \tag{11}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
v_t = W^{UV}c_t^{KV} \tag{12}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t^C = W^{UQ}c_t^{Q} \tag{13}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
c_t^{Q} = W^{DQ}h_t \tag{14}
&lt;/script&gt;&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;Position Path&lt;/strong&gt; (with RoPE):&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t^R = R_{\Theta,t}^d \cdot W^{KR}h_t \tag{15}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t^R = R_{\Theta,t}^d \cdot W^{QR}c_t^Q \tag{16}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where &lt;em&gt;R&lt;sub&gt;Θ,t&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt; represents the rotary position encoding matrix. The final key and query representations are formed by concatenating both components:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t = [k_t^C; k_t^R] \tag{17}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t = [q_t^C; q_t^R] \tag{18}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;During inference, MLA caches both the compressed latent vectors &lt;em&gt;c&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;&lt;/em&gt; and the rotary key components &lt;em&gt;k&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;R&lt;/sup&gt;&lt;/em&gt; as shown in Figure 2. This is reflected in the memory formula, where &lt;em&gt;D&lt;sub&gt;C&lt;/sub&gt;&lt;/em&gt; represents the dimension of &lt;em&gt;c&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;&lt;/em&gt; and &lt;em&gt;D&lt;sub&gt;R&lt;/sub&gt;&lt;/em&gt; represents the dimension of &lt;em&gt;k&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;R&lt;/sup&gt;&lt;/em&gt;.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/mla/cache.png&quot; alt=&quot;MLA Tensor Cache&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 2: Cache both the compressed latent vectors *c&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;* and the rotary key components *k&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;R&lt;/sup&gt;*
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;The attention calculation in MLA becomes:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{o}_{t,i} = \sum_{j=1}^{t} \text{Softmax}_j\left(\frac{(\text{q}_{t,i}^C)^T\text{k}_{j,i}^C + (\text{q}_{t,i}^R)^T\text{k}_j^R}{\sqrt{d_h + d_h^R}}\right)\text{v}_{j,i}^C \tag{19}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} \text{where:}\\ \text{j} &amp;: \text{ Position index } (1 \leq j \leq t) \text{ of previous tokens in the sequence} \\ \text{t} &amp;: \text{ Current position in the sequence} \\ \text{q}_{t,i}^C &amp;: \text{ Content component of query vector at position } t \text{ for head } i \\ \text{k}_{j,i}^C &amp;: \text{ Content component of key vector at position } j \text{ for head } i \\ \text{q}_{t,i}^R &amp;: \text{ Rotary position component of query vector at position } t \text{ for head } i \\ \text{k}_j^R &amp;: \text{ Rotary position component of key vector at position } j \\ \text{v}_{j,i}^C &amp;: \text{ Value vector at position } j \text{ for head } i \\ \text{d}_h &amp;: \text{ Dimension of the content component} \\ \text{d}_h^R &amp;: \text{ Dimension of the rotary position component} \\ \text{o}_{t,i} &amp;: \text{ Output of attention at position } t \text{ for head } i \\ \text{i} &amp;\in \{1, 2, \ldots, n_h\}, \text{ where } n_h \text{ is the total number of attention heads} \end{align} \tag{20}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;And the final multi-head output remains:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{u}_t = W^O[\text{o}_{t,1}; \text{o}_{t,2}; ...; \text{o}_{t,n_h}] \tag{21}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This formulation separates content-based attention (first term) from position-aware attention (second term), allowing each to be processed optimally. The content path can be efficiently compressed without worrying about position encoding, while the position path handles rotary encodings separately, maintaining relative position awareness. This decoupling strategy is particularly important because applying rotary position encodings directly to compressed representations would create mathematical incompatibilities during inference, requiring costly recomputations for each new token (which we will see in the rest of the article with derivations). By separating content from position, MLA achieves both memory efficiency and computational efficiency.&lt;/p&gt;

&lt;h3 id=&quot;12-rotary-position-embeddings-rope-mathematical-foundations&quot;&gt;1.2 Rotary Position Embeddings (RoPE): Mathematical Foundations&lt;/h3&gt;

&lt;p&gt;Transformer architectures have demonstrated remarkable efficacy across diverse natural language processing tasks, yet they inherently lack sequential awareness due to their parallel token processing mechanism. To mitigate this limitation, position encoding methodologies have been developed to incorporate sequential information into the representation space. Among these approaches, Rotary Position Embedding (RoPE), introduced by Su et al. (2021), represents a mathematically sophisticated advancement in positional encoding.&lt;/p&gt;

&lt;p&gt;RoPE encodes positional information by applying a position-dependent rotation to pairs of dimensions in the embedding space. For a token at position &lt;em&gt;m&lt;/em&gt; with embedding vector &lt;em&gt;𝐱&lt;sub&gt;m&lt;/sub&gt; ∈ ℝ&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt;, RoPE transforms query and key vectors as follows:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
f_q(\mathbf{x}_m, m) = (\mathbf{W}_q\mathbf{x}_m)e^{im\theta} \tag{22}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
f_k(\mathbf{x}_n, n) = (\mathbf{W}_k\mathbf{x}_n)e^{in\theta} \tag{23}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Here, the complex exponential &lt;em&gt;e&lt;sup&gt;imθ&lt;/sup&gt;&lt;/em&gt; represents rotation in the complex plane. This operation rotates the query and key vectors by angles proportional to their positions in the sequence. The rotation angle increases with the position index, creating unique position-dependent transformations for each token. For practical implementation in neural networks, these complex number rotations are expressed using real-valued rotation matrices. For embedding vectors with dimension &lt;em&gt;d&lt;/em&gt; (where &lt;em&gt;d&lt;/em&gt; is even), we can view the embedding space as composed of &lt;em&gt;d/2&lt;/em&gt; two-dimensional subspaces. In each two-dimensional subspace corresponding to dimensions &lt;em&gt;(2i-1, 2i)&lt;/em&gt;, RoPE applies a 2×2 rotation matrix:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{pmatrix} \cos m\theta_i &amp; -\sin m\theta_i \\ \sin m\theta_i &amp; \cos m\theta_i \end{pmatrix} \tag{24}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Generalizing to a &lt;em&gt;d&lt;/em&gt;-dimensional space (where &lt;em&gt;d&lt;/em&gt; is even), RoPE uses a block-diagonal rotation matrix &lt;em&gt;R&lt;sub&gt;Θ,m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt;:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
R_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 &amp; -\sin m\theta_1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 &amp; 0 \\ \sin m\theta_1 &amp; \cos m\theta_1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 &amp; 0 \\ 0 &amp; 0 &amp; \cos m\theta_2 &amp; -\sin m\theta_2 &amp; \cdots &amp; 0 &amp; 0 \\ 0 &amp; 0 &amp; \sin m\theta_2 &amp; \cos m\theta_2 &amp; \cdots &amp; 0 &amp; 0 \\ \vdots &amp; \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots &amp; \vdots \\ 0 &amp; 0 &amp; 0 &amp; 0 &amp; \cdots &amp; \cos m\theta_{d/2} &amp; -\sin m\theta_{d/2} \\ 0 &amp; 0 &amp; 0 &amp; 0 &amp; \cdots &amp; \sin m\theta_{d/2} &amp; \cos m\theta_{d/2} \end{pmatrix} \tag{25}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where &lt;em&gt;θ&lt;sub&gt;i&lt;/sub&gt; = 10000&lt;sup&gt;-2(i-1)/d&lt;/sup&gt;&lt;/em&gt; for &lt;em&gt;i ∈ [1, 2, …, d/2]&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;Relative position encoding: The principal advantage of Rotary Position Embedding (RoPE) is its intrinsic capacity to encode relative positional information rather than absolute positions. This property becomes mathematically evident when examining the attention mechanism. For a query vector at position &lt;em&gt;m&lt;/em&gt; and a key vector at position &lt;em&gt;n&lt;/em&gt;, the attention score is formulated as:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_m^T k_n = (R_{\Theta,m}^d \mathbf{W}_q\mathbf{x}m)^T(R_{\Theta,n}^d \mathbf{W}_k\mathbf{x}_n) = \mathbf{x}_m^T \mathbf{W}q^T R_{\Theta,m-n}^d \mathbf{W}_k\mathbf{x}_n \tag{26}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;&lt;label for=&quot;sn-rope-proof&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-rope-proof&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Mathematical Proof of Relative Position Property&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;This relative position property comes from fundamental properties of rotation matrices:&lt;br /&gt;&lt;br /&gt;&lt;strong&gt;Property 1:&lt;/strong&gt; The transpose of a rotation matrix inverts the rotation:&lt;br /&gt;(R&lt;sub&gt;Θ,m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;)&lt;sup&gt;T&lt;/sup&gt; = R&lt;sub&gt;Θ,-m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;br /&gt;&lt;br /&gt;&lt;strong&gt;Property 2:&lt;/strong&gt; Multiplying rotation matrices compounds their rotations:&lt;br /&gt;R&lt;sub&gt;Θ,a&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; · R&lt;sub&gt;Θ,b&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; = R&lt;sub&gt;Θ,a+b&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;br /&gt;&lt;br /&gt;Therefore:&lt;br /&gt;(R&lt;sub&gt;Θ,m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;)&lt;sup&gt;T&lt;/sup&gt; · R&lt;sub&gt;Θ,n&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; = R&lt;sub&gt;Θ,-m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; · R&lt;sub&gt;Θ,n&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; = R&lt;sub&gt;Θ,-m+n&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; = R&lt;sub&gt;Θ,n-m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;br /&gt;&lt;br /&gt;This mathematical property enables the model to naturally compute relative positional relationships between tokens without storing absolute positions. &lt;/span&gt;&lt;/p&gt;

&lt;p&gt;Where &lt;em&gt;R&lt;sub&gt;Θ,m-n&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; = (R&lt;sub&gt;Θ,m&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;)&lt;sup&gt;T&lt;/sup&gt; R&lt;sub&gt;Θ,n&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt;. This means the attention score naturally incorporates relative position information &lt;em&gt;(m-n)&lt;/em&gt; rather than absolute positions. Consequently, this mathematical property enables transformer architectures to develop position-invariant representations of token relationships, thereby enhancing the model’s capability to capture linguistic dependencies across diverse contextual environments.&lt;/p&gt;

&lt;h2 id=&quot;2-the-decoupled-rope-strategy-in-mla&quot;&gt;2. The Decoupled RoPE Strategy in MLA&lt;/h2&gt;

&lt;h3 id=&quot;21-separating-content-and-position-information&quot;&gt;2.1 Separating Content and Position Information&lt;/h3&gt;

&lt;p&gt;The key innovation in MLA is the decoupled Rotary Position Embedding (RoPE) strategy, which elegantly separates content information from positional information:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Content Path (no position encoding):&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
c_t^{KV} = W^{DKV}h_t \tag{27}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t^C = W^{UK}c_t^{KV} \tag{28}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t^C = W^{UQ}c_t^Q \tag{29}
&lt;/script&gt;&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;Position Path (Rotated keys and queries):&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t^R = R_{\Theta,t}^d \cdot W^{KR}h_t \tag{30}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t^R = R_{\Theta,t}^d \cdot W^{QR}c_t^Q \tag{31}
&lt;/script&gt;&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;Final Representations (concatenation):&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t = [k_t^C; k_t^R] \tag{32}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_t = [q_t^C; q_t^R] \tag{33}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The MLA approach bifurcates the attention mechanism into dual parallel pathways, enabling distinct processing optimizations for different aspects of token representation. This architectural design represents a significant advancement over traditional attention mechanisms by addressing the fundamental tension between computational efficiency and positional awareness.&lt;/p&gt;

&lt;p&gt;In the content path, semantic information is processed without positional encoding, allowing for substantial dimensionality reduction through compression. The down-projection matrix &lt;em&gt;W&lt;sup&gt;DKV&lt;/sup&gt;&lt;/em&gt; transforms the high-dimensional hidden state &lt;em&gt;h&lt;sub&gt;t&lt;/sub&gt;&lt;/em&gt; into a compact latent representation &lt;em&gt;c&lt;sub&gt;t&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;&lt;/em&gt; with dimension &lt;em&gt;d&lt;sub&gt;c&lt;/sub&gt;&lt;/em&gt;, where typically &lt;em&gt;d&lt;sub&gt;c&lt;/sub&gt; ≪ n&lt;sub&gt;heads&lt;/sub&gt; × d&lt;sub&gt;head&lt;/sub&gt;&lt;/em&gt;. This compression captures the essential semantic content while eliminating redundant information, resulting in a more memory-efficient representation that can be cached during inference.&lt;/p&gt;

&lt;p&gt;The separate position path maintains spatial awareness through Rotary Position Embeddings (RoPE), applied via the rotation matrix &lt;em&gt;R&lt;sub&gt;Θ,t&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt;. By isolating positional information in a dedicated pathway with dimension &lt;em&gt;d&lt;sub&gt;R&lt;/sub&gt;&lt;/em&gt; (typically much smaller than the content dimension), MLA preserves the model’s ability to understand token relationships without applying position encodings to the compressed representations. This separation is crucial for preventing the computational challenges that would arise from applying rotation matrices to compressed vectors (which we will explore in below sections).&lt;/p&gt;

&lt;h3 id=&quot;22-attention-calculation-with-decoupled-rope&quot;&gt;2.2 Attention Calculation with Decoupled RoPE&lt;/h3&gt;

&lt;p&gt;The attention score between a query at position &lt;em&gt;p&lt;/em&gt; and a key at position &lt;em&gt;j&lt;/em&gt; becomes:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
a_{pj} = \frac{q_p^T k_j}{\sqrt{d_h + d_h^R}} = \frac{(q_p^C)^T k_j^C + (q_p^R)^T k_j^R}{\sqrt{d_h + d_h^R}} \tag{34}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Expanding each component:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Content Path:&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} (q_p^C)^T k_j^C &amp;= (W^{UQ}c_p^Q)^T(W^{UK}c_j^{KV}) \\ &amp;= (c_p^Q)^T (W^{UQ})^T W^{UK} c_j^{KV} \end{align} \tag{35}
&lt;/script&gt;&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;Positional Path (with RoPE):&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} (q_p^R)^T k_j^R &amp;= (R_{\Theta,p}^d \cdot W^{QR}c_p^Q)^T (R_{\Theta,j}^d \cdot W^{KR}h_j) \\ &amp;= (c_p^Q)^T (W^{QR})^T (R_{\Theta,p}^d)^T R_{\Theta,j}^d W^{KR} h_j \\ &amp;= (c_p^Q)^T (W^{QR})^T R_{\Theta,p-j}^d W^{KR} h_j \end{align} \tag{36}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This decomposition highlights how the content similarity component measures semantic relationships independent of position, while the positional relationship component explicitly encodes the relative position &lt;em&gt;(p-j)&lt;/em&gt; between tokens. The additive interaction between these components in the attention calculation allows the model to consider both semantic compatibility and positional context when determining attention weights.&lt;/p&gt;

&lt;h3 id=&quot;23-optimizations-for-efficient-inference&quot;&gt;2.3 Optimizations for Efficient Inference&lt;/h3&gt;

&lt;p&gt;For the content similarity component, MLA employs matrix absorption:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} (q_p^C)^T k_j^C &amp;= (W^{UQ}c_p^Q)^T(W^{UK}c_j^{KV}) \\ &amp;= (c_p^Q)^T (W^{UQ})^T W^{UK} c_j^{KV} \end{align} \tag{37}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;By defining the absorbed matrix &lt;em&gt;(W&lt;sup&gt;UQ&lt;/sup&gt;)’ = (W&lt;sup&gt;UQ&lt;/sup&gt;)&lt;sup&gt;T&lt;/sup&gt; W&lt;sup&gt;UK&lt;/sup&gt;&lt;/em&gt;, we get:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} (q_p^C)^T k_j^C &amp;= (c_p^Q)^T (W^{UQ})&apos; c_j^{KV} \\ &amp;= ((W^{UQ})&apos; c_p^Q)^T c_j^{KV} \end{align} \tag{38}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This optimization represents a significant computational efficiency gain during inference. By pre-computing the absorbed matrix &lt;em&gt;(W&lt;sup&gt;UQ&lt;/sup&gt;)’&lt;/em&gt;, we transform what would be two sequential matrix multiplications (the up-projection of query followed by dot product with up-projected key) into a single multiplication followed by a dot product with the compressed key vector.&lt;/p&gt;

&lt;p&gt;The absorbed matrix &lt;em&gt;(W&lt;sup&gt;UQ&lt;/sup&gt;)’&lt;/em&gt; effectively encapsulates both the query and key up-projection operations in a single transformation. This is particularly valuable during inference, as it reduces the computational overhead for each token generation step. The operation &lt;em&gt;((W&lt;sup&gt;UQ&lt;/sup&gt;)’ c&lt;sub&gt;p&lt;/sub&gt;&lt;sup&gt;Q&lt;/sup&gt;)&lt;sup&gt;T&lt;/sup&gt; c&lt;sub&gt;j&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;&lt;/em&gt; directly computes the content similarity using only the compressed representations, without requiring full decompression of the cached vectors.&lt;/p&gt;

&lt;h3 id=&quot;24-maintaining-relative-position-in-the-position-path&quot;&gt;2.4 Maintaining Relative Position in the Position Path&lt;/h3&gt;

&lt;p&gt;For the positional component, we have:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{align} (q_p^R)^T k_j^R &amp;= (R_{\Theta,p}^d \cdot W^{QR}c_p^Q)^T (R_{\Theta,j}^d \cdot W^{KR}h_j) \\ &amp;= (c_p^Q)^T (W^{QR})^T (R_{\Theta,p}^d)^T R_{\Theta,j}^d W^{KR} h_j \\ &amp;= (c_p^Q)^T (W^{QR})^T R_{\Theta,p-j}^d W^{KR} h_j \end{align} \tag{39}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Since &lt;em&gt;R&lt;sub&gt;Θ,p-j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; = (R&lt;sub&gt;Θ,p&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;)&lt;sup&gt;T&lt;/sup&gt; R&lt;sub&gt;Θ,j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt;, the attention score naturally encodes the relative position &lt;em&gt;(p-j)&lt;/em&gt; between the tokens. This mathematical property is central to the effectiveness of the decoupled RoPE approach. The fundamental challenge in position encoding for efficient inference is maintaining awareness of relative positions while avoiding recomputation of key vectors for each new token. The decoupled approach solves this by leveraging a key property of rotation matrices: the product of a rotation matrix and its transpose encodes the relative angle between them. This means that by caching the position-encoded vectors &lt;em&gt;k&lt;sub&gt;j&lt;/sub&gt;&lt;sup&gt;R&lt;/sup&gt; = R&lt;sub&gt;Θ,j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; ⋅ W&lt;sup&gt;KR&lt;/sup&gt;h&lt;sub&gt;j&lt;/sub&gt;&lt;/em&gt; for each token position &lt;em&gt;j&lt;/em&gt;, and computing &lt;em&gt;q&lt;sub&gt;p&lt;/sub&gt;&lt;sup&gt;R&lt;/sup&gt; = R&lt;sub&gt;Θ,p&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; ⋅ W&lt;sup&gt;QR&lt;/sup&gt;c&lt;sub&gt;p&lt;/sub&gt;&lt;sup&gt;Q&lt;/sup&gt;&lt;/em&gt; for the current position &lt;em&gt;p&lt;/em&gt;, their dot product naturally captures the relative positional relationship without requiring recomputation of previous keys.&lt;/p&gt;

&lt;h3 id=&quot;25-complete-inference-time-attention-calculation&quot;&gt;2.5 Complete Inference-Time Attention Calculation&lt;/h3&gt;

&lt;p&gt;During inference, the optimized attention calculation becomes:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
a_{pj} = \frac{((W^{UQ})&apos; c_p^Q)^T c_j^{KV} + (R_{\Theta,p}^d \cdot W^{QR}c_p^Q)^T k_j^R}{\sqrt{d_h + d_h^R}} \tag{40}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Where &lt;em&gt;(W&lt;sup&gt;UQ&lt;/sup&gt;)’&lt;/em&gt; is pre-computed, &lt;em&gt;c&lt;sub&gt;j&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;&lt;/em&gt; and &lt;em&gt;k&lt;sub&gt;j&lt;/sub&gt;&lt;sup&gt;R&lt;/sup&gt;&lt;/em&gt; are cached for all previous tokens, and we only calculate &lt;em&gt;c&lt;sub&gt;p&lt;/sub&gt;&lt;sup&gt;Q&lt;/sup&gt;&lt;/em&gt; and &lt;em&gt;(R&lt;sub&gt;Θ,p&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; ⋅ W&lt;sup&gt;QR&lt;/sup&gt;c&lt;sub&gt;p&lt;/sub&gt;&lt;sup&gt;Q&lt;/sup&gt;)&lt;/em&gt; for the current token.&lt;/p&gt;

&lt;h2 id=&quot;3-why-the-naive-approach-to-combining-rope-and-mla-fails&quot;&gt;3. Why the Naive Approach to Combining RoPE and MLA Fails?&lt;/h2&gt;

&lt;p&gt;Now that we understand the decoupled RoPE solution, let’s examine why a more straightforward approach doesn’t work.&lt;/p&gt;

&lt;h3 id=&quot;31-the-naive-approach-applying-rope-after-decompression&quot;&gt;3.1 The Naive Approach: Applying RoPE After Decompression&lt;/h3&gt;

&lt;p&gt;A seemingly natural way to combine RoPE with MLA would be to apply the rotational encoding after decompressing the cached latent vectors as shown in Figure 3:&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/mla/mla_arch_wrong.png&quot; alt=&quot;Naive MLA&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 3: Applying RoPR after decompression in MLA architecture
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;ol&gt;
  &lt;li&gt;Compress the hidden states for storage in the KV cache:&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
c_t^{KV} = W^{DKV}h_t \tag{41}
&lt;/script&gt;&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;During attention computation, decompress the cached vectors:&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_t = W^{UK}c_t^{KV} = W^{UK}W^{DKV}h_t \tag{42}
&lt;/script&gt;&lt;/div&gt;

&lt;ol&gt;
  &lt;li&gt;Apply RoPE to the decompressed vectors based on their position:&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_m(m) = R_{\Theta,m}^d \cdot k_m = R_{\Theta,m}^d \cdot W^{UK}W^{DKV}h_m \tag{43}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This approach seems intuitive but creates a fundamental problem during inference.&lt;/p&gt;

&lt;h3 id=&quot;32-the-re-computation-problem&quot;&gt;3.2 The Re-computation Problem&lt;/h3&gt;

&lt;p&gt;During inference, for the current query token at position &lt;em&gt;p&lt;/em&gt; and a key token at position &lt;em&gt;j &amp;lt; p&lt;/em&gt;, the attention score calculation requires:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
a_{pj} = q_p(p)^T k_j(p-j) \tag{44}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Notice that crucial detail: &lt;em&gt;k&lt;sub&gt;j&lt;/sub&gt;(p-j)&lt;/em&gt; – we need the key vector for token &lt;em&gt;j&lt;/em&gt; encoded with the &lt;strong&gt;relative&lt;/strong&gt; position &lt;em&gt;(p-j)&lt;/em&gt;, not just its original absolute position &lt;em&gt;j&lt;/em&gt;. But here’s the problem: during inference, we’ve only cached &lt;em&gt;c&lt;sub&gt;j&lt;/sub&gt;&lt;sup&gt;KV&lt;/sup&gt;&lt;/em&gt; for previous tokens, not their RoPE-encoded keys. To compute &lt;em&gt;k&lt;sub&gt;j&lt;/sub&gt;(p-j)&lt;/em&gt; correctly, we need:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_j(p-j) = R_{\Theta,p-j}^d \cdot W^{UK}c_j^{KV} \tag{45}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This requires applying a different rotation matrix &lt;em&gt;R&lt;sub&gt;Θ,p-j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt; to each cached key, which depends on the distance &lt;em&gt;(p-j)&lt;/em&gt; from the current position &lt;em&gt;p&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;Let’s prove why all keys must be recomputed with each new token. According to the RoPE formulation, the attention score between a query at position &lt;em&gt;p&lt;/em&gt; and a key at position &lt;em&gt;j&lt;/em&gt; is:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_p^T k_j = (R_{\Theta,p}^d \mathbf{W}_q\mathbf{x}_p)^T(R_{\Theta,j}^d \mathbf{W}_k\mathbf{x}_j) = \mathbf{x}_p^T \mathbf{W}_q^T R_{\Theta,p-j}^d \mathbf{W}_k\mathbf{x}_j \tag{46}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;In MLA with compressed representations, this becomes:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_p^T k_j = (R_{\Theta,p}^d \cdot W^{UQ}c_p^Q)^T(R_{\Theta,j}^d \cdot W^{UK}c_j^{KV}) \tag{47}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;But during inference, to capture the correct relative position &lt;em&gt;(p-j)&lt;/em&gt;, we need to recalculate:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_j(p-j) = R_{\Theta,p-j}^d \cdot W^{UK}c_j^{KV} \tag{48}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Or equivalently:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
k_j(p-j) = (R_{\Theta,p}^d)^T R_{\Theta,j}^d \cdot W^{UK}c_j^{KV} \tag{49}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This means for each new token position &lt;em&gt;p&lt;/em&gt;, we must recompute all previous keys with their relative position to &lt;em&gt;p&lt;/em&gt;, which effectively eliminates the benefit of KV caching. Instead of simply retrieving cached vectors, we must perform a matrix multiplication for every previous token with each new step, significantly increasing the computational cost.&lt;/p&gt;

&lt;h3 id=&quot;33-the-matrix-absorption-impossibility&quot;&gt;3.3 The Matrix Absorption Impossibility&lt;/h3&gt;

&lt;p&gt;A natural optimization attempt would be to absorb some of the matrix multiplications. Let’s explore this possibility:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
q_p^T k_j = (R_{\Theta,p}^d \cdot W^{UQ}c_p^Q)^T(R_{\Theta,j}^d \cdot W^{UK}c_j^{KV}) \tag{50}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
= (c_p^Q)^T (W^{UQ})^T (R_{\Theta,p}^d)^T R_{\Theta,j}^d W^{UK} c_j^{KV} \tag{51}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
= (c_p^Q)^T (W^{UQ})^T R_{\Theta,p-j}^d W^{UK} c_j^{KV} \tag{52}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;If the rotation matrix &lt;em&gt;R&lt;sub&gt;Θ,p-j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt; commuted with &lt;em&gt;W&lt;sup&gt;UK&lt;/sup&gt;&lt;/em&gt; (meaning &lt;em&gt;R&lt;sub&gt;Θ,p-j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt; ⋅ W&lt;sup&gt;UK&lt;/sup&gt; = W&lt;sup&gt;UK&lt;/sup&gt; ⋅ R&lt;sub&gt;Θ,p-j&lt;/sub&gt;&lt;sup&gt;d&lt;/sup&gt;&lt;/em&gt;), we could define: [NOTE: THIS IS NOT POSSIBLE]&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
(W^{UQ})&apos; = (W^{UQ})^T (W^{UK}) \tag{53}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;And compute:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
(c_p^Q)^T (W^{UQ})&apos; R_{\Theta,p-j}^d c_j^{KV} \tag{54}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This would allow us to apply the rotation directly to the compressed representations, avoiding the need for decompression and recomputation. However, rotation matrices do not generally commute with arbitrary matrices. To prove this, let’s consider the product of a &lt;em&gt;2×2&lt;/em&gt; rotation matrix and a general &lt;em&gt;2×2&lt;/em&gt; matrix:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
R_{\theta} = \begin{pmatrix} \cos\theta &amp; -\sin\theta \\ \sin\theta &amp; \cos\theta \end{pmatrix} \tag{55}
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
A = \begin{pmatrix} a &amp; b \\ c &amp; d \end{pmatrix} \tag{56}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Computing &lt;em&gt;R&lt;sub&gt;θ&lt;/sub&gt; ⋅ A&lt;/em&gt;:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{pmatrix} \cos\theta &amp; -\sin\theta \\ \sin\theta &amp; \cos\theta \end{pmatrix} \begin{pmatrix} a &amp; b \\ c &amp; d \end{pmatrix} = \begin{pmatrix} a\cos\theta-c\sin\theta &amp; b\cos\theta-d\sin\theta \\ a\sin\theta+c\cos\theta &amp; b\sin\theta+d\cos\theta \end{pmatrix} \tag{57}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Computing &lt;em&gt;A ⋅ R&lt;sub&gt;θ&lt;/sub&gt;&lt;/em&gt;:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\begin{pmatrix} a &amp; b \\ c &amp; d \end{pmatrix} \begin{pmatrix} \cos\theta &amp; -\sin\theta \\ \sin\theta &amp; \cos\theta \end{pmatrix} = \begin{pmatrix} a\cos\theta+b\sin\theta &amp; -a\sin\theta+b\cos\theta \\ c\cos\theta+d\sin\theta &amp; -c\sin\theta+d\cos\theta \end{pmatrix} \tag{58}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;These results are different unless &lt;em&gt;A&lt;/em&gt; has a special structure. Therefore:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
R_{\Theta,p-j}^d \cdot W^{UK} \neq W^{UK} \cdot R_{\Theta,p-j}^d \tag{59}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This non-commutativity prevents the matrix absorption optimization, forcing us to recalculate all keys with their appropriate rotations for each new token position.&lt;/p&gt;

&lt;h2 id=&quot;4-conclusion-why-decoupled-rope-succeeds-where-the-naive-approach-fails&quot;&gt;4. Conclusion: Why Decoupled RoPE Succeeds Where the Naive Approach Fails&lt;/h2&gt;

&lt;p&gt;The decoupled RoPE strategy succeeds by separating positional and content information into parallel paths, allowing each to be processed optimally:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Content path&lt;/strong&gt; can be efficiently compressed without worrying about position encoding.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Position path&lt;/strong&gt; handles rotary encodings separately, maintaining relative position awareness.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Concatenation&lt;/strong&gt; combines both signals without requiring recalculation of previous keys.&lt;/p&gt;

&lt;p&gt;This separation allows MLA to achieve both memory efficiency (through compression) and computational efficiency (by avoiding recomputation), while still preserving the powerful relative position encoding capabilities of RoPE.&lt;/p&gt;

&lt;p&gt;In contrast, the naive approach attempts to apply position encoding on top of the compression-decompression pipeline, creating a fundamental mathematical incompatibility that forces costly recomputations during inference.&lt;/p&gt;

&lt;p&gt;The decoupled RoPE strategy represents an elegant architectural solution that demonstrates the importance of carefully considering how different components of a model interact, particularly when optimizing for inference efficiency.&lt;/p&gt;

&lt;h2 id=&quot;5-references&quot;&gt;5. References&lt;/h2&gt;

&lt;div class=&quot;work-references&quot;&gt;
&lt;p&gt;[1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., &amp;amp; Liu, Y. (2021). &quot;RoFormer: Enhanced Transformer with Rotary Position Embedding.&quot; &lt;em&gt;arXiv preprint arXiv:2104.09864&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[2] DeepSeek-AI, Liu, A., Feng, B., Wang, B., Wang, B., Liu, B., Zhao, C., et al. (2024). &quot;DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model.&quot; &lt;em&gt;arXiv preprint arXiv:2405.04434&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[3] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., &amp;amp; Polosukhin, I. (2017). &quot;Attention is all you need.&quot; &lt;em&gt;Advances in Neural Information Processing Systems&lt;/em&gt;, 30.&lt;/p&gt;
&lt;/div&gt;

</description>
        <pubDate>Sun, 27 Apr 2025 00:00:00 -0400</pubDate>
        <link>https://aakashvarma.github.io/mla/</link>
        <guid isPermaLink="true">https://aakashvarma.github.io/mla/</guid>
        
      </item>
    
      <item>
        <title>Analysis of Matrix Multiplications in Transformer Architectures</title>
        <description>&lt;p&gt;This analysis takes inspiration from Lequn Chen’s &lt;a href=&quot;https://le.qun.ch/en/blog/2023/05/13/transformer-batching/&quot;&gt;excellent article on transformer batching&lt;/a&gt; which analyzed performance on the A100 GPU. Building on their insights, this analysis focuses on the H100 architecture and provides fresh perspectives on transformer computations, with detailed performance analysis, comprehensive roofline model examination, and future optimization strategies specific to H100’s architecture.&lt;/p&gt;

&lt;h2 id=&quot;introduction&quot;&gt;Introduction&lt;/h2&gt;

&lt;p&gt;Transformer blocks are built on two primary types of matrix multiplications: &lt;strong&gt;dense layer operations&lt;/strong&gt; and the &lt;strong&gt;QK multiplication&lt;/strong&gt; in self-attention mechanisms&lt;label for=&quot;sn-transformer-block&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-transformer-block&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;img src=&quot;/assets/images/transformer_bench/transformer_block.png&quot; alt=&quot;Transformer Block&quot; style=&quot;max-width:100%&quot; /&gt;&lt;br /&gt;&lt;br /&gt;To understand more about the transformer block please read &lt;a href=&quot;https://magazine.sebastianraschka.com/&quot;&gt;Sebastian Raschkas&lt;/a&gt; blogs. &lt;/span&gt;. These operations form the backbone of how Transformers process and encode input data, and their computational cost can be analyzed in terms of FLOPs (floating-point operations).&lt;/p&gt;

&lt;h2 id=&quot;dense-layers&quot;&gt;Dense Layers&lt;/h2&gt;

&lt;p&gt;Dense layers, are a fundamental component of Transformer blocks. These layers project inputs from one space to another. They are frequently used in the multi-head attention mechanism in projection operations, such as the generation of Q (query), K (key), and V (value) vectors in self-attention layers. Dense layers are also a crucial part of the Multi-Layer Perceptron (MLP) block, such as in models like LLaMA. A dense layer operates on an input tensor &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt; of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(batch,seqlen,h)&lt;/code&gt;, where &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch&lt;/code&gt; is the batch size, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seqlen&lt;/code&gt; is the sequence length, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h&lt;/code&gt; is the hidden size. It uses a weight matrix &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;W&lt;/code&gt; of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(h,h)&lt;/code&gt; to perform a linear projection, producing an output tensor of the same shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(batch,seqlen,h)&lt;/code&gt;&lt;label for=&quot;sn-dense-broadcast&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-dense-broadcast&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;For higher-dimensional inputs, vector-matrix multiplication broadcasts across all dimensions except the last one. A dense layer of shape &lt;code&gt;(h, h)&lt;/code&gt; applied to a tensor of shape &lt;code&gt;(b, s, h)&lt;/code&gt; first reshapes to &lt;code&gt;(b*s, h)&lt;/code&gt;, performs the matrix multiplication, then reshapes back to &lt;code&gt;(b, s, h)&lt;/code&gt;.&lt;br /&gt;&lt;br /&gt;Note: This broadcasting pattern is core to transformer architectures, allowing efficient parallel processing while preserving hidden dimension operations. &lt;/span&gt; through the matrix multiplication &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X⋅W&lt;/code&gt;.&lt;/p&gt;

&lt;h2 id=&quot;self-attention&quot;&gt;Self Attention&lt;/h2&gt;

&lt;p&gt;The QK multiplication is a core operation in the self-attention mechanism of Transformer models, enabling the computation of how each token in a sequence “attends” to every other token. This operation generates the attention scores that underpin the model’s ability to contextualize the input. To begin, the input tensor X of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(batch,seqlen,h)&lt;/code&gt;, where batch is the batch size, seqlen is the sequence length, and h is the hidden size, is linearly projected into the Query (Q) and Key (K) matrices. Both Q and K have the same shape as X, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(batch,seqlen,h)&lt;/code&gt;, where &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h=n⋅d&lt;/code&gt;, with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;n&lt;/code&gt; being the number of attention heads and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;d&lt;/code&gt; the dimensionality of each head in Multi Head Attention.&lt;/p&gt;

&lt;h2 id=&quot;flops-and-io&quot;&gt;FLOPs and IO&lt;/h2&gt;

&lt;h3 id=&quot;dense-layer&quot;&gt;Dense Layer&lt;/h3&gt;

&lt;p&gt;The computational cost of this operation, measured in floating-point operations (FLOPs)&lt;label for=&quot;sn-flops-def&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-flops-def&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Here FLOPs stands for number of floating point operations needed for the Matrix Multiplication, IO stands for number of Input and Output data transfer. In this current section these are not metrics of the Hardware(GPU). They are theoretical metrics. &lt;/span&gt;, is calculated as; each element in the output requires &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h&lt;/code&gt; multiplications and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h−1&lt;/code&gt; additions, approximately &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;2h&lt;/code&gt; operations per output element. With &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch⋅seqlen⋅h&lt;/code&gt; output elements, the total number of operations is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FLOPs = b⋅seqlen⋅h⋅(2h)&lt;/code&gt;, which simplifies to&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{FLOPs} = 2 \cdot b \cdot \text{seqlen} \cdot h^2
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;As a result, dense layers scale quadratically with the hidden size &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h&lt;/code&gt;, making them computationally expensive as &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h&lt;/code&gt; increases&lt;label for=&quot;sn-quadratic&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-quadratic&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The computational complexity increases quadratically with the hidden size, making this a critical consideration for large models. &lt;/span&gt;.&lt;/p&gt;

&lt;p&gt;The input matrix &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt; has a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b,seqlen,h)&lt;/code&gt;, so the total number of elements read from X is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅seqlen⋅h&lt;/code&gt;. The weight matrix W, with a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(h,h)&lt;/code&gt;, has &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h⋅h&lt;/code&gt; elements that are read. After performing the matrix multiplication &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X⋅W&lt;/code&gt;, the output matrix has the same shape as X, which is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b,seqlen,h)&lt;/code&gt;, and the number of elements written to the output is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅seqlen⋅h&lt;/code&gt;.&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{IO} = 2 \cdot b \cdot \text{seqlen} \cdot h + h^2
&lt;/script&gt;&lt;/div&gt;

&lt;h3 id=&quot;self-attention-1&quot;&gt;Self Attention&lt;/h3&gt;

&lt;h4 id=&quot;init&quot;&gt;Init&lt;/h4&gt;

&lt;p&gt;During initialization, when the entire sequence is processed at once, the Q and K matrices have the shapes &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, seqlen, d)&lt;/code&gt;. To compute attention, the K matrix is transposed to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, d, seqlen)&lt;/code&gt;. The matrix multiplication &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Q⋅K^T&lt;/code&gt; then produces an output tensor of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, seqlen, seqlen)&lt;/code&gt;, where each element represents the attention score between a pair of tokens.&lt;/p&gt;

&lt;p&gt;For each element in the output, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;d&lt;/code&gt; multiplications and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(d−1)&lt;/code&gt; additions are required, totalling approximately &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;2d&lt;/code&gt; operations per element. Since the output matrix has &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seqlen⋅seqlen&lt;/code&gt; elements, and this computation occurs for each batch and head, the total number of FLOPs can be calculated as:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{FLOPs} = 2 \cdot d \cdot b \cdot n \cdot \text{seqlen}^2
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The Q matrix and K matrix have a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b,n,seqlen,d)&lt;/code&gt;, so the number of elements read is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅n⋅seqlen⋅d&lt;/code&gt; for both and the output attention scores have a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b,n,seqlen,seqlen)&lt;/code&gt;, so the number of elements written is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅n⋅seqlen^2&lt;/code&gt;. Adding all of them gives:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{IO}_{\text{Init}} = 2 \cdot (b \cdot n \cdot \text{seqlen} \cdot d) + (b \cdot n \cdot \text{seqlen}^2)
&lt;/script&gt;&lt;/div&gt;

&lt;h4 id=&quot;auto-regressive-step&quot;&gt;Auto-Regressive Step&lt;/h4&gt;

&lt;p&gt;In the auto-regressive phase, where tokens are processed incrementally, the computation is performed for only the current token against all previously decoded tokens. Here, the Q matrix has a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, 1, d)&lt;/code&gt;, while the K matrix remains &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, seqlen, d)&lt;/code&gt;. After transposition, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;K^T&lt;/code&gt; has the shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, d, seqlen)&lt;/code&gt;. The resulting output tensor has the shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, 1, seqlen)&lt;/code&gt;, representing attention scores for the current token against all preceding tokens.&lt;/p&gt;

&lt;p&gt;For each output element, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;d&lt;/code&gt; multiplications and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(d−1)&lt;/code&gt; additions are required, as before. However, since only &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seqlen&lt;/code&gt; elements are computed (instead of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seqlen^2&lt;/code&gt;), the total FLOPs are:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{FLOPs} = 2 \cdot d \cdot b \cdot n \cdot 1 \cdot \text{seqlen} = 2 \cdot b \cdot n \cdot d \cdot \text{seqlen}
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The Q matrix has a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b,n,1,d)&lt;/code&gt; so the number of elements read is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅n⋅1⋅d&lt;/code&gt; and K matrix has a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, seqlen, d)&lt;/code&gt;, so the number of elements read is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅n⋅seqlen⋅d&lt;/code&gt; and the output attention scores have a shape of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b,n,1,seqlen)&lt;/code&gt;, so the number of elements written is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b⋅n⋅seqlen&lt;/code&gt;. Adding all of them gives:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{IO}_{\text{AR}} = (b \cdot n \cdot d) + (b \cdot n \cdot \text{seqlen} \cdot d) + (b \cdot n \cdot \text{seqlen})
&lt;/script&gt;&lt;/div&gt;

&lt;h2 id=&quot;arithmetic-intensity&quot;&gt;Arithmetic Intensity&lt;/h2&gt;

&lt;p&gt;Arithmetic intensity is a critical metric that represents the ratio of computational operations (FLOPs) to memory operations (IO bytes), expressed as FLOPs/Byte&lt;label for=&quot;sn-ai-model&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-ai-model&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;img src=&quot;/assets/images/transformer_bench/ai_model.png&quot; alt=&quot;Arithmetic intensity&quot; style=&quot;max-width:100%&quot; /&gt;&lt;br /&gt;&lt;br /&gt;Arithmetic Intensity (AI) = FLOPs/Bytes is a key performance indicator that helps determine whether an operation is compute-bound or memory-bound. Higher AI values suggest compute-bound operations, while lower values indicate memory-bound operations.&lt;br /&gt;Read more &lt;a href=&quot;https://dando18.github.io/posts/2020/04/02/roofline-model&quot;&gt;here&lt;/a&gt; &lt;/span&gt;. The three plots visualize this relationship for different MatMul layers in Transformer blocks (Dense Layer, QK Init, and QK AR) using logarithmic scales on both axes, where each increment represents an order of magnitude increase. The diagonal gray line represents a 1:1 ratio between FLOPs and bytes, with points above this line indicating operations that perform more computations per byte of memory accessed.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/flops_vs_io_single.png&quot; alt=&quot;Arithmetic Intensity Single&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 1: Arithmetic Intensity Analysis for Transformer Operations (Dense Layer, QK Init, QK AR) with sequence length 100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/flops_vs_io_all.png&quot; alt=&quot;Arithmetic Intensity All&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 2: Arithmetic Intensity Analysis for Transformer Operations (Dense Layer, QK Init, QK AR)
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;For Dense Layers, the arithmetic intensity is governed by:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{FLOPs} = 2 \cdot b \cdot \text{seqlen} \cdot h^2
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{IO} = 2 \cdot b \cdot \text{seqlen} \cdot h + h^2
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;This results in quadratic scaling with hidden size (h), making these layers increasingly compute-intensive as models grow larger&lt;label for=&quot;sn-dense-ai&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-dense-ai&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Dense Layer Arithmetic Intensity:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;\(\text{AI} = \frac{\text{FLOPs}}{\text{IO}} = \frac{2bsh^2}{2bsh + h^2} = \frac{2bsh}{2bs + h} = \frac{h}{1 + \frac{h}{2bs}} = O\left(\frac{1}{\frac{1}{h} + \frac{1}{2bs}}\right)\)&lt;br /&gt;&lt;br /&gt;This ratio reveals key insights:&lt;br /&gt;- When h is large: AI approaches O(b·s)&lt;br /&gt;- As b·s increases: AI approaches O(h)&lt;br /&gt;- When both h and b·s are large: AI is limited by min(h, b·s)&lt;br /&gt;&lt;br /&gt;Note: Ive used s instead of seqlen for consistency with typical notation, but they represent the same sequence length parameter.&lt;br /&gt;&lt;br /&gt;This mathematical relationship explains why increasing batch size improves efficiency: the denominator term 1/b approaches zero, maximizing arithmetic intensity. This is why dense layers in large models benefit significantly from batch processing. &lt;/span&gt;. The stepping pattern visible in the graph reflects this quadratic relationship, where larger hidden sizes show steeper curves and higher arithmetic intensity. This explains why dense layers in large models can become significant computational bottlenecks.&lt;/p&gt;

&lt;p&gt;QK Init (Init) operations are characterized by:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{FLOPs} = b \cdot n \cdot \text{seqlen}^2 \cdot 2 \cdot d
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{IO} = 2 \cdot (b \cdot n \cdot \text{seqlen} \cdot d) + (b \cdot n \cdot \text{seqlen}^2)
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;The middle graph shows parallel lines for different sequence lengths, indicating consistent arithmetic intensity patterns that scale predictably with sequence length. So as the sequence length increases, they become more compute heavy thus higher seqlen in QK Init stage cause bottlenecks in the compute.&lt;/p&gt;

&lt;p&gt;QK AR (Auto-Regressive) computations follow:&lt;/p&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{FLOPs} = b \cdot n \cdot d \cdot \text{seqlen} \cdot 2
&lt;/script&gt;&lt;/div&gt;

&lt;div class=&quot;mathblock&quot;&gt;&lt;script type=&quot;math/tex; mode=display&quot;&gt;
\text{IO} = (b \cdot n \cdot d) + (b \cdot n \cdot \text{seqlen} \cdot d) + (b \cdot n \cdot \text{seqlen})
&lt;/script&gt;&lt;/div&gt;

&lt;p&gt;Unlike QK Init, this operation scales linearly with sequence length, resulting in more favorable arithmetic intensity characteristics&lt;label for=&quot;sn-self-attention-ai&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-self-attention-ai&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Self-Attention Arithmetic Intensity:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;For QK&lt;sup&gt;T&lt;/sup&gt; multiplication:&lt;br /&gt;\(\text{FLOPs} = b \cdot n \cdot \text{seqlen}^2 \cdot 2 \cdot d\)&lt;br /&gt;\(\text{IO} = 2(b \cdot n \cdot \text{seqlen} \cdot d) + (b \cdot n \cdot \text{seqlen}^2)\)&lt;br /&gt;&lt;br /&gt;QK Arithmetic Intensity:&lt;br /&gt;\(\text{AI} = \frac{\text{FLOPs}}{\text{IO}} = \frac{b \cdot n \cdot \text{seqlen}^2 \cdot 2d}{2bnd \cdot \text{seqlen} + bn \cdot \text{seqlen}^2} = \frac{2 \cdot \text{seqlen} \cdot d}{2d + \text{seqlen}}\)&lt;br /&gt;&lt;br /&gt;This derivation reveals crucial properties:&lt;br /&gt;- Batch size b cancels out completely&lt;br /&gt;- AI depends only on sequence length and head dimension&lt;br /&gt;- Scaling b increases both compute and memory linearly&lt;br /&gt;- No inherent efficiency gain from batching unlike dense layers&lt;br /&gt;&lt;br /&gt;The final expression shows why self-attentions performance characteristics remain constant regardless of batch size, making it fundamentally different from dense layer operations. &lt;/span&gt;. This is evident in the rightmost graph, where points cluster tightly along similar trajectories regardless of sequence length.&lt;/p&gt;

&lt;p&gt;However, these graphs represent theoretical relationships that don’t account for real-world hardware constraints. The Roofline Model becomes crucial here as it helps bridge this gap by providing a framework to understand actual performance limitations. In the Roofline Model, performance is bounded by two primary factors: the peak computational performance (represented by a horizontal line) and the memory bandwidth limit (shown as a diagonal line). The lower of these two bounds at any given arithmetic intensity determines the maximum achievable performance. We’ll look at the roofline model in the sections below.&lt;/p&gt;

&lt;h2 id=&quot;analysis-of-dense-layer&quot;&gt;Analysis of Dense Layer&lt;/h2&gt;

&lt;p&gt;To analyze the dense layers, let’s look at the Throughput vs Batch graph. Throughput is calculated in tokens per second on the Y-axis and the X-axis shows the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch * seqlen&lt;/code&gt; dimension for that particular Dense operation&lt;label for=&quot;sn-dense-reshape&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-dense-reshape&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;For higher-dimensional inputs, the vector-matrix multiplication is broadcasted to all dimensions except for the last one. For example, when applying a dense layer of shape &lt;code&gt;(h, h)&lt;/code&gt; to a tensor of shape &lt;code&gt;(b, s, h)&lt;/code&gt;, the tensor is reshaped to &lt;code&gt;(b*s, h)&lt;/code&gt; before the matrix multiplication and then reshaped back to &lt;code&gt;(b, s, h)&lt;/code&gt; afterward. &lt;/span&gt;.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/dense_init_small_seqlen.png&quot; alt=&quot;Dense Layer Small Seqlen&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 3a: Throughput vs Batch for Dense Layer with small sequence length on NVIDIA H100 for INIT stage
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/dense_init_large_seqlen.png&quot; alt=&quot;Dense Layer Large Seqlen&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 3b: Throughput vs Batch for Dense Layer with large sequence length on NVIDIA H100 for INIT stage
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;In Figure 3a, we can see that there is a benefit from batching. the throughput increases as the batch size increases. This is for the smaller dimension of the seqlen, but as the seqlen is made larger&lt;label for=&quot;sn-large-prompt&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-large-prompt&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;When the prompt is larger around 100+ tokens as the input &lt;/span&gt;, we can see that increasing the batch size improves the throughput only till a certain point but beyond that the throughput saturates as in Figure 3b.&lt;/p&gt;

&lt;p&gt;We can infer that the H100 falls short of utilizing all the compute units for the matrix size when the input prompt is smaller but not for larger sequence lengths.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/dense_init_consolidated.png&quot; alt=&quot;Dense Layer Consolidated&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 4: Consolidated view of Dense Layer throughput across different dimensions on NVIDIA H100 for INIT stage
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;To show a consolidated view on Throughput vs the batch for all the batch dimensions with variations in h, d, n, and seqlen, it is not very useful to plot all of them separately for all the combinations of them. Instead using FLOPs on the x-axis allows us to analyze different model sizes on a single plot. This figure uses FLOPs as the x-axis, which is similar to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;b*s&lt;/code&gt; since FLOPs are &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;O(bsh^2)&lt;/code&gt;&lt;label for=&quot;sn-flops-relation&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-flops-relation&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The relationship between FLOPs and batch size demonstrates how computational complexity scales with model parameters, directly impacting throughput characteristics. &lt;/span&gt;. This plot shows that as the batch increases the also throughput increases when the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seqlen&lt;/code&gt; is smaller (which is when the prompts are smaller). But if the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch&lt;/code&gt; is higher (either the batch is higher when the LLM is being served or the prompt is larger which causes &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seqlen&lt;/code&gt; to be larger or both) the throughput saturates.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/dense_ar_1.png&quot; alt=&quot;Dense AR Stage 1&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 5a: Dense Layer performance in auto-regressive stage - Throughput Analysis on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/dense_ar_2.png&quot; alt=&quot;Dense AR Stage 2&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 5b: Dense Layer performance in auto-regressive stage across different hidden-dims
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;In the autoregressive(AR) stage, the sequence length is always 1&lt;label for=&quot;sn-ar-seqlen&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-ar-seqlen&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The seqlen is 1 in AR step for dense because there is only the new Token which was generated in the previous step, that needs to be processed at this step. Keys and Values are needed for all the Tokens but this is handled by the KV Cache. Hence only 1 token processing gives us a seqlen = 1 &lt;/span&gt;. So there is no practical upper limit on the throughput even for higher batches. This reflects in the graphs  from Figure 5a and 5b.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/dense_ar_latency.png&quot; alt=&quot;Dense AR Latency&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 6: Latency analysis for Dense Layer in auto-regressive stage on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;From Figure 6, we can see that batching dense layer in the auto regressive generation stage does not significantly affect the latency of the generation. This is a good thing because a batch of 100 has the same latency as that of lower batch sizes.&lt;/p&gt;

&lt;p&gt;In system design, managing batch sizes and sequence lengths is crucial, particularly for larger models during the Init phase&lt;label for=&quot;sn-system-design&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-system-design&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The relationship between batch size and sequence length creates a complex optimization space that directly impacts system performance and resource utilization. Understanding these dynamics is crucial for efficient model deployment. &lt;/span&gt;. This phase tends to be the primary performance bottleneck, requiring careful optimization to improve efficiency. Conversely, the autoregressive generation phase scales more effectively, making it less of a limiting factor in overall performance. Smaller models with hidden sizes below 2048 demonstrate better efficiency across both phases, highlighting their suitability for latency-sensitive applications. Additionally, effective batching strategies can significantly enhance the performance of the generation phase without incurring notable penalties. These insights suggest the need for distinct optimization strategies tailored to the Init and generation phases in model serving.&lt;/p&gt;

&lt;h2 id=&quot;analysis-of-self-attention&quot;&gt;Analysis of Self Attention&lt;/h2&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_small_seqlen.png&quot; alt=&quot;Self Attention Small Seqlen&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 7a: Self Attention performance with small sequence length on NVIDIA H100 for INIT stage
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_large_seqlen.png&quot; alt=&quot;Self Attention Large Seqlen&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 7b: Self Attention performance with large sequence length on NVIDIA H100 for INIT stage
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Analyzing the graphs above reveals that for smaller sequence lengths (shorter prompts) in the Init stage, batching has a more significant impact, providing noticeable benefits&lt;label for=&quot;sn-batching-impact&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-batching-impact&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;The impact of batching on self-attention performance varies significantly with sequence length, creating an important consideration for optimization strategies. &lt;/span&gt;. However, in the graph in Figure 7b, where the sequence length is larger (seqlen = 500) during the initialization stage, the throughput of the QK matrix multiplication begins to saturate as the batch size increases.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_flops_1.png&quot; alt=&quot;Self Attention FLOPs 1&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 8a: Self Attention performance analysis for INIT stage across different hidden dimensions, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_flops_2.png&quot; alt=&quot;Self Attention FLOPs 2&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 8b: Self Attention performance analysis for INIT stage across different sequence lengths, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Let’s examine the plots with FLOPs on the x-axis, representing different model sizes on the same graph&lt;label for=&quot;sn-flops-comparison&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-flops-comparison&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;Using FLOPs as a metric allows for direct comparison across different model configurations, providing insights into computational efficiency scaling. &lt;/span&gt;. For sequence lengths less than 500, throughput increases as the batch size grows. However, for sequence lengths greater than 500, the plots become linear, showing no increase in throughput despite an increase in batch size.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_ar_1.png&quot; alt=&quot;Self Attention AR 1&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 9a: Self Attention auto-regressive performance with small sequence length for h = 4096, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_ar_2.png&quot; alt=&quot;Self Attention AR 2&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 9b: Self Attention auto-regressive performance with large sequence length for h = 4096, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_ar_3.png&quot; alt=&quot;Self Attention AR 3&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 9c: Self Attention auto-regressive performance across different hidden dimensions, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_ar_4.png&quot; alt=&quot;Self Attention AR 4&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 9d: Self Attention auto-regressive performance across different sequence length, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;A similar pattern is observed in the auto-regressive stage, where increasing the batch size for larger sequence lengths has minimal to no effect. This occurs because they share a similar Arithmetic Intensity. Additionally, as auto-regression progresses, the sequence length increases, further diminishing the impact of batching.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_latency_1.png&quot; alt=&quot;Self Attention Latency 1&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 10a: Latency analysis for Self Attention across different sequence lengths, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/self_attention_latency_2.png&quot; alt=&quot;Self Attention Latency 2&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 10b: Latency analysis for Self Attention across different different hidden dimensions, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Self-attention latency is comparable to that of a dense layer but increases with batch size, unlike a dense layer. This latency scales approximately linearly with batch size because self-attention primarily involves batched matrix multiplication. With a fixed FLOP-to-I/O ratio, increasing the batch size proportionally raises both FLOPs and I/O, maintaining a constant ratio&lt;label for=&quot;sn-self-attn-ai&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-self-attn-ai&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Self-Attention Arithmetic Intensity:&lt;/strong&gt;&lt;br /&gt;\(\text{QK AI} = \frac{2 \cdot \text{seqlen} \cdot d}{2d + \text{seqlen}}\)&lt;br /&gt;&lt;br /&gt;- Batch size b cancels out completely&lt;br /&gt;- AI depends only on sequence length and head dimension&lt;br /&gt;&lt;br /&gt;Therefore, increasing batch size does not change the AI, it increases both FLOPs and IO at the same multiplier. &lt;/span&gt;. For example, increasing the batch size from 100 to 1000 directs the system to process more items simultaneously, boosting total throughput without accelerating the processing of individual items. The fundamental matrix multiplication operations still require the same number of steps per item, as the computational work (FLOPs) and memory operations (I/O) scale together. Additionally, in auto-regressive tasks, as the sequence length grows, more time is required to process each subsequent step.&lt;/p&gt;

&lt;h2 id=&quot;roofline-model&quot;&gt;Roofline Model&lt;/h2&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/roofline_model_overview.png&quot; alt=&quot;Roofline Model Overview&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 11: Roofline Model analysis for all operations on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;The roofline model&lt;label for=&quot;sn-roofline&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-roofline&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;img src=&quot;/assets/images/transformer_bench/rf_model.png&quot; alt=&quot;Roofline Model&quot; style=&quot;max-width:100%&quot; /&gt;&lt;br /&gt;&lt;br /&gt;The Roofline Model is a performance model seeking to give the limitations of a specific hardware component in terms of algorithm performance. The model is often employed visually as a log-log plot of Arithmetic Intensity vs Flops/s. Read the math behind it &lt;a href=&quot;https://dando18.github.io/posts/2020/04/02/roofline-model&quot;&gt;here&lt;/a&gt; &lt;/span&gt; presents the data points for all benchmark combinations, organized using the Roofline Model. Different stages and layers are distinguished through color coding. Overlaid on the figure are the theoretical memory bandwidth and FLOP/s limits, based on NVIDIA H100 specifications.&lt;/p&gt;

&lt;p&gt;Two key insights emerge from this visualization:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;The data points cluster into distinct groups and sub-groups, naturally reflecting the computational and memory characteristics of various stages and layers.&lt;/li&gt;
  &lt;li&gt;The data points closely follow the theoretical roofline, demonstrating that the benchmarks effectively leverage the hardware’s capabilities relative to its peak performance.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;To observe the impact of batching, let’s examine a specific case (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;h=4096, s=100&lt;/code&gt;)&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/transformer_bench/roofline_model_specific.png&quot; alt=&quot;Roofline Model Specific&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 12: Detailed Roofline analysis for h=4096 and s=100, measured on NVIDIA H100
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;Arithmetic Intensity and Achieved FLOP/s&lt;/strong&gt;: Arithmetic intensity across operations follows the sequence: dense_init &amp;gt; qk_init &amp;gt; dense_ar &amp;gt; qk_ar. Achieved FLOP/s also follows this order. The dense layer during initialization is constrained by the GPU’s peak computational performance. For small models and short sequence lengths, batching provides slight improvements, but significant performance gains require investing in a more powerful GPU.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Dense Layer in Auto-Regression&lt;/strong&gt;: Unlike initialization, the dense layer in the auto-regression stage behaves differently. For the same model size, its data points align with the slope of the GPU’s memory bandwidth, indicating that its performance is memory bandwidth-bound. Under this constraint, increasing the batch size enhances the achieved FLOP/s by improving arithmetic intensity.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Batching and Self-Attention&lt;/strong&gt;: Batching significantly impacts self-attention. While it does not alter the arithmetic intensity of self-attention, it increases the achieved FLOP/s for short sequence lengths by enabling parallel processing.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Kernel Implementation in Self-Attention&lt;/strong&gt;: The increase in achieved FLOP/s for self-attention, despite unchanged arithmetic intensity, suggests that the kernel implementation may be suboptimal, potentially failing to fully utilize the GPU’s compute units.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;data-availability&quot;&gt;Data Availability&lt;/h2&gt;

&lt;p&gt;All the data used in this analysis is publicly available in CSV format at &lt;a href=&quot;https://github.com/doteval/transformer_bench/tree/main/data&quot;&gt;transformer_bench/data&lt;/a&gt;. While this article focuses on bf16 dtype results, the repository contains data for fp32 and fp16 dtypes as well on the H100 GPU. You are encouraged to perform their own analysis using these additional precision formats and contribute their findings to the repository at &lt;a href=&quot;https://github.com/doteval/transformer_bench/tree/main&quot;&gt;doteval/transformer_bench&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;summary&quot;&gt;Summary&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;Performance Hierarchy and Hardware Constraints&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;Arithmetic intensity and achieved FLOP/s follow a clear hierarchy: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dense_init &amp;gt; qk_init &amp;gt; dense_ar &amp;gt; qk_ar&lt;/code&gt;&lt;/li&gt;
      &lt;li&gt;Dense layer initialization is compute-bound by GPU peak performance&lt;/li&gt;
      &lt;li&gt;Dense layer auto-regression is memory bandwidth-bound&lt;/li&gt;
      &lt;li&gt;Performance improvements in compute-bound operations require GPU upgrades, while memory-bound operations benefit from optimized batching strategies&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Sequence Length and Batching Dynamics&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;Short sequence lengths (&amp;lt; 500 tokens) show significant benefits from batching&lt;/li&gt;
      &lt;li&gt;Longer sequences (&amp;gt; 500 tokens) show diminishing returns from increased batch sizes&lt;/li&gt;
      &lt;li&gt;In autoregressive generation, sequence length remains at 1, allowing for efficient batching&lt;/li&gt;
      &lt;li&gt;Throughput saturation occurs at different batch sizes depending on sequence length and model size&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Self-Attention Characteristics and Optimization&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;Self-attention benefits from batching without changing arithmetic intensity&lt;/li&gt;
      &lt;li&gt;Current kernel implementations show signs of suboptimal compute unit utilization&lt;/li&gt;
      &lt;li&gt;Parallel processing capabilities are not fully exploited, suggesting room for optimization&lt;/li&gt;
      &lt;li&gt;Performance scales linearly with batch size due to the nature of matrix multiplication operations&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Model Size Considerations&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;Smaller models (hidden sizes &amp;lt; 2048) demonstrate better efficiency across all phases&lt;/li&gt;
      &lt;li&gt;Larger models face significant computational bottlenecks during Init&lt;/li&gt;
      &lt;li&gt;Memory bandwidth becomes a limiting factor for large models in autoregressive phase&lt;/li&gt;
      &lt;li&gt;Different optimization strategies are needed for different model sizes&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;System Design and Implementation Insights&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;Init phase is typically the primary performance bottleneck&lt;/li&gt;
      &lt;li&gt;Autoregressive generation phase shows more favorable scaling characteristics&lt;/li&gt;
      &lt;li&gt;Different phases require distinct optimization approaches due to varying performance characteristics&lt;/li&gt;
      &lt;li&gt;System designs need to balance between throughput optimization and latency requirements based on use case&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;future&quot;&gt;Future&lt;/h2&gt;

&lt;h3 id=&quot;optimizing-self-attention-through-matrix-fusion&quot;&gt;Optimizing Self-Attention Through Matrix Fusion&lt;/h3&gt;

&lt;p&gt;In the self-attention mechanism, we can identify a key optimization opportunity in the matrix multiplication operations. Currently, the computation flow involves:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Computing &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;QK^T&lt;/code&gt; which produces an intermediate result with shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, s, s)&lt;/code&gt;&lt;/li&gt;
  &lt;li&gt;Applying softmax to this intermediate result&lt;/li&gt;
  &lt;li&gt;Multiplying with V to get the final output of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(b, n, s, d)&lt;/code&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;A more efficient approach would combine these operations into a unified computation:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;The key insight is that we can fuse these three matrix operations (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;QK^T&lt;/code&gt;, softmax, and multiplication with V) into a single GPU kernel operation&lt;/li&gt;
  &lt;li&gt;This fusion is particularly effective because the head dimension (d=128) is relatively small&lt;/li&gt;
  &lt;li&gt;The main challenge lies in handling the softmax operation, which traditionally requires computing across the entire sequence dimension&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The softmax computation presents a specific challenge, but this was solved beautifully by Flash Attention. There are 3 versions of Flash Attention. 3rd being a specific optimisation to H100 GPUS, and the first 2 papers can be implemented in any GPU. Links to the papers are in the references.&lt;/p&gt;

&lt;h3 id=&quot;efficient-request-batching-strategy&quot;&gt;Efficient Request Batching Strategy&lt;/h3&gt;

&lt;p&gt;Analysis revealed significant potential in batching multiple requests, even when they have different sequence lengths. Rather than using simple padding, we can implement a more sophisticated approach based on our performance analysis:&lt;/p&gt;

&lt;p&gt;Key Observations:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Dense layer performance:
    &lt;ul&gt;
      &lt;li&gt;Shows strong batching benefits&lt;/li&gt;
      &lt;li&gt;Maintains nearly constant latency during autoregressive generation&lt;/li&gt;
      &lt;li&gt;Treats sequence dimension similarly to batch dimension&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;Self-attention characteristics:
    &lt;ul&gt;
      &lt;li&gt;Must process each sequence independently&lt;/li&gt;
      &lt;li&gt;Cannot be batched across different sequences&lt;/li&gt;
      &lt;li&gt;Takes less execution time compared to dense layers&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Implementation Strategy:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Input Processing:
    &lt;ul&gt;
      &lt;li&gt;Take variable-length inputs: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[(s1, h), (s2, h), ...]&lt;/code&gt;&lt;/li&gt;
      &lt;li&gt;Combine them into a single matrix: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(sum(si), h)&lt;/code&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;Computation Flow:
    &lt;ul&gt;
      &lt;li&gt;Process the combined matrix through dense layers&lt;/li&gt;
      &lt;li&gt;Split the results back into individual sequences&lt;/li&gt;
      &lt;li&gt;Handle self-attention computations separately for each sequence&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This approach offers several advantages:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Eliminates unnecessary padding computations&lt;/li&gt;
  &lt;li&gt;Maintains computational efficiency for dense layers&lt;/li&gt;
  &lt;li&gt;Preserves sequence-specific attention patterns&lt;/li&gt;
  &lt;li&gt;Balances throughput improvements with latency considerations&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The strategy is particularly effective because it:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Leverages the strengths of dense layer batching&lt;/li&gt;
  &lt;li&gt;Respects the inherent limitations of self-attention&lt;/li&gt;
  &lt;li&gt;Minimizes computational overhead&lt;/li&gt;
  &lt;li&gt;Provides flexibility in handling variable-length inputs&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;This method is presented in Orca. Reference here: https://www.usenix.org/conference/osdi22/presentation/yu&lt;/p&gt;

&lt;h2 id=&quot;references&quot;&gt;References&lt;/h2&gt;

&lt;div class=&quot;work-references&quot;&gt;
&lt;p&gt;[1] Yu, G.-I., Jeong, J. S., Kim, G.-W., Kim, S., &amp;amp; Chun, B.-G. (2022). &quot;Orca: A Distributed Serving System for Transformer-Based Generative Models.&quot; In &lt;em&gt;16th USENIX Symposium on Operating Systems Design and Implementation (OSDI 22)&lt;/em&gt; (pp. 521-538). Carlsbad, CA: USENIX Association.&lt;/p&gt;
&lt;p&gt;[2] Dando, A. (2020). &quot;Arithmetic Intensity and the Roofline Model.&quot; &lt;em&gt;Dando&apos;s Blog&lt;/em&gt;, April 2, 2020.&lt;/p&gt;
&lt;p&gt;[3] NVIDIA. &quot;Guide for GEMM.&quot; &lt;em&gt;NVIDIA Deep Learning Performance Documentation&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[4] Milakov, M., &amp;amp; Gimelshein, N. (2018). &quot;Online normalizer calculation for softmax.&quot; &lt;em&gt;arXiv preprint arXiv:1805.02867v2&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[5] Dao, T., Fu, D. Y., Ermon, S., Rudra, A., &amp;amp; Ré, C. (2022). &quot;FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.&quot; &lt;em&gt;arXiv preprint arXiv:2205.14135v2&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[6] Dao, T. (2023). &quot;FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.&quot; &lt;em&gt;arXiv preprint arXiv:2307.08691&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[7] Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., &amp;amp; Dao, T. (2024). &quot;FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.&quot; &lt;em&gt;arXiv preprint arXiv:2407.08608&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[8] Chen, L. (2023). &quot;Transformer Batching.&quot; &lt;em&gt;Lequn Chen Blog&lt;/em&gt;, May 13, 2023.&lt;/p&gt;
&lt;/div&gt;
</description>
        <pubDate>Wed, 01 Jan 2025 00:00:00 -0500</pubDate>
        <link>https://aakashvarma.github.io/transformer_bench/</link>
        <guid isPermaLink="true">https://aakashvarma.github.io/transformer_bench/</guid>
        
      </item>
    
      <item>
        <title>Balancing Memory &amp; Compute: Strategies to Manage KV Cache in LLMs</title>
        <description>&lt;p&gt;KV caching as is method to optimize the inference process of large language models (LLMs), reducing the compute requirements from quadratic to linear scaling with the sequence length. Specifically, KV caching involves storing the key and value tensors of past tokens in GPU memory during the generation process, thus avoiding re-computation at each step.&lt;/p&gt;

&lt;p&gt;KV caching represents a trade-off between memory usage and compute resources&lt;label for=&quot;sn-memory-compute&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-memory-compute&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Memory-Compute Trade-off:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;&lt;em&gt;Without KV Cache:&lt;/em&gt;&lt;br /&gt;&lt;code&gt;Compute = O(n²) per token&lt;br /&gt;Memory = O(1)&lt;/code&gt;&lt;br /&gt;&lt;br /&gt;&lt;em&gt;With KV Cache:&lt;/em&gt;&lt;br /&gt;&lt;code&gt;Compute = O(n) per token&lt;br /&gt;Memory = O(n)&lt;/code&gt;&lt;br /&gt;&lt;br /&gt;Where &lt;em&gt;n&lt;/em&gt; is sequence length &lt;/span&gt;. While it reduces computational load, it increases memory consumption due to the need to store cached tensors. In this post, we’ll delve into the challenges posed by the growing size of the KV cache and explore common strategies to address them.&lt;/p&gt;

&lt;p&gt;The size of the KV cache grows linearly with the batch size and the total sequence length. The per-token memory consumption depends on the precision used for storing the tensors.&lt;/p&gt;

&lt;p&gt;Let’s derive the formula for the total size of the KV cache:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Core Formula Parameters:&lt;/strong&gt;&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;        &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;           &lt;span class=&quot;c1&quot;&gt;# Batch size
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sequence_length&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;# Total sequence length
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_decoder_blocks&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# Number of decoder blocks / attention layers
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_heads&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_attention_heads&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Number of attention heads per layer
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;   &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;head_dimension&lt;/span&gt;       &lt;span class=&quot;c1&quot;&gt;# Hidden dimension of the attention layer
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p_a&lt;/span&gt;      &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;precision_bytes&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;# Precision (bytes)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The per-token memory consumption (in bytes) for the KV cache of a multi-head attention (MHA) model is:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;per_token_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p_a&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The total size of the KV cache (in bytes)&lt;label for=&quot;sn-formula&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-formula&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;This formula accounts for the fact that for each token in each sequence in the batch, we need to store two tensors (key and value) for each attention head and each attention layer. &lt;/span&gt;:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;total_kv_cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p_a&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The challenge with KV caching lies in its unbounded growth with the total sequence length, which poses difficulties in managing GPU memory, especially since the total sequence length may not be known in advance.&lt;/p&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/kv_cache_optimization/heatmap.png&quot; alt=&quot;Attention Heatmap&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 1: Attention (heat)map from the StreamingLLM paper: A lot of attention is consistently allocated to the first token and to the last neighboring tokens (local attention)
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;h2 id=&quot;exploring-ways-to-reduce-memory-footprint-of-the-kv-cache&quot;&gt;Exploring ways to reduce memory footprint of the KV cache&lt;/h2&gt;

&lt;p&gt;Let’s explore ways to reduce the memory footprint of the KV cache by examining each component of the formula:&lt;/p&gt;

&lt;h3 id=&quot;optimizing-batch-size-b&quot;&gt;Optimizing Batch Size (&lt;em&gt;b&lt;/em&gt;)&lt;/h3&gt;

&lt;p&gt;While decreasing the batch size can indeed alleviate the memory footprint of the KV cache and subsequently reduce latency, it’s generally not preferable. This is because reducing the batch size lowers hardware utilization, diminishing cost efficiency. In upcoming posts, we’ll delve into why increasing the batch size is often more desirable.&lt;/p&gt;

&lt;h3 id=&quot;optimizing-sequence-length-seq_len&quot;&gt;Optimizing Sequence Length (&lt;em&gt;seq_len&lt;/em&gt;)&lt;/h3&gt;

&lt;p&gt;To mitigate the dependency on the total sequence length&lt;label for=&quot;sn-attention-pattern&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-attention-pattern&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Attention Pattern Analysis:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;• Strong attention to first tokens&lt;br /&gt;• Local attention clusters&lt;br /&gt;• Special token importance&lt;br /&gt;• Periodic patterns at:&lt;br /&gt;  • Sentence boundaries&lt;br /&gt;  • Paragraph breaks&lt;br /&gt;  • List elements &lt;/span&gt;, one approach is to refrain from storing keys and values for all tokens in the sequence. This strategy might involve recomputing missing keys and values on each iteration, prioritizing computational resources over GPU memory consumption, especially when memory bandwidth is a limiting factor.&lt;/p&gt;

&lt;p&gt;Another perspective involves not storing keys and values for tokens that the model pays little or no attention to. This could be intentional in models trained to attend only to specific parts of the sequence, such as Mistral-7B, which utilizes sliding window attention (SWA) or local attention. With SWA, attention layers focus solely on neighboring tokens (only 4096), limiting the number of tensor pairs stored per sequence to the window size (4096).&lt;/p&gt;

&lt;h3 id=&quot;more-methods-for-memory-reduction&quot;&gt;More Methods for Memory Reduction&lt;/h3&gt;

&lt;h4 id=&quot;streamingllm-framework&quot;&gt;StreamingLLM Framework&lt;/h4&gt;

&lt;p&gt;Targeting models with finite-length context windows, this framework observes that initial tokens gather significant attention&lt;label for=&quot;sn-streamingllm&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-streamingllm&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;StreamingLLM Memory Usage:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;&lt;code&gt;Fixed part = n_sink tokens&lt;br /&gt;Sliding part = window_size tokens&lt;br /&gt;&lt;br /&gt;Total Memory = (n_sink + window_size) × token_size&lt;br /&gt;vs. Original = full_context × token_size&lt;/code&gt;&lt;br /&gt;&lt;br /&gt;Typical savings: 40-60% with minimal performance impact &lt;/span&gt;. It builds a sliding window by retaining only the first positional tokens (“sink tokens”) and the last neighboring tokens (local attention) in the cache. The cache has a fixed length with both a fixed part and a sliding part.&lt;/p&gt;

&lt;h4 id=&quot;h2o-and-scissorhands-methods&quot;&gt;H2O and Scissorhands Methods&lt;/h4&gt;

&lt;p&gt;These methods compress the KV cache by setting a maximum number of cached tokens (budget) and discarding tokens when the cache budget is reached. H2O discards one token at a time, while Scissorhands drops tokens based on a target compression ratio. Both methods exploit the observation that influential tokens at a given step tend to remain influential in future steps.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Cache Eviction Policy&lt;/strong&gt; - Both H2O and Scissorhands employ cache eviction policies to determine which tokens to discard. Scissorhands retains the most recent tokens and tokens with the highest attention scores within a history window. H2O discards tokens with the lowest cumulated attention scores, retaining tokens consistently achieving high attention scores across iterations.&lt;/p&gt;

&lt;h4 id=&quot;fastgen-method&quot;&gt;FastGen Method&lt;/h4&gt;

&lt;p&gt;FastGen focuses on preserving model accuracy&lt;label for=&quot;sn-fastgen&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-fastgen&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;FastGen sets an error threshold (ε) for approximation:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;&lt;code&gt;Error = ||A - A||_F / ||A||_F&lt;br /&gt;&lt;br /&gt;Where:&lt;br /&gt;A = Original attention matrix&lt;br /&gt;A = Approximated matrix&lt;br /&gt;||·||_F = Frobenius norm&lt;br /&gt;&lt;br /&gt;Typical bounds:&lt;br /&gt;ε = 0.1  → ~70% compression&lt;br /&gt;ε = 0.05 → ~50% compression&lt;br /&gt;ε = 0.01 → ~30% compression&lt;/code&gt; &lt;/span&gt; by setting a maximum approximation error for the attention matrix instead of a cache budget. It profiles the model’s attention layers to determine compression policies during a prefill phase. These policies, such as keeping special tokens or punctuation tokens, are applied to the KV cache at each generation step to meet the error target. If the target is too stringent, FastGen falls back to regular KV caching.&lt;/p&gt;

&lt;h3 id=&quot;optimizing-number-of-layers-n_layers&quot;&gt;Optimizing Number of Layers (&lt;em&gt;n_layers&lt;/em&gt;)&lt;/h3&gt;

&lt;p&gt;Reducing the number of layers in a language model does not offer significant gains in terms of memory reduction. Typically, smaller models naturally have fewer layers. Therefore, if a smaller model suits your use case and performs adequately, opting for it is a straightforward solution.&lt;/p&gt;

&lt;h3 id=&quot;optimizing-number-of-attention-heads-n_heads&quot;&gt;Optimizing Number of Attention Heads (&lt;em&gt;n_heads&lt;/em&gt;)&lt;/h3&gt;

&lt;figure&gt;
&lt;img src=&quot;/assets/images/kv_cache_optimization/attention_types.jpg&quot; alt=&quot;Types of Attention&quot; class=&quot;center&quot; /&gt;
&lt;figcaption&gt;
Figure 2: Types of Attention
&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;The multi-query attention (MQA) and grouped-query attention (GQA) architectures provide strategies for reducing the key-value (KV) cache size in models based on the Transformer architecture&lt;label for=&quot;sn-mqa-gqa&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-mqa-gqa&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;MQA vs GQA Memory:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;&lt;code&gt;MHA: Memory = H × d × 2&lt;br /&gt;MQA: Memory = d × 2&lt;br /&gt;GQA: Memory = g × d × 2&lt;br /&gt;&lt;br /&gt;Where:&lt;br /&gt;H = Total heads&lt;br /&gt;d = Head dimension&lt;br /&gt;g = Number of groups (g &amp;lt; H)&lt;br /&gt;&lt;br /&gt;Real-world example:&lt;br /&gt;32 heads → 8 groups = 75% reduction&lt;/code&gt; &lt;/span&gt;. These approaches allow for more efficient use of resources without sacrificing model performance significantly.&lt;/p&gt;

&lt;p&gt;In MQA, all query heads share the same single key and value heads, meaning that each query head computes attention scores using the same keys, and all heads output values computed using the same values but different attention scores.&lt;/p&gt;

&lt;p&gt;GQA splits the query heads into groups, with each group sharing the same unique key-value heads. This allows for a smoother reduction in the number of key-value heads compared to MQA, providing a compromise between model representation capacity and KV cache size.&lt;/p&gt;

&lt;p&gt;These architectures have been implemented in various models by different research groups, such as Google Research’s PaLM, TII’s Falcon models, Meta’s Llama-2 (limited to 70B only), and Mistral AI’s Mistral-7B.&lt;/p&gt;

&lt;h3 id=&quot;optimizing-hidden-dimension-d_head&quot;&gt;Optimizing Hidden Dimension (&lt;em&gt;d_head&lt;/em&gt;)&lt;/h3&gt;

&lt;p&gt;Once again, there is nothing much to gain here if you are not ready to opt for another model.&lt;/p&gt;

&lt;h3 id=&quot;optimizing-precision-p_a&quot;&gt;Optimizing Precision (&lt;em&gt;p_a&lt;/em&gt;)&lt;/h3&gt;

&lt;p&gt;Quantizing the key-value (KV) cache is an effective method for reducing its size&lt;label for=&quot;sn-precision&quot; class=&quot;margin-toggle sidenote-number&quot;&gt;&lt;/label&gt;&lt;input type=&quot;checkbox&quot; id=&quot;sn-precision&quot; class=&quot;margin-toggle&quot; /&gt;&lt;span class=&quot;sidenote&quot;&gt;&lt;strong&gt;Precision Impact:&lt;/strong&gt;&lt;br /&gt;&lt;br /&gt;&lt;code&gt;Memory reduction by precision:&lt;br /&gt;FP32 (4 bytes) → FP16 (2 bytes): 50% reduction&lt;br /&gt;FP16 (2 bytes) → INT8 (1 byte): 50% reduction&lt;br /&gt;INT8 (1 byte) → INT4 (0.5 bytes): 50% reduction&lt;/code&gt; &lt;/span&gt;, but it’s important to use quantization algorithms that operate on both weights and activations, not just weights. Algorithms like LLM.int8() or SmoothQuant are suitable for this purpose, as they quantize both weights and activations, resulting in a reduced memory footprint.&lt;/p&gt;

&lt;p&gt;However, for inference tasks, where memory bandwidth is the limiting factor rather than compute power, quantizing the cached tensors before moving them to GPU memory and dequantizing them afterward could suffice. This approach reduces the memory footprint without the overhead of more complex quantization algorithms.&lt;/p&gt;

&lt;p&gt;Some inference systems, like FlexGen, NVIDIA TensorRT-LLM, and vLLM framework, already incorporate KV cache quantization features. They store the KV cache and model weights in reduced bit formats (4-bit or 8-bit) dynamically without requiring a calibration step at each iteration.&lt;/p&gt;

&lt;h2 id=&quot;references&quot;&gt;References&lt;/h2&gt;

&lt;div class=&quot;work-references&quot;&gt;
&lt;p&gt;[1] Xiao, G., Tian, Y., Chen, B., Han, S., &amp;amp; Lewis, M. (2023). &quot;Efficient Streaming Language Models with Attention Sinks.&quot; In &lt;em&gt;International Conference on Learning Representations (ICLR)&lt;/em&gt;. arXiv preprint arXiv:2309.17453.&lt;/p&gt;
&lt;p&gt;[2] Zhang, Z., Sheng, Y., Zhou, T., Chen, T., Zheng, L., Cai, R., Song, Z., Tian, Y., Ré, C., Barrett, C., Wang, Z., &amp;amp; Chen, B. (2023). &quot;H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models.&quot; In &lt;em&gt;Advances in Neural Information Processing Systems (NeurIPS)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[3] Liu, Z., Desai, A., Liao, F., Wang, W., Xie, V., Xu, Z., Kyrillidis, A., &amp;amp; Shrivastava, A. (2023). &quot;Scissorhands: Exploiting the Persistence of Importance Hypothesis for LLM KV Cache Compression at Test Time.&quot; In &lt;em&gt;Advances in Neural Information Processing Systems (NeurIPS)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[4] Ge, Y., Qin, Y., Tang, J., &amp;amp; Liu, Y. (2024). &quot;Model Tells You What to Discard: Adaptive KV Cache Compression for LLMs.&quot; In &lt;em&gt;International Conference on Learning Representations (ICLR)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[5] Shazeer, N. (2019). &quot;Fast Transformer Decoding: One Write-Head is All You Need.&quot; &lt;em&gt;arXiv preprint arXiv:1911.02150&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[6] Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., &amp;amp; Sanghai, S. (2023). &quot;GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.&quot; In &lt;em&gt;Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing (EMNLP)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[7] Dettmers, T., Lewis, M., Belkada, Y., &amp;amp; Zettlemoyer, L. (2022). &quot;LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale.&quot; In &lt;em&gt;Advances in Neural Information Processing Systems (NeurIPS)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[8] Xiao, G., Lin, J., Seznec, M., Wu, H., Demouth, J., &amp;amp; Han, S. (2023). &quot;SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models.&quot; In &lt;em&gt;International Conference on Machine Learning (ICML)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[9] Sheng, Y., Zheng, L., Yuan, B., Li, Z., Ryabinin, M., Chen, B., Liang, P., Ré, C., Stoica, I., &amp;amp; Zhang, C. (2023). &quot;FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU.&quot; In &lt;em&gt;International Conference on Machine Learning (ICML)&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[10] Kwon, W., Li, Z., Zhuang, S., Sheng, Y., Zheng, L., Yu, C. H., Gonzalez, J., Zhang, H., &amp;amp; Stoica, I. (2023). &quot;Efficient Memory Management for Large Language Model Serving with PagedAttention.&quot; In &lt;em&gt;Proceedings of the 29th Symposium on Operating Systems Principles (SOSP)&lt;/em&gt;.&lt;/p&gt;
&lt;/div&gt;

</description>
        <pubDate>Mon, 27 May 2024 00:00:00 -0400</pubDate>
        <link>https://aakashvarma.github.io/kv_cache_optimization/</link>
        <guid isPermaLink="true">https://aakashvarma.github.io/kv_cache_optimization/</guid>
        
      </item>
    
  </channel>
</rss>

