nouamanetazi HF staff commited on
Commit
e87aa99
·
1 Parent(s): 835c7e8
dist/assets/.DS_Store DELETED
Binary file (6.15 kB)
 
dist/bibliography.bib CHANGED
@@ -316,6 +316,15 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
316
  archivePrefix={arXiv},
317
  primaryClass={cs.AI}
318
  }
 
 
 
 
 
 
 
 
 
319
  @misc{hendrycks2021measuring,
320
  title={Measuring Massive Multitask Language Understanding},
321
  author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
 
316
  archivePrefix={arXiv},
317
  primaryClass={cs.AI}
318
  }
319
+ @misc{liu2023ringattentionblockwisetransformers,
320
+ title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
321
+ author={Hao Liu and Matei Zaharia and Pieter Abbeel},
322
+ year={2023},
323
+ eprint={2310.01889},
324
+ archivePrefix={arXiv},
325
+ primaryClass={cs.CL},
326
+ url={https://arxiv.org/abs/2310.01889},
327
+ }
328
  @misc{hendrycks2021measuring,
329
  title={Measuring Massive Multitask Language Understanding},
330
  author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
dist/index.html CHANGED
The diff for this file is too large to render. See raw diff
 
src/bibliography.bib CHANGED
@@ -316,6 +316,15 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
316
  archivePrefix={arXiv},
317
  primaryClass={cs.AI}
318
  }
 
 
 
 
 
 
 
 
 
319
  @misc{hendrycks2021measuring,
320
  title={Measuring Massive Multitask Language Understanding},
321
  author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
 
316
  archivePrefix={arXiv},
317
  primaryClass={cs.AI}
318
  }
319
+ @misc{liu2023ringattentionblockwisetransformers,
320
+ title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
321
+ author={Hao Liu and Matei Zaharia and Pieter Abbeel},
322
+ year={2023},
323
+ eprint={2310.01889},
324
+ archivePrefix={arXiv},
325
+ primaryClass={cs.CL},
326
+ url={https://arxiv.org/abs/2310.01889},
327
+ }
328
  @misc{hendrycks2021measuring,
329
  title={Measuring Massive Multitask Language Understanding},
330
  author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
src/index.html CHANGED
@@ -865,22 +865,23 @@
865
 
866
  <h4>Memory usage revisited</h4>
867
 
868
- <p><a target="_self" href="#memory_usage_in_transformers">Earlier</a>, we discussed the memory usage of optimizer states, gradients, and parameters during standard training. Let's call our model's parameter count <d-math>\Psi</d-math> (previously this was <d-math>N</d-math>, but here we use the original ZeRO paper's notation <!-- RH: Add a citation for the paper? -->). In mixed precision training (discussed further <a target="_self" href="#mixed_precision_training">later in the book</a>) with the Adam optimizer, the memory usage for each item we need to store is:</p>
869
 
870
  <ul>
871
  <li>Model’s parameters (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
872
  <li>Model’s gradients (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
873
  <li>Model’s parameters in FP32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
874
- <li>Model’s gradients in FP32: <d-math>4\Psi</d-math> (optional, only accounted <!-- RH: only needed / included? --> if we want to accumulate gradients in FP32)</li>
875
  </ul>
876
 
877
- <p>If we dont accumulate gradients in FP32, this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we do it gives us <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Lets focus for now on the case without FP32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3 <!-- RH: Would "...for simplicity; for ZeRO-2 and -3, you can just add the additional bytes to the gradient term" work here? This could use a little clarification, I think. -->.</p>
878
 
879
  <p>The idea of ZeRO is to shard these objects across the DP ranks, with each node only storing a slice of the items. These slices are then reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
880
 
881
  <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
882
  <p>Here, <d-math>\Psi</d-math> denotes the number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam, as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
883
 
 
884
 
885
  <p>Let’s explain this by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
886
 
@@ -890,17 +891,16 @@
890
 
891
  <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts, where <d-math>N_d</d-math> is the DP degree. This means that the model replicas distributed on the DP ranks each only keep track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states, and during the optimization step, only <d-math>\frac{1}{N_d}</d-math> of the FP32 weights are updated.</p>
892
 
893
- <p>However, during the forward pass, each replica needs all the parameters. We thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective <!-- RH: distributed? --> communication primitive we've encountered!) after the optimizer step so that each model replica has the full set of updated weights.</p>
894
 
895
  <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw in the previous figure! Here’s a summary of the sequence of operations for a single training step:</p>
896
- <!-- RH: Should the following be an ordered list instead? -->
897
- <ul>
898
  <li>Perform a forward pass with the same full set of BF16 parameters on each replica, but different micro-batches across replicas.</li>
899
  <li>Perform a backward pass with the same full set of gradients on each replica, but different micro-batches across replicas.</li>
900
  <li>Perform a <strong><em>reduce-scatter</em></strong> on the gradients (another primitive - we'll explain this one shortly).</li>
901
  <li>Each replica performs an optimizer step on its local optimizer states (only <d-math>\frac{1}{N_d}</d-math> of the optimizer states) to get <d-math>\frac{1}{N_d}</d-math> updated FP32 parameters, which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of BF16 parameters.</li>
902
  <li>Perform an all-gather on the BF16 parameters to send the missing slices back to each replica. This is a new operation in ZeRO and is not used in vanilla DP.</li>
903
- </ul>
904
  <aside>Note: Reduce-scatter is two times faster than all-reduce! <em>Yay, a third communication primitive!</em></aside>
905
 
906
  <p>You may be wondering what this "reduce-scatter" operation is and what this all looks like, so let's try to make it more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
@@ -914,7 +914,7 @@
914
  <p>If you've been following along, you'll recall from our discussion of vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of BF16 parameters. There are two main strategies for this:</p>
915
 
916
  <ul>
917
- <li><strong>During the optimizer step:</strong> We can initiate the all-gather immediately after the optimizer updates part of the parameters <!-- RH: Or "the first slice of the parameters"? -->. This allows the communication to potentially overlap with the updating of the other parameters.</li>
918
  <li><strong>During the forward pass:</strong> We can overlap the all-gather of each layer’s parameters with the forward pass.</li>
919
  </ul>
920
 
@@ -929,19 +929,19 @@
929
 
930
  <h4>ZeRO-2: Adding <strong>gradient partitioning</strong></h4>
931
 
932
- <p>Since on each replica we only need to have the gradient shard corresponding to its optimizer state shard, it makes sense to shard gradients as well, similarly to the optimizer states. Then, during the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Here, we only spread <!-- RH: we only store? --> the <d-math>\frac{1}{N_d}</d-math> gradients that are needed in memory, thus saving more memory compared to ZeRO-1.</p>
933
 
934
- <aside>In the case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> FP32 grads where <!-- RH: when? --> we accumulate the BF16 grads coming from the reduce-scatter. And in the optimizer step, we use the <!-- RH: we use only those? --> <d-math>\frac{1}{N_d}</d-math> FP32 grads.</aside>
935
 
936
  <p><img alt="dp_zero2.gif" src="/assets/images/dp_zero2.gif" /></p>
937
 
938
- <p>Its easy to see now that sharding the gradients leads to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math>, and as <d-math>N_d</d-math> is increased, we can save up to 8x memory over the baseline <!-- RH: we can use up to eight times less memory than the baseline? It's a little hard for me to visualize what you mean by saving up to 8x memory over something, compared to using up to 8x less (but I think the ideas are the same?). -->. In terms of communication, the same process applies as for ZeRO-1, with the only difference being that we communicate and release <!-- RH: Release what? Memory? --> on the fly: they both require a reduce-scatter for the gradients and an all-gather over all parameters. ZeRO-2 is thus also equivalent to vanilla DP training with regard to communication.</p>
939
  <!-- RH: In this figure, on the right, can "AllGather Params" and "ReduceScatter Grad" be changed to "All-gather params" and "Reduce-scatter grads"? -->
940
  <p><img alt="dp_zero2_overlap.svg" src="/assets/images/dp_zero2_overlap.svg" /></p>
941
 
942
- <aside>Note: You might notice that there is no real overhead to using ZeRO-2 over ZeRO-1, and indeed ZeRO-2 is usually the better option.</aside>
943
 
944
- <p>Now that weve sharded gradients as well, are we done, or can we keep getting away with this? Well, sort of. <!-- RH: Could that just say "...or can we keep making improvements?"? I'm not sure what you mean by getting away with this or (at this point, at least) why it's only "sort of." --> Here comes ZeRO-3!</p>
945
 
946
  <h4>ZeRO-3: Adding <strong>parameter partitioning</strong> (FSDP)</h4>
947
 
@@ -954,7 +954,7 @@
954
  </div>
955
  </div>
956
 
957
- <p>So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply, we gather them <!-- RH: Is "them" all parts of the model, or should that be "the parameters"? --> on demand when we need them. In the forward pass, this looks as follows:</p>
958
 
959
  <p><img alt="dp_zero3_fwd.svg" src="/assets/images/dp_zero3_fwd.svg" /></p>
960
 
@@ -962,15 +962,17 @@
962
 
963
  <p><img alt="dp_zero3_bwd.svg" src="/assets/images/dp_zero3_bwd.svg" /></p>
964
 
965
- <p>The other issue is that we need to do these all-gathers continuously throughout the forward and backward step <!-- RH: forward and backward passes? Or "forward and backward pass steps"? -->, which amounts to <d-math>2\cdot \text{num\_layers} -1</d-math> additional all-gathers in a training step compared to ZeRO-2. Each comes with a small <em>base latency</em> overhead, as we can see in the following figure:</p>
966
  <!-- RH: In this figure, change "AllGather Params" and "ReduceScatter Grads" to "All-gather params" and "Reduce-scatter grads" and lowercase "Free"? -->
967
  <p><img alt="dp_zero3_overlap.svg" src="/assets/images/dp_zero3_overlap.svg" /></p>
968
 
969
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so there's a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we use them in the forward pass, we need one more all-gather during the backward pass as well, incurring another <d-math>\Psi</d-math> communication tax. Finally, we need the same reduce-scatter operation as in ZeRO-2 for the gradients, which also costs <d-math>\Psi</d-math> in communication. So, we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for ZeRO-2.</p>
970
 
971
- <p>This may sound like a lot of communication overhead, but it's actually not a big deal, as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong><em>prefetching</em></strong>. With prefetching, we all-gather the weights for <em>Layer n+1</em> while we do the forward pass for <em>Layer n</em>, and similarly, we all-gather the weights for <em>Layer n-1</em> while doing the backward pass for <em>Layer n</em>. Of course, this overlap only works as long as we don’t scale DP too much (as a rule of thumb, DP <!-- RH: the number of replicas / GPUs? --> shouldn’t exceed 512).</p>
 
 
972
 
973
- <p>In terms of memory, we can see that our equation has now reached its final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math>, which means we can drive memory usage down indefinitely if we can increase the DP rank <!-- RH: increase the number of replicas? Also, should that say "we can theoretically drive memory usage down"? -->, at least for the model-related parameters. Notice that it doesn’t help with the intermediate activations, though - for that, we can use activation checkpointing and gradient accumulation, as we saw earlier.</p>
974
 
975
  <p>Let’s summarize our journey into DP and ZeRO so far. We've seen that we can increase the throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO, we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients, and optimizer states across DP replicas, while incurring a small communication cost.</p>
976
  <aside>If you want to read more about FSDP1, FSDP2, and some of the implementation complexities around them, check out <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
@@ -1006,7 +1008,7 @@
1006
 
1007
 
1008
 
1009
- <p>So, we've sharded the model’s parameters, gradients, and optimizer states with ZeRO, but we hit a limit once activation memory overtakes our memory budget. Welcome tensor parallelism (TP), a method that shards weights, gradients, and optimizer states as well as activations - and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how TP works with simple matrix multiplication (matmul) operations.</p> <!-- RH: That abbreviation is used later, so I wanted to introduce it somewhere. -->
1010
 
1011
  <p>Tensor parallelism leverages the mathematical properties of matrix multiplication, <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
1012
 
@@ -1021,7 +1023,7 @@
1021
 
1022
  <ul>
1023
  <li><d-math>X</d-math> represents the input or activation values</li>
1024
- <li><d-math>W</d-math> represents the weight of the <code>nn.Linear</code></li> <!-- RH: What is "the nn.Linear"? -->
1025
  </ul>
1026
 
1027
  <p>In practice, a small example of the operation looks like this:</p>
@@ -1064,19 +1066,20 @@
1064
 
1065
  <h3>Tensor parallelism in a transformer block</h3>
1066
 
1067
- <p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks: a feedforward multi-layer perceptron (MLP) block and a multi-head attention (MHA) block. <!-- RH: Does that look OK? --> We can apply tensor parallelism to both.</p>
1068
 
1069
- <p>The feedforward part can be parallelized by having a column-linear followed by a row-linear split <!-- RH: Do I have that right? And the changes in the last sentence in this paragraph? -->, which amounts to a broadcast to copy the input and an all-reduce in the forward pass. Note that the broadcast isn’t needed in actual training, where we can make sure inputs are already synced across TP ranks. This setup is more efficient than starting with a row-linear followed by column-linear split, as we can skip the intermediate all-reduce between the split operations.</p>
1070
 
1071
  <p><img alt="image.png" src="/assets/images/tp_diagram4.png" /></p>
1072
 
1073
- <p>Now that weve found an efficient schema for the feedforward part of the transformer, lets take a look at the multi-head attention block.</p>
1074
-
1075
- <p>We can generally follow a similar approach here, where the Query (Q), Key (K), and Value (V) matrices are split in a column-parallel fashion and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads <!-- RH: Is just "computes the attention" OK there (rather than attention scores/matrix/whatever), and should you give an indication of what heads are (e.g., "attention heads" or "heads (attention mechanisms)" or however you would describe them), or will readers be expected to follow all of this? -->. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query attention (MQA)</em></strong></a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention (GQA)</em></strong></a>, where keys and values are shared between queries. </p>
1076
 
1077
- <p>It's worth noting, however, that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank <!-- RH: Is there a clearer way to say "intact heads per TP rank," or do you think that's OK as is? --> (otherwise, we cannot compute the attentions <!-- RH: attention scores? --> independently on each GPU and we'll need additional communication operations). For instance, Llama-3 8B has 8 K/V heads, so the TP degree should ideally not exceed 8. If we use TP=16 for this model, we will need to duplicate the K/V heads on each GPU and make sure they stay in sync. If we’re using GQA, the TP degree should actually be smaller than the number of K/V heads. </p> <!-- RH: I moved the sentence about using GQA to the end of the paragraph because the Llama example you give doesn't really support the statement that the TP degree should be SMALLER than the TP degree; just that it shouldn't be larger. (I'm not sure if that model always uses GQA...) -->
1078
 
1079
  <p><img alt="image.png" src="/assets/images/tp_full_diagram.png" /></p>
 
 
 
1080
 
1081
  <p>Note also that tensor parallelism is not a silver bullet for training. We’ve added several distributed communication primitives directly in the computation path of our model, which are therefore hard to fully hide/overlap with computation (like we did in ZeRO), so our final performance will be the result of a trade-off between the computation and memory gains and the added communication overhead. Let's illustrate this:</p>
1082
  <!-- RH: In this figure, change "TP region" to "TP Region" and "AllReduce Activs" to "All-reduce activs"? -->
@@ -1084,11 +1087,11 @@
1084
 
1085
  <aside>It's possible to partially hide this communication by performing block matrix multiplication coupled with async communication/computation.</aside>
1086
 
1087
- <p>Looking at the timeline of operations in tensor-parallel MLP (the same applies for attention <!-- RH: MHA? Or OK as is? -->), we can better understand the trade-offs involved. In the forward pass of each decoder layer, we hit a synchronization point with the all-reduce operation that cannot be overlapped with computation. This <em>exposed communication overhead</em> is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
1088
 
1089
- <aside>For example, Megatron-LM/Nanotron <!-- RH: Megatron-LM and Nanotron? --> implement a partial overlapping of all-gather with FC1 computation <!-- RH: What is FC1 computation? -->, where a portion of the matrix multiplication result gets sent to the other GPU while the remaining part is still being computed.</aside>
1090
 
1091
- <p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, TP introduces significant communication requirements that heavily depend on the network infrastructure. The inability to fully hide this particular all-reduce behind computation means it directly adds to the critical path of forward propagation.</p> <!-- RH: In what sense does it "directly add to the critical path"? (What is it that it adds to?) -->
1092
 
1093
  <aside>This is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. </aside>
1094
 
@@ -1176,7 +1179,7 @@
1176
  <li><em>f</em> is an all-reduce to synchronize gradients.</li>
1177
  </ul>
1178
 
1179
- <p>These <em>f</em> and <em>f*</em> operations are called <strong><em>conjugate pairs</em></strong> because they complement each other - in each pass, when one is a no-op the other is an all-reduce, and it's the opposite in the other pass.</p> <!-- RH: Does that work? -->
1180
 
1181
  <p>For sequence parallelism, we use different operations labeled <em>g</em> and <em>g*</em>. Specifically, we avoid using all-reduce in the SP regions since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
1182
 
@@ -1186,24 +1189,24 @@
1186
  <div>
1187
  <p style="margin-bottom: 0;"><strong>Initial LayerNorm layer (SP region)</strong></p>
1188
  <ul style="margin-top: 0;">
1189
- <li>Input tensors <em>X1</em> and <em>X2</em> <d-math>(b,s/2,h)</d-math> enter, already split across the sequence dimension.</li>
1190
- <li>Each GPU computes LayerNorm independently on its sequence chunk, giving <em>Y1</em> and <em>Y2</em>.</li> <!-- RH: In the figure the labels are X1*, X2*, Y1*, and Y2*. Add the *s here (and in the first list item below), or remove them in the figure? -->
1191
  </ul>
1192
  <p style="margin-bottom: 0;"><strong>First transition (SP → TP)</strong></p>
1193
  <ul style="margin-top: 0;">
1194
  <li><em>g</em> operation (all-gather) combines <em>Y1</em> and <em>Y2</em> back to full sequence length.</li>
1195
- <li>Restores <em>Y</em> <d-math>(b,s,h)<d-math>) since column-linear layers need the full hidden dimension <d-math>h</d-math>.</li>
1196
  </ul>
1197
  <p style="margin-bottom: 0;"><strong>First linear layer (TP region)</strong></p>
1198
  <ul style="margin-top: 0;">
1199
- <li><em>A</em> is a column-linear layer, so it splits <em>Y</em> along the hidden dimension.</li>
1200
  <li>GELU is applied independently on each GPU.</li>
1201
- <li><em>Z1*</em> and <em>Z2*</em> are <d-math>(b,s,h/2)</d-math>).</li>
1202
  </ul>
1203
  <p style="margin-bottom: 0;"><strong>Second linear layer (TP region)</strong></p>
1204
  <ul style="margin-top: 0;">
1205
- <li><em>B</em> is a row-linear layer, so it restores the hidden dimension.</li>
1206
- <li><em>W1</em> and <em>W2</em> are <d-math>(b,s,h)</d-math>.</li>
1207
  </ul>
1208
  <p style="margin-bottom: 0;"><strong>Final transition (TP → SP)</strong></p>
1209
  <ul style="margin-top: 0;">
@@ -1388,7 +1391,7 @@
1388
 
1389
  <p><img alt="ring-attention.gif" src="/assets/images/ring-attention.gif" /></p>
1390
 
1391
- <p>It's probably obvious to you from this animation why the authors chose to call this approach Ring Attention!</p> <!-- RH: Add a citation for the original Ring Attention paper? -->
1392
 
1393
  <p>There is one big problem, though, which is that a naive implementation of Ring Attention leads to some strong imbalances between GPUs due to the shape of the causal attention matrix. Let’s take a look at the softmax computation by considering the attention score matrix with the causal attention mask:</p>
1394
 
@@ -1399,8 +1402,10 @@
1399
  <p>Let’s see if we can balance our computations better.</p>
1400
 
1401
  <h3>Zig-Zag Ring Attention – A balanced compute implementation</h3>
1402
-
1403
- <p>We need a better way to distribute the input sequences. This can be achieved by not assigning the tokens to the GPUs in a purely sequential manner, but instead mixing up the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag Attention<d-cite bibtex-key="brandon2023fasterring"></d-cite> <!-- RH: The paper you point to here doesn't actually include the name "Zig-Zag"; it's on Striped Attention. Are they two names for the same thing, or should you point to something else here (or call it Striped Attention instead)? -->. In this new arrangement, the attention mask will show an even distribution of computation, but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
 
 
1404
 
1405
  <p><img alt="cp_zigzagmask.svg" src="/assets/images/cp_zigzagmask.svg" /></p>
1406
 
@@ -1447,8 +1452,7 @@
1447
  </div>
1448
  </div>
1449
 
1450
- <!-- RH: Do the following changes look OK? -->
1451
- <p>In the <a target="_self" href="#tensor-parallelism">"Tensor Parallelism"</a> section, we saw that trying to scale tensor parallelism past the number of GPUs on a single node - typically 4 or 8 - requires lower-bandwidth network communication, which can significantly impair performance. We can see the effects of this inter-node communication clearly in the all-reduce operation when we benchmark it on our cluster across several nodes (each node here has 8 GPUs):</p>
1452
 
1453
  <!-- <iframe class="l-body-outset" id="plotFrame11" src="assets/data/benchmarks/pp_comm_bandwidth.html" width="90%" scrolling="no" frameborder="0"></iframe> -->
1454
  <div class="l-body-outset" id="fragment-pp_comm_bandwidth"></div>
@@ -1639,14 +1643,14 @@
1639
 
1640
  <p><img alt="image.png" src="/assets/images/pp_zerobubble_ppschedule.png" /></p>
1641
 
1642
- <div class="figure-legend"><p>On the top (Figure 2 from the Zero Bubble paper): the classical 1F1B schedule, interleaving forward and backward passes but keeping a coarse-grained backward pass. On the bottom (Figure 3 from the Zero Bubble paper): two handcrafted schedules splitting the backward pass into finer-grained <d-math>B</d-math> and <d-math>W</d-math> operations. The lower schedule is an example of a (theoretical) zero bubble schedule taking advantage of this fine-grained decomposition.</p> <!-- RH: Adjusted based on the paper. -->
1643
  </div>
1644
 
1645
  <p>DeepSeek’s DualPipe, introduced with its V3 technical report <d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>, is an extension of this decomposed approach to the additional case of two streams propagating from both ends of the PP dimension, with these streams being interleaved to further minimize idle time in the GPUs. This schedule is displayed in the following scheduling graph - as you can see, it's even more complex than the previous ones:</p>
1646
 
1647
  <p><img alt="image.png" src="/assets/images/pp_zerobubble_dualpipe.png" /></p>
1648
 
1649
- <p>In general, fully optimizing such complex schedules involves carefully measuring the duration of the various fine-grained operations and solving a ILP <!-- RH: solving an integer linear programming problem? --> to minimize the final bubble time. (See, for instance, the Zero Bubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms used to perform such scheduling.) As a result, the zero bubble and DualPipe schedules are too complex for us to give code snippets here, but you should have a general idea of the concepts involved. </p>
1650
 
1651
  <p>This concludes our tour of the world of pipeline schedules and bubbles. We hope you enjoyed it!</p>
1652
 
@@ -1701,7 +1705,7 @@
1701
 
1702
  <p>However, there are a few major differences between the PP and ZeRO-3 approaches:</p>
1703
 
1704
- <aside>Note here that we say a layer to simplify what should in general be called “a set of layers” (as the basis <!-- RH: basic? Or OK as is? --> sharding unit of the model).</aside>
1705
  <div class="l-body">
1706
  <table>
1707
  <thead>
@@ -1756,11 +1760,11 @@
1756
 
1757
  <p>The main reason we don't want to use TP only for parallelism is that, in practice, TP has two limitations (which we've discussed in previous sections). First, since its communication operations are part of the critical path of computation, it's difficult to scale well beyond a certain point, after which communication overhead begins to dominate. Second, unlike ZeRO and PP, which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more cumbersome to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.</p>
1758
 
1759
- <p>As a consequence, when combining parallelism strategies, TP will typically be kept for high-speed intra-node communications, while ZeRO-3 or PP can be used for parallelism groups spanning lower-speed inter-node communications <!-- RH: Or just "parallelism groups spanning nodes"? --> as their communication patterns require less bandwidth (for PP) or can be more easily overlapped with computation (for ZeRO-3). The main consideration when combining these techniques is to organize the GPUs efficiently in groups for each parallelism dimension to maximize throughput and minimize communication overhead, while being mindful of TP's scaling limitations. For instance, the groups of GPUs communicating for TP should be kept inside nodes.</p>
1760
 
1761
  <p><strong>Context parallelism</strong> and <strong>expert parallelism</strong> also help us shard activations, and can be seen as complementary to TP. CP handles long sequences while EP enables distributed Mixture of Experts training, and they can be combined without any particular issues.</p>
1762
 
1763
- <p>CP specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most operations <!-- RH: layers? -->, like MLP and LayerNorm, can process these sharded sequences independently, attention layers require communication since each token needs access to keys/values from the full sequence. As we saw in the <a target="_self" href="#context_parallelism">CP section</a>, this is handled efficiently through Ring Attention patterns that overlap computation and communication. CP is particularly valuable when scaling to extreme sequence lengths (128k+ tokens) where, even when using full activation recomputation, the memory requirements for attention would be prohibitive on a single GPU.</p>
1764
 
1765
  <div class="large-image-background-transparent">
1766
  <div class="boxed-image">
@@ -1768,7 +1772,7 @@
1768
  </div>
1769
  </div>
1770
 
1771
- <p><strong>Expert parallelism</strong> specifically targets the challenge of training MoE models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication operations in EP are the "all-to-all" operations routing tokens to their assigned experts and gathering the results back. While this introduces some communication overhead, it enables scaling model capacity significantly since each token is only processed during inference (and training) by a much smaller fraction of the total parameters <!-- RH: Is "parameters" the right word there? -->. In terms of distributed training/inference, partitioning experts across GPUs becomes relevant when models scales to a large number of experts.</p>
1772
  <aside>For instance, DeepSeek-V3 uses 256 experts.</aside>
1773
 
1774
  <div class="large-image-background-transparent">
@@ -1809,7 +1813,7 @@
1809
  </tr>
1810
  <tr>
1811
  <td>Communication for matrix multiplication operations (column/row linear)</td>
1812
- <td>Communication for attention keys/values</td> <!-- RH: Should "keys/values" be in parentheses? -->
1813
  <td>Communication for token routing to experts</td>
1814
  </tr>
1815
  <tr>
@@ -1930,7 +1934,7 @@
1930
  <li>At 512+ GPU scale, pure data parallelism/ZeRO-3 will start to becomes inefficient due to communication cost - it can be better to then combine DP with either tensor or pipeline parallelism.</li>
1931
  <li>At 1024+ GPU scale, a recommended setup may be tensor parallelism (TP=8) with data parallelism (ZeRO-2) and pipeline parallelism.</li>
1932
  </ul>
1933
- <aside>We focus on fitting a single instance for now - even though we may use DP for ZeRO to achieve this goal - we're only interested here in the model parameters memory savings that it provide when used with ZeRO-3.</aside> <!-- RH: I'm having trouble working out what you're trying to say here. Maybe something like "We'll focus on fitting a single instance for now. Even though we may use DP to achieve this goal, we're only interested here in the model parameter memory savings that it provides when used with ZeRO-3."? -->
1934
  </p>
1935
 
1936
  <p><em>Special considerations:</em>
@@ -1984,7 +1988,7 @@
1984
 
1985
  <p>In the <a href="https://github.com/huggingface/nanotron">Nanotron</a> repository, you'll find several scripts you can use to run all the experiments discussed previously and benchmark your own model and cluster.</p>
1986
 
1987
- <p>We actually ran benchmarks ourselves on <strong>several thousand distributed configurations</strong>, covering every model size we've discussed here as well as a very large number of cluster configurations (namely, 1-64 nodes of 8xH100s) <!-- RH: Maybe leave out the detail about the number of configurations (the part in parentheses) here, as it's included in the next section? Or do you think it's important to repeat it? --> in order to produce the results we've covered up to now in this book.</p>
1988
  <aside>We want to take this opportunity to apologize to our coworkers for blocking most of the science cluster, and in turn forgive any threats that may have been whispered.</aside>
1989
 
1990
  <p>Now let's take a step back to gather and analyze the results of all our benchmarks and see if, beyond theory, we can actually discover using real-world data how various configurations fare against each other.</p>
@@ -2013,7 +2017,7 @@
2013
 
2014
  <h3>Lessons learned on benchmarking</h3>
2015
 
2016
- <p>Our goal for this book was not only to discuss theory and implementations, but to provide actual data points as well. So, the plan was simple: let's run every possible distributed configuration for every model and a number of cluster sizes (namely 1-64 nodes of 8xH100s). Even after excluding impossible configurations, we still needed to run thousands of experiments. </p>
2017
 
2018
  <p>
2019
  On paper, this sounds easy enough: we can easily launch big arrays of jobs on our cluster. However, as soon as we launched the first batches of experiments, our troubles began:
@@ -2032,7 +2036,7 @@
2032
  <li>Minimizing cluster restart times and optimizing idle time</li>
2033
  <li>Analyzing detailed NCCL debug logs</li>
2034
  <li>Understanding memory usage patterns and CUDA memory allocator behaviors</li>
2035
- <li>Improving pipeline parallelism performance on multi-node <!-- RH: on multi-node clusters? --></li>
2036
  </ul>
2037
 
2038
  <p>These challenges taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.</p>
@@ -2094,7 +2098,7 @@
2094
  <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
2095
  <div class="figure-legend"><p>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p></div>
2096
 
2097
- <p>The goal of a GPU <!-- RH: Or "The goal when using a GPU"? (Is it the GPU's goal, or yours?) --> is to run as many workloads as possible, in parallel, on the available cores, by taking advantage of this hierarchical organization of compute/memory resources.</p>
2098
 
2099
  <p>A piece of code running on a core of the GPU is called a <strong><em>kernel</em></strong>. It can be written at a high level in CUDA or Triton, for instance, and is then compiled to Parallel Thread Execution (PTX), the low-level assembly used by NVIDIA GPUs.</p>
2100
 
@@ -2160,7 +2164,7 @@
2160
 
2161
  <ul>
2162
  <li>Threads are grouped in <em>warps</em>, each containing 32 threads. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li>
2163
- <li>Warps are grouped in larger <em>blocks</em> of more flexible size (for example, there may be 512 or 1,024 threads in a block) <!-- RH: Does that work? -->, with each block assigned to a single SM. An SM may run several blocks in parallel. However, depending on the resources available, not all of the blocks may get assigned for execution immediately; some may be waitlisted until more resources become available.</li>
2164
  </ul>
2165
 
2166
  <p>The main thing to retain here is that there are various sizing and allocation constraints (size of the various memories, number of concurrent blocks and threads in the warps) which need to be taken into account to use the GPU architecture in the most efficient way.</p>
@@ -2270,7 +2274,7 @@
2270
 
2271
  </ol>
2272
 
2273
- <p>We'll start by looking at one of the most frequent uses of CUDA: optimizing memory access. <!-- RH: Does that work? --> The global memory in GPUs (the largest memory area, as you saw earlier) has a long latency and low bandwidth in comparison to the cache, creating a major bottleneck for many applications. Efficiently accessing data from global memory can greatly improve performance.</p>
2274
 
2275
  <h4>Memory coalescing</h4>
2276
 
@@ -2421,7 +2425,7 @@
2421
  <p>In several places now, we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workloads on the GPU in a non-blocking way.</p>
2422
 
2423
  <p>This can be useful for overlapping communication and computation – as we've seen many times in our journey – but it can also be extended to the more general idea of trying to avoid, at all cost, going back and forth between host and GPU kernel commands.</p>
2424
- <p>This idea is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
2425
  <div style="display: flex; gap: 20px; align-items: flex-start;">
2426
  <div style="width: 50%;">
2427
  <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
@@ -2440,22 +2444,23 @@
2440
  <p>How can we avoid the back and forth shown on the left? Well, the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations as possible together in a single kernel for the GPU to run, called a “fused kernel,” as shown on the right.</p>
2441
 
2442
 
2443
- <p>Fused kernels are especially efficient and simple to write for successions of point-like operations that are performed independently of each other on each input token. In this case, there is no point in bringing back computed values in global memory <!-- RH: Should that say "there is no point in sending the computed values back to global memory"? --> before moving them to SM memory and spinning up a new kernel. Its much more efficient to keep all the values locally until all the computations have been performed.</p>
2444
 
2445
- <p>In a Transformer model, this "fusing" approach can be applied every time we have a succession of point-wise operations, such as in the computations involved in the LayerNorm layers.</p> <!-- RH: OK? -->
2446
 
2447
  <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>FlashAttention</em></strong>.</p>
2448
 
2449
  <h3>FlashAttention</h3>
2450
 
2451
  <p>FlashAttention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize attention computations by writing custom CUDA kernels to make them much faster <em>and</em> more memory efficient. The idea behind FlashAttention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory.</p>
2452
- <p>Confusingly, some GPUs use a stacked DRAM technology called High Bandwidth Memory (HBM) 🫠 for their global memory.<!-- RH: Does that work? It didn't seem right to say global memory is also called HBM. Also, as this acronym is used in the text, I think it needs to be introduced in the main text, not an aside. -->
 
2453
 
2454
- A basic implementation of the attention mechanism involves a lot of transfer between memory and workers. It requires materializing the <strong>S</strong> and <strong>P</strong> matrices (representing the softmax-normalized attention weights computed in a block-wise manner and the softmax-normalized attention scores at each step of the computation, respectively) <!-- RH: Does that work? I think you need to explain what S and P are. --> in HBM, which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
2455
 
2456
  <p style="text-align: center"><img alt="image.png" src="/assets/images/flashattn.png" style="width: 500px" /></p>
2457
 
2458
- <p>Since bandwidth is much lower in HBM, this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
2459
 
2460
  <p>The key element is to compute the <strong>S</strong> matrix in small pieces that can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large <strong>S</strong> matrix altogether, in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So, we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not only do we make use of the shared memory, but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length): the attention matrix.</p>
2461
 
@@ -2465,7 +2470,7 @@
2465
  <p>The idea of FlashAttention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers. Notably:</p>
2466
  <ul>
2467
  <li>By avoiding materializing the <strong>S</strong> matrix, we <strong>reduce the memory burden of attention</strong>.</li>
2468
- <li>We also <strong>remove a large part of the naive impact of the S^2 <!-- RH: Should that be in math mode? --> cost of attention</strong>.</li>
2469
  </ul>
2470
 
2471
  <p>All variants of linear attention and subquadratic approaches to approximate attention (developed shortly after the invention of the Transformer architecture) have mostly been put aside in favor of this exact and fast FlashAttention implementation and mechanism.</p>
@@ -2557,15 +2562,15 @@
2557
 
2558
  <p>We can see that float32 spans 80 orders of magnitude, and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further: e5m2 can maintain the float16 range, but e4m3 has an even smaller range.</p>
2559
 
2560
- <p>How come some formats are able to maintain the range <!-- RH: maintain the full range? Or just "How come some formats are able to maintain larger ranges than others"? --> and others aren't? Lets investigate their resolutions by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
2561
 
2562
  <p><img alt="image.png" src="/assets/images/mixedprecision_2.png" /></p>
2563
 
2564
- <p>We can see here that although bfloat16 maintains the range of float32 (unlike float16), it does this at the cost of sacrificing more precision. In the case of float8 the situation is even more dire: e4m3 can represent only 7 and e5m2 only 3 numbers in the interval [1,2].</p><!-- RH: Does that look OK? -->
2565
 
2566
  <p>A common metric to measure a format's resolution is <em>epsilon</em>: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format, <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it's ~<d-math>10^{-3}</d-math>, and for bfloat it's 10x higher still.</p>
2567
 
2568
- <p>The idea of mixed precision training is to use some of these lower-precision formats for certain computations while maintaining the performance of full precision training. As it turns out, we <strong>can’t</strong> totally abandon float32 and usually will need to do some of the computations in full precision. </p><!-- RH: Does that work? -->
2569
 
2570
  <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further, all the way down to 8 bits.</p>
2571
 
@@ -2599,7 +2604,7 @@
2599
  <!-- Hynek uncomment this once it's added to -->
2600
  <!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
2601
 
2602
- <p>The first successful very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3 <!-- RH: I.e., in the DeepSeek-V3 technical report? -->. The authors carefully analyzed each operation of the forward pass (<em>Fprop</em>) as well as the activation (<em>Dgrad</em>) and weight (<em>Wgrad</em>) backward passes. Similar to BF16 mixed precision training, some aggregations and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
2603
 
2604
  <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2605
 
@@ -2996,7 +3001,7 @@
2996
  <h3>A0: Parallel Programming Crash Course</h3>
2997
 
2998
 
2999
- <p>Throughout this book, we've scaled LLM training from one to hundreds of GPUs. This requires the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong><em>collective operations</em></strong>. In this section, we’ll do a small crash course on those operations - <em>Broadcast</em>, <em>AllReduce</em>, <em>Scatter</em>, and more. Let’s dive in!</p> <!-- RH: Up to now, you've styled the names of the operations broadcast, all-reduce, etc. Would it be better to stick to the same style in the Appendix, to avoid any potential confusion on the part of the reader? Or refer to them as patterns when using the uppercase versions of the names? -->
3000
 
3001
  <p>The general setup is that we have a number of independent nodes, which could be CPU cores, GPUs, or compute nodes. Each performs some computation, and then we want to communicate the result or parts of it to the other nodes for the next computation step (<d-math>t+1</d-math>).</p>
3002
 
@@ -3199,7 +3204,7 @@
3199
 
3200
  <p>As the name suggests, the goal of the Scatter operation is to take data on one node and scatter it across all the nodes, which it does by distributing a slice of the data to each node. It’s thus different from the Broadcast operation, which sends each node a complete copy of the data without slicing it, and it’s the logical inverse of the Gather operation.</p>
3201
 
3202
- <p>The ReduceScatter pattern is slightly more complex. As in the Reduce case, you apply an operation on the data from all nodes. But then, instead of moving the result to just one node, you slice and distribute it evenly to all nodes. The following image illustrates the difference between these operations:</p> <!-- RH: Does that look OK? -->
3203
 
3204
  <p style="text-align: center"><img alt="image.png" src="/assets/images/a0_scatter_reducescatter.png" style="width: 1000px" /></p>
3205
 
@@ -3286,16 +3291,15 @@
3286
  <ol>
3287
  <li><strong>ReduceScatter</strong></li>
3288
  <ul>
3289
- <li>Each device splits its data (e.g., gradients) into chunks <!-- RH: into two chunks? I.e., it keeps one and sends one on? --> and sends one chunk to its neighbor. Simultaneously, each device receives a chunk from its other neighbor.</li>
3290
  <li>As each device receives a chunk, it adds (reduces) its corresponding chunk to the received one.</li>
3291
- <li>This process continues around the ring until each device holds a partially reduced chunk representing the sum of the gradients across all devices for that chunk.</li> <!-- RH: Is it partially reduced or fully reduced? See below, "each device needs to collect the fully reduced chunks from the other devices." -->
3292
  </ul>
3293
  <li><strong>AllGather</strong></li>
3294
  <ul>
3295
  <li>Now, each device needs to collect the fully reduced chunks from the other devices.</li>
3296
- <li>The devices start sending their reduced chunks to neighbors.</li><!-- RH: I.e., Each device sends its reduced chunk to its neighbor, and receives the reduced chunk from its other neighbor? -->
3297
- <li>Each device forwards the chunks it receives <!-- RH: The devices continue forwarding the chunks they receive? --> until every device has all the fully reduced chunks, giving each device the complete, summed-up gradient.</li>
3298
- </ul>
3299
  </ol>
3300
 
3301
  <p>Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first animation shows the ReduceScatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (the orange rectangle).</p>
@@ -3368,13 +3372,9 @@
3368
  </ul>
3369
 
3370
  <p>There are a few finer points in the decision tree that we leave to the reader to explore in the PyTorch guide referenced above.</p>
3371
-
3372
- <p>Now that we've covered the fundamental operations for distributed training, you should be ready to follow the blog post easily.</p> <!-- RH: What blog post? Do you mean the PyTorch guide? If so, I'd just say "you should find it easy to follow" (and combine these last two paragraphs into one). -->
3373
 
3374
  <h3>A1: Distributed Training Profiling</h3>
3375
 
3376
- <p>The next topic we'll explore is profiling.</p> <!-- RH: Feel free to adjust that, but I think there should be some sort of lead-in to this section above the first subhead. -->
3377
-
3378
  <h4>Kernels</h4>
3379
 
3380
  <p>Let's begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the layer normalization function implemented in PyTorch as <code>torch.nn.functional.layer_norm</code>. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python <code>time</code> module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.</p>
@@ -3402,7 +3402,7 @@
3402
  return start.elapsed_time(end)
3403
  </d-code>
3404
 
3405
- <p>A more efficient approach to profiling is to utilize the PyTorch profiler, as explained previously <!-- RH: add a link to "Profiling the memory usage"? -->. For example, consider the following code:</p>
3406
 
3407
  <d-code block language="python">
3408
  import torch
@@ -3555,11 +3555,11 @@
3555
 
3556
  <li><strong>Activations (hidden states):</strong> For a single layer, the hidden state tensor is of size <d-math>seq \cdot mbs \cdot h</d-math> elements.</li>
3557
 
3558
- <li><strong>Model weights and gradients:</strong> Each weight matrix in your model (like in linears <!-- RH: What do you mean by "like in linears"? -->) is <!-- RH: holds / contains? --> about <d-math>h^2</d-math> elements. Gradients have the same size as weights.</li>
3559
 
3560
- <li><strong>Optimizer states:</strong> For each weight matrix (of <d-math>h^2</d-math> elements), an optimizer like Adam with mixed precision training will keep momentum and variance states in FP32 precision (<d-math>2 \cdot h^2</d-math>), plus master weights in FP32 (<d-math>h^2</d-math>). So, the total optimizer states <!-- RH: I.e., the total number of optimizer states? --> will be around (<d-math>6 \cdot h^2</d-math>) per weight matrix.</li>
3561
 
3562
- <li><strong>Total model parameters:</strong> For each transformer block: <!-- RH: I.e., Each transformer block will store: ? -->
3563
  <ul>
3564
  <li>Attention parameters:
3565
  <ul>
@@ -3567,7 +3567,7 @@
3567
  <li>Output projection: <d-math>h^2</d-math> parameters</li>
3568
  </ul>
3569
  </li>
3570
- <li>MLP parameters with GLU: <!-- RH: Gated Linear Units (GLU)? -->
3571
  <ul>
3572
  <li>Gate and up projections: <d-math>8h^2</d-math> parameters (2 matrices of size <d-math>h \times 4h</d-math>)</li>
3573
  <li>Down projection: <d-math>4h^2</d-math> parameters (1 matrix of size <d-math>4h \times h</d-math>)</li>
@@ -3667,7 +3667,7 @@
3667
  `
3668
  <h4>TP communication analysis</h4>
3669
 
3670
- <p>For tensor parallelism, activations are sharded across GPUs during linear operations <!-- RH: Is that what you meant by "during linears"? And is the addition of "operation" in the following list OK? -->. Let's analyze the communication pattern:</p>
3671
 
3672
  <ul>
3673
  <li>For each column-linear operation in the forward pass:
@@ -3684,7 +3684,7 @@
3684
  <li>Total communication per block: <d-math>8 \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
3685
  <li>Total communication for full model: <d-math>8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
3686
  </ul>
3687
- <p>Let's analyze if we can overlap the all-gather communication for one layer with the computation of the next linear layer <!-- RH: Is layer correct, here and below? -->. The communication time for all-gather operations is:</p>
3688
 
3689
  <d-math block>
3690
  t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
@@ -3701,7 +3701,7 @@
3701
  \frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
3702
  </d-math>
3703
 
3704
- <p>This ratio tells us whether we can successfully hide the all-gather communication behind the computation of the next linear layer. Interestingly, the ratio only depends on the hidden size <d-math>h</d-math> and tensor parallelism degree <d-math>tp</d-math>, not on sequence length or batch size.</p>
3705
 
3706
 
3707
  <h4>PP communication analysis</h4>
 
865
 
866
  <h4>Memory usage revisited</h4>
867
 
868
+ <p><a target="_self" href="#memory_usage_in_transformers">Earlier</a>, we discussed the memory usage of optimizer states, gradients, and parameters during standard training. Let's call our model's parameter count <d-math>\Psi</d-math> (previously this was <d-math>N</d-math>, but here we use the original ZeRO paper's<d-cite bibtex-key="rajbhandari2020zero"></d-cite> notation). In mixed precision training (discussed further <a target="_self" href="#mixed_precision_training">later in the book</a>) with the Adam optimizer, the memory usage for each item we need to store is:</p>
869
 
870
  <ul>
871
  <li>Model’s parameters (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
872
  <li>Model’s gradients (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
873
  <li>Model’s parameters in FP32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
874
+ <li>Model’s gradients in FP32: <d-math>4\Psi</d-math> (optional, only included if we want to accumulate gradients in FP32)</li>
875
  </ul>
876
 
877
+ <p>If we don't accumulate gradients in FP32, this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we do it gives us <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let's focus for now on the case without FP32 gradient accumulation for simplicity.</p>
878
 
879
  <p>The idea of ZeRO is to shard these objects across the DP ranks, with each node only storing a slice of the items. These slices are then reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
880
 
881
  <p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
882
  <p>Here, <d-math>\Psi</d-math> denotes the number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam, as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
883
 
884
+ <aside>If you're using FP32 gradient accumulation with ZeRO-2 or ZeRO-3, you would need to add an additional <d-math>\frac{4\Psi}{N_d}</d-math> to the gradient term.</aside>
885
 
886
  <p>Let’s explain this by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
887
 
 
891
 
892
  <p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts, where <d-math>N_d</d-math> is the DP degree. This means that the model replicas distributed on the DP ranks each only keep track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states, and during the optimization step, only <d-math>\frac{1}{N_d}</d-math> of the FP32 weights are updated.</p>
893
 
894
+ <p>However, during the forward pass, each replica needs all the parameters. We thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we've encountered!) after the optimizer step so that each model replica has the full set of updated weights.</p>
895
 
896
  <p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw in the previous figure! Here’s a summary of the sequence of operations for a single training step:</p>
897
+ <ol>
 
898
  <li>Perform a forward pass with the same full set of BF16 parameters on each replica, but different micro-batches across replicas.</li>
899
  <li>Perform a backward pass with the same full set of gradients on each replica, but different micro-batches across replicas.</li>
900
  <li>Perform a <strong><em>reduce-scatter</em></strong> on the gradients (another primitive - we'll explain this one shortly).</li>
901
  <li>Each replica performs an optimizer step on its local optimizer states (only <d-math>\frac{1}{N_d}</d-math> of the optimizer states) to get <d-math>\frac{1}{N_d}</d-math> updated FP32 parameters, which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of BF16 parameters.</li>
902
  <li>Perform an all-gather on the BF16 parameters to send the missing slices back to each replica. This is a new operation in ZeRO and is not used in vanilla DP.</li>
903
+ </ol>
904
  <aside>Note: Reduce-scatter is two times faster than all-reduce! <em>Yay, a third communication primitive!</em></aside>
905
 
906
  <p>You may be wondering what this "reduce-scatter" operation is and what this all looks like, so let's try to make it more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
 
914
  <p>If you've been following along, you'll recall from our discussion of vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of BF16 parameters. There are two main strategies for this:</p>
915
 
916
  <ul>
917
+ <li><strong>During the optimizer step:</strong> We can initiate the all-gather immediately after the optimizer updates the first slice of the parameters. This allows the communication to potentially overlap with the updating of the other parameters.</li>
918
  <li><strong>During the forward pass:</strong> We can overlap the all-gather of each layer’s parameters with the forward pass.</li>
919
  </ul>
920
 
 
929
 
930
  <h4>ZeRO-2: Adding <strong>gradient partitioning</strong></h4>
931
 
932
+ <p>Since on each replica we only need to have the gradient shard corresponding to its optimizer state shard, it makes sense to shard gradients as well, similarly to the optimizer states. Then, during the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Here, we only store the <d-math>\frac{1}{N_d}</d-math> gradients that are needed in memory, thus saving more memory compared to ZeRO-1.</p>
933
 
934
+ <aside>In the case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> FP32 grads used to accumulate the BF16 grads coming from the reduce-scatter. And in the optimizer step, these <d-math>\frac{1}{N_d}</d-math> FP32 grads are used to update the local shard of the optimizer states.</aside>
935
 
936
  <p><img alt="dp_zero2.gif" src="/assets/images/dp_zero2.gif" /></p>
937
 
938
+ <p>It's easy to see now that sharding the gradients leads to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math>, and as <d-math>N_d</d-math> is increased, we can use up to 8x less memory than the baseline. In terms of communication, the same process applies as for ZeRO-1, with the only difference being that we communicate and release memory on the fly: they both require a reduce-scatter for the gradients and an all-gather over all parameters. ZeRO-2 is thus also equivalent to vanilla DP training with regard to communication.</p>
939
  <!-- RH: In this figure, on the right, can "AllGather Params" and "ReduceScatter Grad" be changed to "All-gather params" and "Reduce-scatter grads"? -->
940
  <p><img alt="dp_zero2_overlap.svg" src="/assets/images/dp_zero2_overlap.svg" /></p>
941
 
942
+ <aside>Note: You might notice that there is no real overhead to using ZeRO-2 over ZeRO-1 besides implementation complexity, and indeed ZeRO-2 is usually the better option.</aside>
943
 
944
+ <p>Now that we've sharded gradients as well, are we done, or can we keep making improvements? Here comes ZeRO-3!</p>
945
 
946
  <h4>ZeRO-3: Adding <strong>parameter partitioning</strong> (FSDP)</h4>
947
 
 
954
  </div>
955
  </div>
956
 
957
+ <p>So how do we do a forward or backward pass in practice if the parameters of the model are distributed? Quite simply, we gather them on demand when we need them. In the forward pass, this looks as follows:</p>
958
 
959
  <p><img alt="dp_zero3_fwd.svg" src="/assets/images/dp_zero3_fwd.svg" /></p>
960
 
 
962
 
963
  <p><img alt="dp_zero3_bwd.svg" src="/assets/images/dp_zero3_bwd.svg" /></p>
964
 
965
+ <p>The other issue is that we need to do these all-gathers continuously throughout the forward and backward pass in a training step, which amounts to <d-math>2\cdot \text{num\_layers} -1</d-math> additional all-gathers in a training step compared to ZeRO-2. Each comes with a small <em>base latency</em> overhead, as we can see in the following figure:</p>
966
  <!-- RH: In this figure, change "AllGather Params" and "ReduceScatter Grads" to "All-gather params" and "Reduce-scatter grads" and lowercase "Free"? -->
967
  <p><img alt="dp_zero3_overlap.svg" src="/assets/images/dp_zero3_overlap.svg" /></p>
968
 
969
  <p>During the forward pass we do all-gather operations for the parameters when we need them, so there's a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we use them in the forward pass, we need one more all-gather during the backward pass as well, incurring another <d-math>\Psi</d-math> communication tax. Finally, we need the same reduce-scatter operation as in ZeRO-2 for the gradients, which also costs <d-math>\Psi</d-math> in communication. So, we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for ZeRO-2.</p>
970
 
971
+ <p>This may sound like a lot of communication overhead, but it's actually not a big deal, as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong><em>prefetching</em></strong>. With prefetching, we all-gather the weights for <em>Layer n+1</em> while we do the forward pass for <em>Layer n</em>, and similarly, we all-gather the weights for <em>Layer n-1</em> while doing the backward pass for <em>Layer n</em>. Of course, this overlap only works as long as we don’t scale DP too much (as a rule of thumb, DP shouldn’t exceed 512).</p>
972
+
973
+ <aside>Note: We use "DP" to refer to both the data parallelism technique and the number of GPUs used for data parallelism <em>(DP = DP size = DP degree)</em>.</aside>
974
 
975
+ <p>In terms of memory, we can see that our equation has now reached its final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math>, which means we can theoretically drive memory usage down indefinitely if we can increase the DP size, at least for the model-related parameters. Notice that it doesn’t help with the intermediate activations, though - for that, we can use activation checkpointing and gradient accumulation, as we saw earlier.</p>
976
 
977
  <p>Let’s summarize our journey into DP and ZeRO so far. We've seen that we can increase the throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO, we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients, and optimizer states across DP replicas, while incurring a small communication cost.</p>
978
  <aside>If you want to read more about FSDP1, FSDP2, and some of the implementation complexities around them, check out <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
 
1008
 
1009
 
1010
 
1011
+ <p>So, we've sharded the model’s parameters, gradients, and optimizer states with ZeRO, but we hit a limit once activation memory overtakes our memory budget. Welcome tensor parallelism (TP), a method that shards weights, gradients, and optimizer states as well as activations - and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how TP works with simple matrix multiplication (matmul) operations.</p>
1012
 
1013
  <p>Tensor parallelism leverages the mathematical properties of matrix multiplication, <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
1014
 
 
1023
 
1024
  <ul>
1025
  <li><d-math>X</d-math> represents the input or activation values</li>
1026
+ <li><d-math>W</d-math> represents the weight of the Linear layer</li>
1027
  </ul>
1028
 
1029
  <p>In practice, a small example of the operation looks like this:</p>
 
1066
 
1067
  <h3>Tensor parallelism in a transformer block</h3>
1068
 
1069
+ <p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks: a feedforward multi-layer perceptron (MLP) block and a multi-head attention (MHA) block. We can apply tensor parallelism to both.</p>
1070
 
1071
+ <p>The feedforward part can be parallelized by having a column-linear followed by a row-linear split, which amounts to a broadcast to copy the input and an all-reduce in the forward pass. Note that the broadcast isn’t needed in actual training, where we can make sure inputs are already synced across TP ranks. This setup is more efficient than starting with a row-linear followed by column-linear split, as we can skip the intermediate all-reduce between the split operations.</p>
1072
 
1073
  <p><img alt="image.png" src="/assets/images/tp_diagram4.png" /></p>
1074
 
1075
+ <p>Now that we've found an efficient schema for the feedforward part of the transformer, let's take a look at the multi-head attention block.</p>
 
 
1076
 
1077
+ <p>We can generally follow a similar approach here, where the Query (Q), Key (K), and Value (V) matrices are split in a column-parallel fashion and the output projection can be considered a row-linear. With multi-head attention, the column-parallel approach has a very natural interpretation: each GPU computes the attention for an individual or a subset of attention heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query attention (MQA)</em></strong></a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention (GQA)</em></strong></a>, where keys and values are shared between queries. </p>
1078
 
1079
  <p><img alt="image.png" src="/assets/images/tp_full_diagram.png" /></p>
1080
+
1081
+ <p>We're able to apply tensor parallelism so effectively to both the Attention and MLP blocks because they have dimensions that are naturally independent. The Attention block can be parallelized along the <code>num_attention_heads</code> dimension, as each attention head operates independently. Similarly, the MLP block can be parallelized along the <code>hidden_dim</code> dimension, as operations within the feedforward network are independent along this dimension.</p>
1082
+ <p>It's worth noting, however, that the tensor parallelism degree should not exceed the number of attention heads because we shard the QKV projection along the <code>num_attention_heads</code> dimension. When using Grouped Query Attention (GQA), we have <d-math>num\_attention\_heads</d-math> query heads but only <d-math>num\_kv\_heads</d-math> key/value heads (with <d-math>num\_attention\_heads >= num\_kv\_heads</d-math>). In this case, we can still set <d-math>TP = num\_attention\_heads</d-math>, but we'll need to ensure that the K/V heads stay properly synchronized across GPUs. For instance, Llama-3 8B has 32 query heads but only 8 key/value heads, so while the TP degree could theoretically go up to 32, we would need careful implementation to maintain K/V head synchronization across the tensor-parallel workers.</p>
1083
 
1084
  <p>Note also that tensor parallelism is not a silver bullet for training. We’ve added several distributed communication primitives directly in the computation path of our model, which are therefore hard to fully hide/overlap with computation (like we did in ZeRO), so our final performance will be the result of a trade-off between the computation and memory gains and the added communication overhead. Let's illustrate this:</p>
1085
  <!-- RH: In this figure, change "TP region" to "TP Region" and "AllReduce Activs" to "All-reduce activs"? -->
 
1087
 
1088
  <aside>It's possible to partially hide this communication by performing block matrix multiplication coupled with async communication/computation.</aside>
1089
 
1090
+ <p>Looking at the timeline of operations in tensor-parallel MLP (the same applies for MHA), we can better understand the trade-offs involved. In the forward pass of each decoder layer, we hit a synchronization point with the all-reduce operation that cannot be overlapped with computation. This <em>exposed communication overhead</em> is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
1091
 
1092
+ <aside>For example, Megatron-LM and Nanotron implement a partial overlapping of all-gather with Fully-Connected (FC1) computation, where a portion of the matrix multiplication result gets sent to the other GPU while the remaining part is still being computed.</aside>
1093
 
1094
+ <p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, TP introduces significant communication requirements that heavily depend on the network infrastructure. The inability to fully hide this particular all-reduce behind computation means it directly adds to the critical path of forward propagation, where the critical path refers to the sequence of operations that determine the minimum time required to complete the forward pass.</p>
1095
 
1096
  <aside>This is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. </aside>
1097
 
 
1179
  <li><em>f</em> is an all-reduce to synchronize gradients.</li>
1180
  </ul>
1181
 
1182
+ <p>These <em>f</em> and <em>f*</em> operations are called <strong><em>conjugate pairs</em></strong> because they complement each other - in each pass, when one is a no-op the other is an all-reduce, and it's the opposite in the other pass.</p>
1183
 
1184
  <p>For sequence parallelism, we use different operations labeled <em>g</em> and <em>g*</em>. Specifically, we avoid using all-reduce in the SP regions since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
1185
 
 
1189
  <div>
1190
  <p style="margin-bottom: 0;"><strong>Initial LayerNorm layer (SP region)</strong></p>
1191
  <ul style="margin-top: 0;">
1192
+ <li>Input tensors <em>X1*</em> and <em>X2*</em> <d-math>(b,s/2,h)</d-math> enter, already split across the sequence dimension.</li>
1193
+ <li>Each GPU computes LayerNorm independently on its sequence chunk, giving <em>Y1*</em> and <em>Y2*</em>.</li>
1194
  </ul>
1195
  <p style="margin-bottom: 0;"><strong>First transition (SP → TP)</strong></p>
1196
  <ul style="margin-top: 0;">
1197
  <li><em>g</em> operation (all-gather) combines <em>Y1</em> and <em>Y2</em> back to full sequence length.</li>
1198
+ <li>Restores <em>Y</em> <d-math>(b,s,h)</d-math> since column-linear layers need the full hidden dimension <d-math>h</d-math>.</li>
1199
  </ul>
1200
  <p style="margin-bottom: 0;"><strong>First linear layer (TP region)</strong></p>
1201
  <ul style="margin-top: 0;">
1202
+ <li><em>A1</em> and <em>A2</em> are column-linear layers, so they split <em>Y</em> along the hidden dimension.</li>
1203
  <li>GELU is applied independently on each GPU.</li>
1204
+ <li><em>Z1*</em> and <em>Z2*</em> are <d-math>(b,s,h/2)</d-math>.</li>
1205
  </ul>
1206
  <p style="margin-bottom: 0;"><strong>Second linear layer (TP region)</strong></p>
1207
  <ul style="margin-top: 0;">
1208
+ <li><em>B1</em> and <em>B2</em> are row-linear layers, so they restore the hidden dimension.</li>
1209
+ <li><em>W1</em> and <em>W2</em> are <d-math>(b,s,h)</d-math> that need to be summed together.</li>
1210
  </ul>
1211
  <p style="margin-bottom: 0;"><strong>Final transition (TP → SP)</strong></p>
1212
  <ul style="margin-top: 0;">
 
1391
 
1392
  <p><img alt="ring-attention.gif" src="/assets/images/ring-attention.gif" /></p>
1393
 
1394
+ <p>It's probably obvious to you from this animation why the authors chose to call this approach Ring Attention<d-cite bibtex-key="liu2023ringattentionblockwisetransformers"></d-cite>!</p>
1395
 
1396
  <p>There is one big problem, though, which is that a naive implementation of Ring Attention leads to some strong imbalances between GPUs due to the shape of the causal attention matrix. Let’s take a look at the softmax computation by considering the attention score matrix with the causal attention mask:</p>
1397
 
 
1402
  <p>Let’s see if we can balance our computations better.</p>
1403
 
1404
  <h3>Zig-Zag Ring Attention – A balanced compute implementation</h3>
1405
+ <p>We need a better way to distribute the input sequences. This can be achieved by not assigning the tokens to the GPUs in a purely sequential manner, but instead mixing up the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag Attention. In this new arrangement, the attention mask will show an even distribution of computation, but if you count the number of colored squares, you'll see that the computation is now balanced across all GPUs.</p>
1406
+ <aside>
1407
+ <p>We show here Zig-Zag Attention, which slightly differs from Striped Attention<d-cite bibtex-key="brandon2023fasterring"></d-cite>. For details on the differences, check <a href="https://github.com/zhuzilin/ring-flash-attention/issues/2#issuecomment-2236746166">this GitHub discussion</a>.</p>
1408
+ </aside>
1409
 
1410
  <p><img alt="cp_zigzagmask.svg" src="/assets/images/cp_zigzagmask.svg" /></p>
1411
 
 
1452
  </div>
1453
  </div>
1454
 
1455
+ <p>In the <a target="_self" href="#tensor-parallelism">"Tensor Parallelism"</a> section, we saw that trying to scale tensor parallelism past the number of GPUs on a single node - typically 4 or 8 - forces us to use lower-bandwidth network communication, which can significantly impair performance. We can see the effects of this inter-node communication clearly in the all-reduce operation when we benchmark it on our cluster across several nodes (each node here has 8 GPUs):</p>
 
1456
 
1457
  <!-- <iframe class="l-body-outset" id="plotFrame11" src="assets/data/benchmarks/pp_comm_bandwidth.html" width="90%" scrolling="no" frameborder="0"></iframe> -->
1458
  <div class="l-body-outset" id="fragment-pp_comm_bandwidth"></div>
 
1643
 
1644
  <p><img alt="image.png" src="/assets/images/pp_zerobubble_ppschedule.png" /></p>
1645
 
1646
+ <div class="figure-legend"><p>On the top (Figure 2 from the Zero Bubble paper): the classical 1F1B schedule, interleaving forward and backward passes but keeping a coarse-grained backward pass. On the bottom (Figure 3 from the Zero Bubble paper): two handcrafted schedules splitting the backward pass into finer-grained <d-math>B</d-math> and <d-math>W</d-math> operations. The lower schedule is an example of a (theoretical) zero bubble schedule taking advantage of this fine-grained decomposition.</p>
1647
  </div>
1648
 
1649
  <p>DeepSeek’s DualPipe, introduced with its V3 technical report <d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>, is an extension of this decomposed approach to the additional case of two streams propagating from both ends of the PP dimension, with these streams being interleaved to further minimize idle time in the GPUs. This schedule is displayed in the following scheduling graph - as you can see, it's even more complex than the previous ones:</p>
1650
 
1651
  <p><img alt="image.png" src="/assets/images/pp_zerobubble_dualpipe.png" /></p>
1652
 
1653
+ <p>In general, fully optimizing such complex schedules involves carefully measuring the duration of the various fine-grained operations and solving an Integer Linear Programming (ILP) problem to minimize the final bubble time. (See, for instance, the Zero Bubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms used to perform such scheduling.) As a result, the zero bubble and DualPipe schedules are too complex for us to give code snippets here, but you should have a general idea of the concepts involved. </p>
1654
 
1655
  <p>This concludes our tour of the world of pipeline schedules and bubbles. We hope you enjoyed it!</p>
1656
 
 
1705
 
1706
  <p>However, there are a few major differences between the PP and ZeRO-3 approaches:</p>
1707
 
1708
+ <aside>Note here that we say "a layer" to simplify, but the actual sharding unit of the model can vary - it might be multiple layers, a single layer, or even a subset of a layer, depending on the specific implementation.</aside>
1709
  <div class="l-body">
1710
  <table>
1711
  <thead>
 
1760
 
1761
  <p>The main reason we don't want to use TP only for parallelism is that, in practice, TP has two limitations (which we've discussed in previous sections). First, since its communication operations are part of the critical path of computation, it's difficult to scale well beyond a certain point, after which communication overhead begins to dominate. Second, unlike ZeRO and PP, which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more cumbersome to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.</p>
1762
 
1763
+ <p>As a consequence, when combining parallelism strategies, TP will typically be kept for high-speed intra-node communications, while ZeRO-3 or PP can be used for parallelism groups spanning lower-speed inter-node communications as their communication patterns require less bandwidth (for PP) or can be more easily overlapped with computation (for ZeRO-3). The main consideration when combining these techniques is to organize the GPUs efficiently in groups for each parallelism dimension to maximize throughput and minimize communication overhead, while being mindful of TP's scaling limitations. For instance, the groups of GPUs communicating for TP should be kept inside nodes.</p>
1764
 
1765
  <p><strong>Context parallelism</strong> and <strong>expert parallelism</strong> also help us shard activations, and can be seen as complementary to TP. CP handles long sequences while EP enables distributed Mixture of Experts training, and they can be combined without any particular issues.</p>
1766
 
1767
+ <p>CP specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most modules, like MLP and LayerNorm, can process these sharded sequences independently, attention blocks require communication since each token needs access to keys/values from the full sequence. As we saw in the <a target="_self" href="#context_parallelism">CP section</a>, this is handled efficiently through Ring Attention patterns that overlap computation and communication. CP is particularly valuable when scaling to extreme sequence lengths (128k+ tokens) where, even when using full activation recomputation, the memory requirements for attention would be prohibitive on a single GPU.</p>
1768
 
1769
  <div class="large-image-background-transparent">
1770
  <div class="boxed-image">
 
1772
  </div>
1773
  </div>
1774
 
1775
+ <p><strong>Expert parallelism</strong> specifically targets the challenge of training MoE models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication operations in EP are the "all-to-all" operations routing tokens to their assigned experts and gathering the results back. While this introduces some communication overhead, it enables scaling model capacity significantly since each token is only processed during inference (and training) by a much smaller fraction of the total parameters. In terms of distributed training/inference, partitioning experts across GPUs becomes relevant when models scales to a large number of experts.</p>
1776
  <aside>For instance, DeepSeek-V3 uses 256 experts.</aside>
1777
 
1778
  <div class="large-image-background-transparent">
 
1813
  </tr>
1814
  <tr>
1815
  <td>Communication for matrix multiplication operations (column/row linear)</td>
1816
+ <td>Communication for attention keys/values</td>
1817
  <td>Communication for token routing to experts</td>
1818
  </tr>
1819
  <tr>
 
1934
  <li>At 512+ GPU scale, pure data parallelism/ZeRO-3 will start to becomes inefficient due to communication cost - it can be better to then combine DP with either tensor or pipeline parallelism.</li>
1935
  <li>At 1024+ GPU scale, a recommended setup may be tensor parallelism (TP=8) with data parallelism (ZeRO-2) and pipeline parallelism.</li>
1936
  </ul>
1937
+ <aside>We focus on fitting a single instance for now - even though we may use DP for ZeRO to achieve this goal - we're only interested here in the model parameters memory savings that it provide when used with ZeRO-3.</aside> <!-- RH: I'm having trouble working out what you're trying to say here. Maybe something like "We'll focus on fitting a single instance for now. Even though we may use DP to achieve this goal, we're only interested here in the model parameter memory savings that it provides when used with ZeRO-3."? --> <!-- NT: I'm in favor of removing as well -->
1938
  </p>
1939
 
1940
  <p><em>Special considerations:</em>
 
1988
 
1989
  <p>In the <a href="https://github.com/huggingface/nanotron">Nanotron</a> repository, you'll find several scripts you can use to run all the experiments discussed previously and benchmark your own model and cluster.</p>
1990
 
1991
+ <p>We actually ran benchmarks ourselves on <strong>several thousand distributed configurations</strong>, covering every model size we've discussed here as well as a very large number of cluster configurations (namely, 1-64 nodes of 8xH100s) in order to produce the results we've covered up to now in this book.</p>
1992
  <aside>We want to take this opportunity to apologize to our coworkers for blocking most of the science cluster, and in turn forgive any threats that may have been whispered.</aside>
1993
 
1994
  <p>Now let's take a step back to gather and analyze the results of all our benchmarks and see if, beyond theory, we can actually discover using real-world data how various configurations fare against each other.</p>
 
2017
 
2018
  <h3>Lessons learned on benchmarking</h3>
2019
 
2020
+ <p>Our goal for this book was not only to discuss theory and implementations, but to provide actual data points as well. So, the plan was simple: let's run every possible distributed configuration for every model and a number of cluster sizes. Even after excluding impossible configurations, we still needed to run thousands of experiments. </p>
2021
 
2022
  <p>
2023
  On paper, this sounds easy enough: we can easily launch big arrays of jobs on our cluster. However, as soon as we launched the first batches of experiments, our troubles began:
 
2036
  <li>Minimizing cluster restart times and optimizing idle time</li>
2037
  <li>Analyzing detailed NCCL debug logs</li>
2038
  <li>Understanding memory usage patterns and CUDA memory allocator behaviors</li>
2039
+ <li>Improving pipeline parallelism performance on multi-node setups</li>
2040
  </ul>
2041
 
2042
  <p>These challenges taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.</p>
 
2098
  <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
2099
  <div class="figure-legend"><p>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p></div>
2100
 
2101
+ <p>The goal when using a GPU is to run as many workloads as possible, in parallel, on the available cores, by taking advantage of this hierarchical organization of compute/memory resources.</p>
2102
 
2103
  <p>A piece of code running on a core of the GPU is called a <strong><em>kernel</em></strong>. It can be written at a high level in CUDA or Triton, for instance, and is then compiled to Parallel Thread Execution (PTX), the low-level assembly used by NVIDIA GPUs.</p>
2104
 
 
2164
 
2165
  <ul>
2166
  <li>Threads are grouped in <em>warps</em>, each containing 32 threads. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li>
2167
+ <li>Warps are grouped in larger <em>blocks</em> of more flexible size (for example, there may be 512 or 1,024 threads in a block), with each block assigned to a single SM. An SM may run several blocks in parallel. However, depending on the resources available, not all of the blocks may get assigned for execution immediately; some may be waitlisted until more resources become available.</li>
2168
  </ul>
2169
 
2170
  <p>The main thing to retain here is that there are various sizing and allocation constraints (size of the various memories, number of concurrent blocks and threads in the warps) which need to be taken into account to use the GPU architecture in the most efficient way.</p>
 
2274
 
2275
  </ol>
2276
 
2277
+ <p>We'll start by looking at one of the most frequent uses of CUDA: optimizing memory access. The global memory in GPUs (the largest memory area, as you saw earlier) has a long latency and low bandwidth in comparison to the cache, creating a major bottleneck for many applications. Efficiently accessing data from global memory can greatly improve performance.</p>
2278
 
2279
  <h4>Memory coalescing</h4>
2280
 
 
2425
  <p>In several places now, we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workloads on the GPU in a non-blocking way.</p>
2426
 
2427
  <p>This can be useful for overlapping communication and computation – as we've seen many times in our journey – but it can also be extended to the more general idea of trying to avoid, at all cost, going back and forth between host and GPU kernel commands.</p>
2428
+ <p>This idea is beautifully illustrated by <a href="https://upload.wikimedia.org/wikipedia/commons/b/b2/Hausziege_04.jpg">Horace He</a> in <a href="https://horace.io/brrr_intro.html">these diagrams:</a></p>
2429
  <div style="display: flex; gap: 20px; align-items: flex-start;">
2430
  <div style="width: 50%;">
2431
  <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
 
2444
  <p>How can we avoid the back and forth shown on the left? Well, the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations as possible together in a single kernel for the GPU to run, called a “fused kernel,” as shown on the right.</p>
2445
 
2446
 
2447
+ <p>Fused kernels are especially efficient and simple to write for successions of point-like operations that are performed independently of each other on each input token. In this case, there is no point in sending the computed values back to global memory before moving them to SM memory and spinning up a new kernel. It's much more efficient to keep all the values locally until all the computations have been performed.</p>
2448
 
2449
+ <p>In a Transformer model, this "fusing" approach can be applied every time we have a succession of point-wise operations, such as in the computations involved in the LayerNorm layers.</p>
2450
 
2451
  <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>FlashAttention</em></strong>.</p>
2452
 
2453
  <h3>FlashAttention</h3>
2454
 
2455
  <p>FlashAttention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize attention computations by writing custom CUDA kernels to make them much faster <em>and</em> more memory efficient. The idea behind FlashAttention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory.</p>
2456
+
2457
+ <p>The global memory in modern GPUs often uses a technology called <a href="https://semianalysis.com/2024/09/03/the-memory-wall/#hbm-roadmap"></a>High Bandwidth Memory (HBM)</a>, which despite its name, is slower than SRAM in the GPU memory hierarchy. This HBM terminology will be important when we discuss the details of FlashAttention's implementation.</p>
2458
 
2459
+ <p>A basic implementation of the attention mechanism involves a lot of transfer between memory and workers. It requires materializing the <strong>S</strong> matrix (where S = QK^T, the attention scores) and the <strong>P</strong> matrix (where P = softmax(S), the normalized attention weights) in HBM, which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
2460
 
2461
  <p style="text-align: center"><img alt="image.png" src="/assets/images/flashattn.png" style="width: 500px" /></p>
2462
 
2463
+ <p>Since bandwidth is much lower in HBM, this introduces a severe bottleneck in the attention computation. Can we do better? <a href="https://upload.wikimedia.org/wikipedia/commons/b/b2/Hausziege_04.jpg">Tri Dao</a> says yes!</p>
2464
 
2465
  <p>The key element is to compute the <strong>S</strong> matrix in small pieces that can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large <strong>S</strong> matrix altogether, in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So, we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not only do we make use of the shared memory, but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length): the attention matrix.</p>
2466
 
 
2470
  <p>The idea of FlashAttention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers. Notably:</p>
2471
  <ul>
2472
  <li>By avoiding materializing the <strong>S</strong> matrix, we <strong>reduce the memory burden of attention</strong>.</li>
2473
+ <li>We also <strong>remove a large part of the naive impact of the <d-math>O(S^2)</d-math> cost of attention</strong>.</li>
2474
  </ul>
2475
 
2476
  <p>All variants of linear attention and subquadratic approaches to approximate attention (developed shortly after the invention of the Transformer architecture) have mostly been put aside in favor of this exact and fast FlashAttention implementation and mechanism.</p>
 
2562
 
2563
  <p>We can see that float32 spans 80 orders of magnitude, and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further: e5m2 can maintain the float16 range, but e4m3 has an even smaller range.</p>
2564
 
2565
+ <p>How come some formats are able to maintain the full range while others aren't? Let's investigate their resolutions by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
2566
 
2567
  <p><img alt="image.png" src="/assets/images/mixedprecision_2.png" /></p>
2568
 
2569
+ <p>We can see here that although bfloat16 maintains the range of float32 (unlike float16), it does this at the cost of sacrificing more precision. In the case of float8 the situation is even more dire: e4m3 can represent only 7 and e5m2 only 3 numbers in the interval [1,2].</p>
2570
 
2571
  <p>A common metric to measure a format's resolution is <em>epsilon</em>: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format, <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it's ~<d-math>10^{-3}</d-math>, and for bfloat it's 10x higher still.</p>
2572
 
2573
+ <p>The idea of mixed precision training is to use some of these lower-precision formats for certain computations while maintaining the performance of full precision training. As it turns out, we <strong>can’t</strong> totally abandon float32 and usually will need to do some of the computations in full precision. </p>
2574
 
2575
  <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further, all the way down to 8 bits.</p>
2576
 
 
2604
  <!-- Hynek uncomment this once it's added to -->
2605
  <!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
2606
 
2607
+ <p>The first successful very large scale training with FP8 mixed precision was publicly reported in the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. The authors carefully analyzed each operation of the forward pass (<em>Fprop</em>) as well as the activation (<em>Dgrad</em>) and weight (<em>Wgrad</em>) backward passes. Similar to BF16 mixed precision training, some aggregations and master weights are kept in higher precision while the operations themselves are performed in FP8.</p>
2608
 
2609
  <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2610
 
 
3001
  <h3>A0: Parallel Programming Crash Course</h3>
3002
 
3003
 
3004
+ <p>Throughout this book, we've scaled LLM training from one to hundreds of GPUs. This requires the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong><em>collective operations</em></strong>. In this section, we’ll do a small crash course on those operations - <em>Broadcast</em>, <em>AllReduce</em>, <em>Scatter</em>, and more. Let’s dive in!</p> <!-- RH: Up to now, you've styled the names of the operations broadcast, all-reduce, etc. Would it be better to stick to the same style in the Appendix, to avoid any potential confusion on the part of the reader? Or refer to them as patterns when using the uppercase versions of the names? --> <!-- NT: Both are fine to me, shouldn't be confusing to reader i guess -->
3005
 
3006
  <p>The general setup is that we have a number of independent nodes, which could be CPU cores, GPUs, or compute nodes. Each performs some computation, and then we want to communicate the result or parts of it to the other nodes for the next computation step (<d-math>t+1</d-math>).</p>
3007
 
 
3204
 
3205
  <p>As the name suggests, the goal of the Scatter operation is to take data on one node and scatter it across all the nodes, which it does by distributing a slice of the data to each node. It’s thus different from the Broadcast operation, which sends each node a complete copy of the data without slicing it, and it’s the logical inverse of the Gather operation.</p>
3206
 
3207
+ <p>The ReduceScatter pattern is slightly more complex. As in the AllReduce case , you apply an operation on the data from all nodes. But unlike AllReduce where each node receives the full output tensor, in ReduceScatter each node only receives a slice of the output tensor. The following image illustrates the difference between these operations:</p>
3208
 
3209
  <p style="text-align: center"><img alt="image.png" src="/assets/images/a0_scatter_reducescatter.png" style="width: 1000px" /></p>
3210
 
 
3291
  <ol>
3292
  <li><strong>ReduceScatter</strong></li>
3293
  <ul>
3294
+ <li>Each device splits its data (e.g., gradients) into N chunks (where N is the number of GPUs) and sends one chunk to its neighbor. Simultaneously, each device receives a chunk from its other neighbor.</li>
3295
  <li>As each device receives a chunk, it adds (reduces) its corresponding chunk to the received one.</li>
3296
+ <li>This process continues around the ring until each device holds a fully reduced chunk representing the sum of the gradients across all devices for that chunk.</li>
3297
  </ul>
3298
  <li><strong>AllGather</strong></li>
3299
  <ul>
3300
  <li>Now, each device needs to collect the fully reduced chunks from the other devices.</li>
3301
+ <li>Each device sends its reduced chunk to its neighbor, and receives the reduced chunk from its other neighbor.</li>
3302
+ <li>The devices continue forwarding the chunks they receive until every device has all the fully reduced chunks, giving each device the complete, summed-up gradient.</li>
 
3303
  </ol>
3304
 
3305
  <p>Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first animation shows the ReduceScatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (the orange rectangle).</p>
 
3372
  </ul>
3373
 
3374
  <p>There are a few finer points in the decision tree that we leave to the reader to explore in the PyTorch guide referenced above.</p>
 
 
3375
 
3376
  <h3>A1: Distributed Training Profiling</h3>
3377
 
 
 
3378
  <h4>Kernels</h4>
3379
 
3380
  <p>Let's begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the layer normalization function implemented in PyTorch as <code>torch.nn.functional.layer_norm</code>. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python <code>time</code> module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.</p>
 
3402
  return start.elapsed_time(end)
3403
  </d-code>
3404
 
3405
+ <p>A more efficient approach to profiling is to utilize the PyTorch profiler, as <a target="_self" href="#profiling_gpu_compute_and_communication">explained previously</a>. For example, consider the following code:</p>
3406
 
3407
  <d-code block language="python">
3408
  import torch
 
3555
 
3556
  <li><strong>Activations (hidden states):</strong> For a single layer, the hidden state tensor is of size <d-math>seq \cdot mbs \cdot h</d-math> elements.</li>
3557
 
3558
+ <li><strong>Model weights and gradients:</strong> Each weight matrix in your model (e.g. linear layer) contains about <d-math>h^2</d-math> elements. Gradients have the same size as weights.</li>
3559
 
3560
+ <li><strong>Optimizer states:</strong> For each weight matrix (of <d-math>h^2</d-math> elements), an optimizer like Adam with mixed precision training will keep momentum and variance states in FP32 precision (<d-math>2 \cdot h^2</d-math>), plus master weights in FP32 (<d-math>h^2</d-math>). So, the total number of optimizer states will be around (<d-math>6 \cdot h^2</d-math>) per weight matrix.</li>
3561
 
3562
+ <li><strong>Total model parameters:</strong> Each transformer block will store:
3563
  <ul>
3564
  <li>Attention parameters:
3565
  <ul>
 
3567
  <li>Output projection: <d-math>h^2</d-math> parameters</li>
3568
  </ul>
3569
  </li>
3570
+ <li>MLP parameters with Gated Linear Units (GLU):
3571
  <ul>
3572
  <li>Gate and up projections: <d-math>8h^2</d-math> parameters (2 matrices of size <d-math>h \times 4h</d-math>)</li>
3573
  <li>Down projection: <d-math>4h^2</d-math> parameters (1 matrix of size <d-math>4h \times h</d-math>)</li>
 
3667
  `
3668
  <h4>TP communication analysis</h4>
3669
 
3670
+ <p>For tensor parallelism, activations are sharded across GPUs in the <a target="_self" href="#sequence_parallelism">TP regions</a> (e.g. MLP block). Let's analyze the communication pattern:</p>
3671
 
3672
  <ul>
3673
  <li>For each column-linear operation in the forward pass:
 
3684
  <li>Total communication per block: <d-math>8 \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
3685
  <li>Total communication for full model: <d-math>8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
3686
  </ul>
3687
+ <p>Let's take a TP region within a layer and analyze if we can overlap the all-gather communication with the computation of the next linear. The communication time for all-gather operations is:</p>
3688
 
3689
  <d-math block>
3690
  t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
 
3701
  \frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
3702
  </d-math>
3703
 
3704
+ <p>This ratio tells us whether we can successfully hide the all-gather communication behind the computation of the next linear. Interestingly, the ratio only depends on the hidden size <d-math>h</d-math> and tensor parallelism degree <d-math>tp</d-math>, not on sequence length or batch size.</p>
3705
 
3706
 
3707
  <h4>PP communication analysis</h4>