Commit
·
e87aa99
1
Parent(s):
835c7e8
- dist/assets/.DS_Store +0 -0
- dist/bibliography.bib +9 -0
- dist/index.html +0 -0
- src/bibliography.bib +9 -0
- src/index.html +85 -85
dist/assets/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
dist/bibliography.bib
CHANGED
@@ -316,6 +316,15 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
|
|
316 |
archivePrefix={arXiv},
|
317 |
primaryClass={cs.AI}
|
318 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
@misc{hendrycks2021measuring,
|
320 |
title={Measuring Massive Multitask Language Understanding},
|
321 |
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
|
|
316 |
archivePrefix={arXiv},
|
317 |
primaryClass={cs.AI}
|
318 |
}
|
319 |
+
@misc{liu2023ringattentionblockwisetransformers,
|
320 |
+
title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
|
321 |
+
author={Hao Liu and Matei Zaharia and Pieter Abbeel},
|
322 |
+
year={2023},
|
323 |
+
eprint={2310.01889},
|
324 |
+
archivePrefix={arXiv},
|
325 |
+
primaryClass={cs.CL},
|
326 |
+
url={https://arxiv.org/abs/2310.01889},
|
327 |
+
}
|
328 |
@misc{hendrycks2021measuring,
|
329 |
title={Measuring Massive Multitask Language Understanding},
|
330 |
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
dist/index.html
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/bibliography.bib
CHANGED
@@ -316,6 +316,15 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
|
|
316 |
archivePrefix={arXiv},
|
317 |
primaryClass={cs.AI}
|
318 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
@misc{hendrycks2021measuring,
|
320 |
title={Measuring Massive Multitask Language Understanding},
|
321 |
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
|
|
316 |
archivePrefix={arXiv},
|
317 |
primaryClass={cs.AI}
|
318 |
}
|
319 |
+
@misc{liu2023ringattentionblockwisetransformers,
|
320 |
+
title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
|
321 |
+
author={Hao Liu and Matei Zaharia and Pieter Abbeel},
|
322 |
+
year={2023},
|
323 |
+
eprint={2310.01889},
|
324 |
+
archivePrefix={arXiv},
|
325 |
+
primaryClass={cs.CL},
|
326 |
+
url={https://arxiv.org/abs/2310.01889},
|
327 |
+
}
|
328 |
@misc{hendrycks2021measuring,
|
329 |
title={Measuring Massive Multitask Language Understanding},
|
330 |
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
src/index.html
CHANGED
@@ -865,22 +865,23 @@
|
|
865 |
|
866 |
<h4>Memory usage revisited</h4>
|
867 |
|
868 |
-
<p><a target="_self" href="#memory_usage_in_transformers">Earlier</a>, we discussed the memory usage of optimizer states, gradients, and parameters during standard training. Let's call our model's parameter count <d-math>\Psi</d-math> (previously this was <d-math>N</d-math>, but here we use the original ZeRO paper's notation
|
869 |
|
870 |
<ul>
|
871 |
<li>Model’s parameters (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
|
872 |
<li>Model’s gradients (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
|
873 |
<li>Model’s parameters in FP32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
|
874 |
-
<li>Model’s gradients in FP32: <d-math>4\Psi</d-math> (optional, only
|
875 |
</ul>
|
876 |
|
877 |
-
<p>If we don
|
878 |
|
879 |
<p>The idea of ZeRO is to shard these objects across the DP ranks, with each node only storing a slice of the items. These slices are then reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
|
880 |
|
881 |
<p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
|
882 |
<p>Here, <d-math>\Psi</d-math> denotes the number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam, as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
|
883 |
|
|
|
884 |
|
885 |
<p>Let’s explain this by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
|
886 |
|
@@ -890,17 +891,16 @@
|
|
890 |
|
891 |
<p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts, where <d-math>N_d</d-math> is the DP degree. This means that the model replicas distributed on the DP ranks each only keep track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states, and during the optimization step, only <d-math>\frac{1}{N_d}</d-math> of the FP32 weights are updated.</p>
|
892 |
|
893 |
-
<p>However, during the forward pass, each replica needs all the parameters. We thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective
|
894 |
|
895 |
<p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw in the previous figure! Here’s a summary of the sequence of operations for a single training step:</p>
|
896 |
-
|
897 |
-
<ul>
|
898 |
<li>Perform a forward pass with the same full set of BF16 parameters on each replica, but different micro-batches across replicas.</li>
|
899 |
<li>Perform a backward pass with the same full set of gradients on each replica, but different micro-batches across replicas.</li>
|
900 |
<li>Perform a <strong><em>reduce-scatter</em></strong> on the gradients (another primitive - we'll explain this one shortly).</li>
|
901 |
<li>Each replica performs an optimizer step on its local optimizer states (only <d-math>\frac{1}{N_d}</d-math> of the optimizer states) to get <d-math>\frac{1}{N_d}</d-math> updated FP32 parameters, which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of BF16 parameters.</li>
|
902 |
<li>Perform an all-gather on the BF16 parameters to send the missing slices back to each replica. This is a new operation in ZeRO and is not used in vanilla DP.</li>
|
903 |
-
</
|
904 |
<aside>Note: Reduce-scatter is two times faster than all-reduce! <em>Yay, a third communication primitive!</em></aside>
|
905 |
|
906 |
<p>You may be wondering what this "reduce-scatter" operation is and what this all looks like, so let's try to make it more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
|
@@ -914,7 +914,7 @@
|
|
914 |
<p>If you've been following along, you'll recall from our discussion of vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of BF16 parameters. There are two main strategies for this:</p>
|
915 |
|
916 |
<ul>
|
917 |
-
<li><strong>During the optimizer step:</strong> We can initiate the all-gather immediately after the optimizer updates
|
918 |
<li><strong>During the forward pass:</strong> We can overlap the all-gather of each layer’s parameters with the forward pass.</li>
|
919 |
</ul>
|
920 |
|
@@ -929,19 +929,19 @@
|
|
929 |
|
930 |
<h4>ZeRO-2: Adding <strong>gradient partitioning</strong></h4>
|
931 |
|
932 |
-
<p>Since on each replica we only need to have the gradient shard corresponding to its optimizer state shard, it makes sense to shard gradients as well, similarly to the optimizer states. Then, during the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Here, we only
|
933 |
|
934 |
-
<aside>In the case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> FP32 grads
|
935 |
|
936 |
<p><img alt="dp_zero2.gif" src="/assets/images/dp_zero2.gif" /></p>
|
937 |
|
938 |
-
<p>It
|
939 |
<!-- RH: In this figure, on the right, can "AllGather Params" and "ReduceScatter Grad" be changed to "All-gather params" and "Reduce-scatter grads"? -->
|
940 |
<p><img alt="dp_zero2_overlap.svg" src="/assets/images/dp_zero2_overlap.svg" /></p>
|
941 |
|
942 |
-
<aside>Note: You might notice that there is no real overhead to using ZeRO-2 over ZeRO-1, and indeed ZeRO-2 is usually the better option.</aside>
|
943 |
|
944 |
-
<p>Now that we
|
945 |
|
946 |
<h4>ZeRO-3: Adding <strong>parameter partitioning</strong> (FSDP)</h4>
|
947 |
|
@@ -954,7 +954,7 @@
|
|
954 |
</div>
|
955 |
</div>
|
956 |
|
957 |
-
<p>So how do we do a forward or backward pass in practice if
|
958 |
|
959 |
<p><img alt="dp_zero3_fwd.svg" src="/assets/images/dp_zero3_fwd.svg" /></p>
|
960 |
|
@@ -962,15 +962,17 @@
|
|
962 |
|
963 |
<p><img alt="dp_zero3_bwd.svg" src="/assets/images/dp_zero3_bwd.svg" /></p>
|
964 |
|
965 |
-
<p>The other issue is that we need to do these all-gathers continuously throughout the forward and backward
|
966 |
<!-- RH: In this figure, change "AllGather Params" and "ReduceScatter Grads" to "All-gather params" and "Reduce-scatter grads" and lowercase "Free"? -->
|
967 |
<p><img alt="dp_zero3_overlap.svg" src="/assets/images/dp_zero3_overlap.svg" /></p>
|
968 |
|
969 |
<p>During the forward pass we do all-gather operations for the parameters when we need them, so there's a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we use them in the forward pass, we need one more all-gather during the backward pass as well, incurring another <d-math>\Psi</d-math> communication tax. Finally, we need the same reduce-scatter operation as in ZeRO-2 for the gradients, which also costs <d-math>\Psi</d-math> in communication. So, we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for ZeRO-2.</p>
|
970 |
|
971 |
-
<p>This may sound like a lot of communication overhead, but it's actually not a big deal, as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong><em>prefetching</em></strong>. With prefetching, we all-gather the weights for <em>Layer n+1</em> while we do the forward pass for <em>Layer n</em>, and similarly, we all-gather the weights for <em>Layer n-1</em> while doing the backward pass for <em>Layer n</em>. Of course, this overlap only works as long as we don’t scale DP too much (as a rule of thumb, DP
|
|
|
|
|
972 |
|
973 |
-
<p>In terms of memory, we can see that our equation has now reached its final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math>, which means we can drive memory usage down indefinitely if we can increase the DP
|
974 |
|
975 |
<p>Let’s summarize our journey into DP and ZeRO so far. We've seen that we can increase the throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO, we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients, and optimizer states across DP replicas, while incurring a small communication cost.</p>
|
976 |
<aside>If you want to read more about FSDP1, FSDP2, and some of the implementation complexities around them, check out <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
|
@@ -1006,7 +1008,7 @@
|
|
1006 |
|
1007 |
|
1008 |
|
1009 |
-
<p>So, we've sharded the model’s parameters, gradients, and optimizer states with ZeRO, but we hit a limit once activation memory overtakes our memory budget. Welcome tensor parallelism (TP), a method that shards weights, gradients, and optimizer states as well as activations - and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how TP works with simple matrix multiplication (matmul) operations.</p>
|
1010 |
|
1011 |
<p>Tensor parallelism leverages the mathematical properties of matrix multiplication, <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
|
1012 |
|
@@ -1021,7 +1023,7 @@
|
|
1021 |
|
1022 |
<ul>
|
1023 |
<li><d-math>X</d-math> represents the input or activation values</li>
|
1024 |
-
<li><d-math>W</d-math> represents the weight of the
|
1025 |
</ul>
|
1026 |
|
1027 |
<p>In practice, a small example of the operation looks like this:</p>
|
@@ -1064,19 +1066,20 @@
|
|
1064 |
|
1065 |
<h3>Tensor parallelism in a transformer block</h3>
|
1066 |
|
1067 |
-
<p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks: a feedforward multi-layer perceptron (MLP) block and a multi-head attention (MHA) block.
|
1068 |
|
1069 |
-
<p>The feedforward part can be parallelized by having a column-linear followed by a row-linear split
|
1070 |
|
1071 |
<p><img alt="image.png" src="/assets/images/tp_diagram4.png" /></p>
|
1072 |
|
1073 |
-
<p>Now that we
|
1074 |
-
|
1075 |
-
<p>We can generally follow a similar approach here, where the Query (Q), Key (K), and Value (V) matrices are split in a column-parallel fashion and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads <!-- RH: Is just "computes the attention" OK there (rather than attention scores/matrix/whatever), and should you give an indication of what heads are (e.g., "attention heads" or "heads (attention mechanisms)" or however you would describe them), or will readers be expected to follow all of this? -->. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query attention (MQA)</em></strong></a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention (GQA)</em></strong></a>, where keys and values are shared between queries. </p>
|
1076 |
|
1077 |
-
<p>
|
1078 |
|
1079 |
<p><img alt="image.png" src="/assets/images/tp_full_diagram.png" /></p>
|
|
|
|
|
|
|
1080 |
|
1081 |
<p>Note also that tensor parallelism is not a silver bullet for training. We’ve added several distributed communication primitives directly in the computation path of our model, which are therefore hard to fully hide/overlap with computation (like we did in ZeRO), so our final performance will be the result of a trade-off between the computation and memory gains and the added communication overhead. Let's illustrate this:</p>
|
1082 |
<!-- RH: In this figure, change "TP region" to "TP Region" and "AllReduce Activs" to "All-reduce activs"? -->
|
@@ -1084,11 +1087,11 @@
|
|
1084 |
|
1085 |
<aside>It's possible to partially hide this communication by performing block matrix multiplication coupled with async communication/computation.</aside>
|
1086 |
|
1087 |
-
<p>Looking at the timeline of operations in tensor-parallel MLP (the same applies for
|
1088 |
|
1089 |
-
<aside>For example, Megatron-LM
|
1090 |
|
1091 |
-
<p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, TP introduces significant communication requirements that heavily depend on the network infrastructure. The inability to fully hide this particular all-reduce behind computation means it directly adds to the critical path of forward propagation
|
1092 |
|
1093 |
<aside>This is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. </aside>
|
1094 |
|
@@ -1176,7 +1179,7 @@
|
|
1176 |
<li><em>f</em> is an all-reduce to synchronize gradients.</li>
|
1177 |
</ul>
|
1178 |
|
1179 |
-
<p>These <em>f</em> and <em>f*</em> operations are called <strong><em>conjugate pairs</em></strong> because they complement each other - in each pass, when one is a no-op the other is an all-reduce, and it's the opposite in the other pass.</p>
|
1180 |
|
1181 |
<p>For sequence parallelism, we use different operations labeled <em>g</em> and <em>g*</em>. Specifically, we avoid using all-reduce in the SP regions since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
|
1182 |
|
@@ -1186,24 +1189,24 @@
|
|
1186 |
<div>
|
1187 |
<p style="margin-bottom: 0;"><strong>Initial LayerNorm layer (SP region)</strong></p>
|
1188 |
<ul style="margin-top: 0;">
|
1189 |
-
<li>Input tensors <em>X1
|
1190 |
-
<li>Each GPU computes LayerNorm independently on its sequence chunk, giving <em>Y1
|
1191 |
</ul>
|
1192 |
<p style="margin-bottom: 0;"><strong>First transition (SP → TP)</strong></p>
|
1193 |
<ul style="margin-top: 0;">
|
1194 |
<li><em>g</em> operation (all-gather) combines <em>Y1</em> and <em>Y2</em> back to full sequence length.</li>
|
1195 |
-
<li>Restores <em>Y</em> <d-math>(b,s,h)
|
1196 |
</ul>
|
1197 |
<p style="margin-bottom: 0;"><strong>First linear layer (TP region)</strong></p>
|
1198 |
<ul style="margin-top: 0;">
|
1199 |
-
<li><em>
|
1200 |
<li>GELU is applied independently on each GPU.</li>
|
1201 |
-
<li><em>Z1*</em> and <em>Z2*</em> are <d-math>(b,s,h/2)</d-math
|
1202 |
</ul>
|
1203 |
<p style="margin-bottom: 0;"><strong>Second linear layer (TP region)</strong></p>
|
1204 |
<ul style="margin-top: 0;">
|
1205 |
-
<li><em>
|
1206 |
-
<li><em>W1</em> and <em>W2</em> are <d-math>(b,s,h)</d-math
|
1207 |
</ul>
|
1208 |
<p style="margin-bottom: 0;"><strong>Final transition (TP → SP)</strong></p>
|
1209 |
<ul style="margin-top: 0;">
|
@@ -1388,7 +1391,7 @@
|
|
1388 |
|
1389 |
<p><img alt="ring-attention.gif" src="/assets/images/ring-attention.gif" /></p>
|
1390 |
|
1391 |
-
<p>It's probably obvious to you from this animation why the authors chose to call this approach Ring Attention
|
1392 |
|
1393 |
<p>There is one big problem, though, which is that a naive implementation of Ring Attention leads to some strong imbalances between GPUs due to the shape of the causal attention matrix. Let’s take a look at the softmax computation by considering the attention score matrix with the causal attention mask:</p>
|
1394 |
|
@@ -1399,8 +1402,10 @@
|
|
1399 |
<p>Let’s see if we can balance our computations better.</p>
|
1400 |
|
1401 |
<h3>Zig-Zag Ring Attention – A balanced compute implementation</h3>
|
1402 |
-
|
1403 |
-
<
|
|
|
|
|
1404 |
|
1405 |
<p><img alt="cp_zigzagmask.svg" src="/assets/images/cp_zigzagmask.svg" /></p>
|
1406 |
|
@@ -1447,8 +1452,7 @@
|
|
1447 |
</div>
|
1448 |
</div>
|
1449 |
|
1450 |
-
|
1451 |
-
<p>In the <a target="_self" href="#tensor-parallelism">"Tensor Parallelism"</a> section, we saw that trying to scale tensor parallelism past the number of GPUs on a single node - typically 4 or 8 - requires lower-bandwidth network communication, which can significantly impair performance. We can see the effects of this inter-node communication clearly in the all-reduce operation when we benchmark it on our cluster across several nodes (each node here has 8 GPUs):</p>
|
1452 |
|
1453 |
<!-- <iframe class="l-body-outset" id="plotFrame11" src="assets/data/benchmarks/pp_comm_bandwidth.html" width="90%" scrolling="no" frameborder="0"></iframe> -->
|
1454 |
<div class="l-body-outset" id="fragment-pp_comm_bandwidth"></div>
|
@@ -1639,14 +1643,14 @@
|
|
1639 |
|
1640 |
<p><img alt="image.png" src="/assets/images/pp_zerobubble_ppschedule.png" /></p>
|
1641 |
|
1642 |
-
<div class="figure-legend"><p>On the top (Figure 2 from the Zero Bubble paper): the classical 1F1B schedule, interleaving forward and backward passes but keeping a coarse-grained backward pass. On the bottom (Figure 3 from the Zero Bubble paper): two handcrafted schedules splitting the backward pass into finer-grained <d-math>B</d-math> and <d-math>W</d-math> operations. The lower schedule is an example of a (theoretical) zero bubble schedule taking advantage of this fine-grained decomposition.</p>
|
1643 |
</div>
|
1644 |
|
1645 |
<p>DeepSeek’s DualPipe, introduced with its V3 technical report <d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>, is an extension of this decomposed approach to the additional case of two streams propagating from both ends of the PP dimension, with these streams being interleaved to further minimize idle time in the GPUs. This schedule is displayed in the following scheduling graph - as you can see, it's even more complex than the previous ones:</p>
|
1646 |
|
1647 |
<p><img alt="image.png" src="/assets/images/pp_zerobubble_dualpipe.png" /></p>
|
1648 |
|
1649 |
-
<p>In general, fully optimizing such complex schedules involves carefully measuring the duration of the various fine-grained operations and solving
|
1650 |
|
1651 |
<p>This concludes our tour of the world of pipeline schedules and bubbles. We hope you enjoyed it!</p>
|
1652 |
|
@@ -1701,7 +1705,7 @@
|
|
1701 |
|
1702 |
<p>However, there are a few major differences between the PP and ZeRO-3 approaches:</p>
|
1703 |
|
1704 |
-
<aside>Note here that we say
|
1705 |
<div class="l-body">
|
1706 |
<table>
|
1707 |
<thead>
|
@@ -1756,11 +1760,11 @@
|
|
1756 |
|
1757 |
<p>The main reason we don't want to use TP only for parallelism is that, in practice, TP has two limitations (which we've discussed in previous sections). First, since its communication operations are part of the critical path of computation, it's difficult to scale well beyond a certain point, after which communication overhead begins to dominate. Second, unlike ZeRO and PP, which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more cumbersome to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.</p>
|
1758 |
|
1759 |
-
<p>As a consequence, when combining parallelism strategies, TP will typically be kept for high-speed intra-node communications, while ZeRO-3 or PP can be used for parallelism groups spanning lower-speed inter-node communications
|
1760 |
|
1761 |
<p><strong>Context parallelism</strong> and <strong>expert parallelism</strong> also help us shard activations, and can be seen as complementary to TP. CP handles long sequences while EP enables distributed Mixture of Experts training, and they can be combined without any particular issues.</p>
|
1762 |
|
1763 |
-
<p>CP specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most
|
1764 |
|
1765 |
<div class="large-image-background-transparent">
|
1766 |
<div class="boxed-image">
|
@@ -1768,7 +1772,7 @@
|
|
1768 |
</div>
|
1769 |
</div>
|
1770 |
|
1771 |
-
<p><strong>Expert parallelism</strong> specifically targets the challenge of training MoE models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication operations in EP are the "all-to-all" operations routing tokens to their assigned experts and gathering the results back. While this introduces some communication overhead, it enables scaling model capacity significantly since each token is only processed during inference (and training) by a much smaller fraction of the total parameters
|
1772 |
<aside>For instance, DeepSeek-V3 uses 256 experts.</aside>
|
1773 |
|
1774 |
<div class="large-image-background-transparent">
|
@@ -1809,7 +1813,7 @@
|
|
1809 |
</tr>
|
1810 |
<tr>
|
1811 |
<td>Communication for matrix multiplication operations (column/row linear)</td>
|
1812 |
-
<td>Communication for attention keys/values</td>
|
1813 |
<td>Communication for token routing to experts</td>
|
1814 |
</tr>
|
1815 |
<tr>
|
@@ -1930,7 +1934,7 @@
|
|
1930 |
<li>At 512+ GPU scale, pure data parallelism/ZeRO-3 will start to becomes inefficient due to communication cost - it can be better to then combine DP with either tensor or pipeline parallelism.</li>
|
1931 |
<li>At 1024+ GPU scale, a recommended setup may be tensor parallelism (TP=8) with data parallelism (ZeRO-2) and pipeline parallelism.</li>
|
1932 |
</ul>
|
1933 |
-
<aside>We focus on fitting a single instance for now - even though we may use DP for ZeRO to achieve this goal - we're only interested here in the model parameters memory savings that it provide when used with ZeRO-3.</aside> <!-- RH: I'm having trouble working out what you're trying to say here. Maybe something like "We'll focus on fitting a single instance for now. Even though we may use DP to achieve this goal, we're only interested here in the model parameter memory savings that it provides when used with ZeRO-3."? -->
|
1934 |
</p>
|
1935 |
|
1936 |
<p><em>Special considerations:</em>
|
@@ -1984,7 +1988,7 @@
|
|
1984 |
|
1985 |
<p>In the <a href="https://github.com/huggingface/nanotron">Nanotron</a> repository, you'll find several scripts you can use to run all the experiments discussed previously and benchmark your own model and cluster.</p>
|
1986 |
|
1987 |
-
<p>We actually ran benchmarks ourselves on <strong>several thousand distributed configurations</strong>, covering every model size we've discussed here as well as a very large number of cluster configurations (namely, 1-64 nodes of 8xH100s)
|
1988 |
<aside>We want to take this opportunity to apologize to our coworkers for blocking most of the science cluster, and in turn forgive any threats that may have been whispered.</aside>
|
1989 |
|
1990 |
<p>Now let's take a step back to gather and analyze the results of all our benchmarks and see if, beyond theory, we can actually discover using real-world data how various configurations fare against each other.</p>
|
@@ -2013,7 +2017,7 @@
|
|
2013 |
|
2014 |
<h3>Lessons learned on benchmarking</h3>
|
2015 |
|
2016 |
-
<p>Our goal for this book was not only to discuss theory and implementations, but to provide actual data points as well. So, the plan was simple: let's run every possible distributed configuration for every model and a number of cluster sizes
|
2017 |
|
2018 |
<p>
|
2019 |
On paper, this sounds easy enough: we can easily launch big arrays of jobs on our cluster. However, as soon as we launched the first batches of experiments, our troubles began:
|
@@ -2032,7 +2036,7 @@
|
|
2032 |
<li>Minimizing cluster restart times and optimizing idle time</li>
|
2033 |
<li>Analyzing detailed NCCL debug logs</li>
|
2034 |
<li>Understanding memory usage patterns and CUDA memory allocator behaviors</li>
|
2035 |
-
<li>Improving pipeline parallelism performance on multi-node
|
2036 |
</ul>
|
2037 |
|
2038 |
<p>These challenges taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.</p>
|
@@ -2094,7 +2098,7 @@
|
|
2094 |
<p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
|
2095 |
<div class="figure-legend"><p>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p></div>
|
2096 |
|
2097 |
-
<p>The goal
|
2098 |
|
2099 |
<p>A piece of code running on a core of the GPU is called a <strong><em>kernel</em></strong>. It can be written at a high level in CUDA or Triton, for instance, and is then compiled to Parallel Thread Execution (PTX), the low-level assembly used by NVIDIA GPUs.</p>
|
2100 |
|
@@ -2160,7 +2164,7 @@
|
|
2160 |
|
2161 |
<ul>
|
2162 |
<li>Threads are grouped in <em>warps</em>, each containing 32 threads. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li>
|
2163 |
-
<li>Warps are grouped in larger <em>blocks</em> of more flexible size (for example, there may be 512 or 1,024 threads in a block)
|
2164 |
</ul>
|
2165 |
|
2166 |
<p>The main thing to retain here is that there are various sizing and allocation constraints (size of the various memories, number of concurrent blocks and threads in the warps) which need to be taken into account to use the GPU architecture in the most efficient way.</p>
|
@@ -2270,7 +2274,7 @@
|
|
2270 |
|
2271 |
</ol>
|
2272 |
|
2273 |
-
<p>We'll start by looking at one of the most frequent uses of CUDA: optimizing memory access.
|
2274 |
|
2275 |
<h4>Memory coalescing</h4>
|
2276 |
|
@@ -2421,7 +2425,7 @@
|
|
2421 |
<p>In several places now, we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workloads on the GPU in a non-blocking way.</p>
|
2422 |
|
2423 |
<p>This can be useful for overlapping communication and computation – as we've seen many times in our journey – but it can also be extended to the more general idea of trying to avoid, at all cost, going back and forth between host and GPU kernel commands.</p>
|
2424 |
-
<p>This idea is beautifully illustrated by <a href="https://
|
2425 |
<div style="display: flex; gap: 20px; align-items: flex-start;">
|
2426 |
<div style="width: 50%;">
|
2427 |
<img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
|
@@ -2440,22 +2444,23 @@
|
|
2440 |
<p>How can we avoid the back and forth shown on the left? Well, the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations as possible together in a single kernel for the GPU to run, called a “fused kernel,” as shown on the right.</p>
|
2441 |
|
2442 |
|
2443 |
-
<p>Fused kernels are especially efficient and simple to write for successions of point-like operations that are performed independently of each other on each input token. In this case, there is no point in
|
2444 |
|
2445 |
-
<p>In a Transformer model, this "fusing" approach can be applied every time we have a succession of point-wise operations, such as in the computations involved in the LayerNorm layers.</p>
|
2446 |
|
2447 |
<p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>FlashAttention</em></strong>.</p>
|
2448 |
|
2449 |
<h3>FlashAttention</h3>
|
2450 |
|
2451 |
<p>FlashAttention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize attention computations by writing custom CUDA kernels to make them much faster <em>and</em> more memory efficient. The idea behind FlashAttention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory.</p>
|
2452 |
-
|
|
|
2453 |
|
2454 |
-
A basic implementation of the attention mechanism involves a lot of transfer between memory and workers. It requires materializing the <strong>S</strong>
|
2455 |
|
2456 |
<p style="text-align: center"><img alt="image.png" src="/assets/images/flashattn.png" style="width: 500px" /></p>
|
2457 |
|
2458 |
-
<p>Since bandwidth is much lower in HBM, this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
|
2459 |
|
2460 |
<p>The key element is to compute the <strong>S</strong> matrix in small pieces that can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large <strong>S</strong> matrix altogether, in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So, we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not only do we make use of the shared memory, but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length): the attention matrix.</p>
|
2461 |
|
@@ -2465,7 +2470,7 @@
|
|
2465 |
<p>The idea of FlashAttention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers. Notably:</p>
|
2466 |
<ul>
|
2467 |
<li>By avoiding materializing the <strong>S</strong> matrix, we <strong>reduce the memory burden of attention</strong>.</li>
|
2468 |
-
<li>We also <strong>remove a large part of the naive impact of the S^2
|
2469 |
</ul>
|
2470 |
|
2471 |
<p>All variants of linear attention and subquadratic approaches to approximate attention (developed shortly after the invention of the Transformer architecture) have mostly been put aside in favor of this exact and fast FlashAttention implementation and mechanism.</p>
|
@@ -2557,15 +2562,15 @@
|
|
2557 |
|
2558 |
<p>We can see that float32 spans 80 orders of magnitude, and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further: e5m2 can maintain the float16 range, but e4m3 has an even smaller range.</p>
|
2559 |
|
2560 |
-
<p>How come some formats are able to maintain the
|
2561 |
|
2562 |
<p><img alt="image.png" src="/assets/images/mixedprecision_2.png" /></p>
|
2563 |
|
2564 |
-
<p>We can see here that although bfloat16 maintains the range of float32 (unlike float16), it does this at the cost of sacrificing more precision. In the case of float8 the situation is even more dire: e4m3 can represent only 7 and e5m2 only 3 numbers in the interval [1,2].</p
|
2565 |
|
2566 |
<p>A common metric to measure a format's resolution is <em>epsilon</em>: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format, <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it's ~<d-math>10^{-3}</d-math>, and for bfloat it's 10x higher still.</p>
|
2567 |
|
2568 |
-
<p>The idea of mixed precision training is to use some of these lower-precision formats for certain computations while maintaining the performance of full precision training. As it turns out, we <strong>can’t</strong> totally abandon float32 and usually will need to do some of the computations in full precision. </p
|
2569 |
|
2570 |
<p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further, all the way down to 8 bits.</p>
|
2571 |
|
@@ -2599,7 +2604,7 @@
|
|
2599 |
<!-- Hynek uncomment this once it's added to -->
|
2600 |
<!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
|
2601 |
|
2602 |
-
<p>The first successful very large scale training with FP8 mixed precision was publicly reported
|
2603 |
|
2604 |
<p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
|
2605 |
|
@@ -2996,7 +3001,7 @@
|
|
2996 |
<h3>A0: Parallel Programming Crash Course</h3>
|
2997 |
|
2998 |
|
2999 |
-
<p>Throughout this book, we've scaled LLM training from one to hundreds of GPUs. This requires the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong><em>collective operations</em></strong>. In this section, we’ll do a small crash course on those operations - <em>Broadcast</em>, <em>AllReduce</em>, <em>Scatter</em>, and more. Let’s dive in!</p> <!-- RH: Up to now, you've styled the names of the operations broadcast, all-reduce, etc. Would it be better to stick to the same style in the Appendix, to avoid any potential confusion on the part of the reader? Or refer to them as patterns when using the uppercase versions of the names? -->
|
3000 |
|
3001 |
<p>The general setup is that we have a number of independent nodes, which could be CPU cores, GPUs, or compute nodes. Each performs some computation, and then we want to communicate the result or parts of it to the other nodes for the next computation step (<d-math>t+1</d-math>).</p>
|
3002 |
|
@@ -3199,7 +3204,7 @@
|
|
3199 |
|
3200 |
<p>As the name suggests, the goal of the Scatter operation is to take data on one node and scatter it across all the nodes, which it does by distributing a slice of the data to each node. It’s thus different from the Broadcast operation, which sends each node a complete copy of the data without slicing it, and it’s the logical inverse of the Gather operation.</p>
|
3201 |
|
3202 |
-
<p>The ReduceScatter pattern is slightly more complex. As in the
|
3203 |
|
3204 |
<p style="text-align: center"><img alt="image.png" src="/assets/images/a0_scatter_reducescatter.png" style="width: 1000px" /></p>
|
3205 |
|
@@ -3286,16 +3291,15 @@
|
|
3286 |
<ol>
|
3287 |
<li><strong>ReduceScatter</strong></li>
|
3288 |
<ul>
|
3289 |
-
<li>Each device splits its data (e.g., gradients) into
|
3290 |
<li>As each device receives a chunk, it adds (reduces) its corresponding chunk to the received one.</li>
|
3291 |
-
<li>This process continues around the ring until each device holds a
|
3292 |
</ul>
|
3293 |
<li><strong>AllGather</strong></li>
|
3294 |
<ul>
|
3295 |
<li>Now, each device needs to collect the fully reduced chunks from the other devices.</li>
|
3296 |
-
<li>
|
3297 |
-
<li>
|
3298 |
-
</ul>
|
3299 |
</ol>
|
3300 |
|
3301 |
<p>Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first animation shows the ReduceScatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (the orange rectangle).</p>
|
@@ -3368,13 +3372,9 @@
|
|
3368 |
</ul>
|
3369 |
|
3370 |
<p>There are a few finer points in the decision tree that we leave to the reader to explore in the PyTorch guide referenced above.</p>
|
3371 |
-
|
3372 |
-
<p>Now that we've covered the fundamental operations for distributed training, you should be ready to follow the blog post easily.</p> <!-- RH: What blog post? Do you mean the PyTorch guide? If so, I'd just say "you should find it easy to follow" (and combine these last two paragraphs into one). -->
|
3373 |
|
3374 |
<h3>A1: Distributed Training Profiling</h3>
|
3375 |
|
3376 |
-
<p>The next topic we'll explore is profiling.</p> <!-- RH: Feel free to adjust that, but I think there should be some sort of lead-in to this section above the first subhead. -->
|
3377 |
-
|
3378 |
<h4>Kernels</h4>
|
3379 |
|
3380 |
<p>Let's begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the layer normalization function implemented in PyTorch as <code>torch.nn.functional.layer_norm</code>. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python <code>time</code> module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.</p>
|
@@ -3402,7 +3402,7 @@
|
|
3402 |
return start.elapsed_time(end)
|
3403 |
</d-code>
|
3404 |
|
3405 |
-
<p>A more efficient approach to profiling is to utilize the PyTorch profiler, as
|
3406 |
|
3407 |
<d-code block language="python">
|
3408 |
import torch
|
@@ -3555,11 +3555,11 @@
|
|
3555 |
|
3556 |
<li><strong>Activations (hidden states):</strong> For a single layer, the hidden state tensor is of size <d-math>seq \cdot mbs \cdot h</d-math> elements.</li>
|
3557 |
|
3558 |
-
<li><strong>Model weights and gradients:</strong> Each weight matrix in your model (
|
3559 |
|
3560 |
-
<li><strong>Optimizer states:</strong> For each weight matrix (of <d-math>h^2</d-math> elements), an optimizer like Adam with mixed precision training will keep momentum and variance states in FP32 precision (<d-math>2 \cdot h^2</d-math>), plus master weights in FP32 (<d-math>h^2</d-math>). So, the total
|
3561 |
|
3562 |
-
<li><strong>Total model parameters:</strong>
|
3563 |
<ul>
|
3564 |
<li>Attention parameters:
|
3565 |
<ul>
|
@@ -3567,7 +3567,7 @@
|
|
3567 |
<li>Output projection: <d-math>h^2</d-math> parameters</li>
|
3568 |
</ul>
|
3569 |
</li>
|
3570 |
-
<li>MLP parameters with
|
3571 |
<ul>
|
3572 |
<li>Gate and up projections: <d-math>8h^2</d-math> parameters (2 matrices of size <d-math>h \times 4h</d-math>)</li>
|
3573 |
<li>Down projection: <d-math>4h^2</d-math> parameters (1 matrix of size <d-math>4h \times h</d-math>)</li>
|
@@ -3667,7 +3667,7 @@
|
|
3667 |
`
|
3668 |
<h4>TP communication analysis</h4>
|
3669 |
|
3670 |
-
<p>For tensor parallelism, activations are sharded across GPUs
|
3671 |
|
3672 |
<ul>
|
3673 |
<li>For each column-linear operation in the forward pass:
|
@@ -3684,7 +3684,7 @@
|
|
3684 |
<li>Total communication per block: <d-math>8 \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3685 |
<li>Total communication for full model: <d-math>8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3686 |
</ul>
|
3687 |
-
<p>Let's analyze if we can overlap the all-gather communication
|
3688 |
|
3689 |
<d-math block>
|
3690 |
t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
|
@@ -3701,7 +3701,7 @@
|
|
3701 |
\frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3702 |
</d-math>
|
3703 |
|
3704 |
-
<p>This ratio tells us whether we can successfully hide the all-gather communication behind the computation of the next linear
|
3705 |
|
3706 |
|
3707 |
<h4>PP communication analysis</h4>
|
|
|
865 |
|
866 |
<h4>Memory usage revisited</h4>
|
867 |
|
868 |
+
<p><a target="_self" href="#memory_usage_in_transformers">Earlier</a>, we discussed the memory usage of optimizer states, gradients, and parameters during standard training. Let's call our model's parameter count <d-math>\Psi</d-math> (previously this was <d-math>N</d-math>, but here we use the original ZeRO paper's<d-cite bibtex-key="rajbhandari2020zero"></d-cite> notation). In mixed precision training (discussed further <a target="_self" href="#mixed_precision_training">later in the book</a>) with the Adam optimizer, the memory usage for each item we need to store is:</p>
|
869 |
|
870 |
<ul>
|
871 |
<li>Model’s parameters (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
|
872 |
<li>Model’s gradients (half precision; i.e., BF16/FP16): <d-math>2\Psi</d-math></li>
|
873 |
<li>Model’s parameters in FP32 and optimizer states: <d-math>4\Psi + (4\Psi + 4\Psi)</d-math></li>
|
874 |
+
<li>Model’s gradients in FP32: <d-math>4\Psi</d-math> (optional, only included if we want to accumulate gradients in FP32)</li>
|
875 |
</ul>
|
876 |
|
877 |
+
<p>If we don't accumulate gradients in FP32, this gives us a total memory consumption of <d-math>2\Psi + 2\Psi + 12\Psi</d-math>, and if we do it gives us <d-math>2\Psi + 6\Psi + 12\Psi</d-math>. Let's focus for now on the case without FP32 gradient accumulation for simplicity.</p>
|
878 |
|
879 |
<p>The idea of ZeRO is to shard these objects across the DP ranks, with each node only storing a slice of the items. These slices are then reconstructed when and if needed, thereby dividing memory usage by the data parallel degree <d-math>N_d</d-math>:</p>
|
880 |
|
881 |
<p><img alt="zero_memory.svg" src="/assets/images/zero_memory.svg" /></p>
|
882 |
<p>Here, <d-math>\Psi</d-math> denotes the number of parameters, <d-math>k</d-math> denotes the memory multiplier of optimizer states (<d-math>k=12</d-math> for Adam, as we've just seen), and <d-math>N_d</d-math> denotes DP degree.</p>
|
883 |
|
884 |
+
<aside>If you're using FP32 gradient accumulation with ZeRO-2 or ZeRO-3, you would need to add an additional <d-math>\frac{4\Psi}{N_d}</d-math> to the gradient term.</aside>
|
885 |
|
886 |
<p>Let’s explain this by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p>
|
887 |
|
|
|
891 |
|
892 |
<p>In ZeRO-1, the optimizer states are partitioned into <d-math>N_d</d-math> equal parts, where <d-math>N_d</d-math> is the DP degree. This means that the model replicas distributed on the DP ranks each only keep track of <d-math>\frac{1}{N_d}</d-math> of the optimizer states, and during the optimization step, only <d-math>\frac{1}{N_d}</d-math> of the FP32 weights are updated.</p>
|
893 |
|
894 |
+
<p>However, during the forward pass, each replica needs all the parameters. We thus need to add an additional <strong><em>all-gather</em></strong> (the second type of collective communication primitive we've encountered!) after the optimizer step so that each model replica has the full set of updated weights.</p>
|
895 |
|
896 |
<p>This explains the memory formula of <d-math>2\Psi + 2\Psi + \frac{k\Psi}{N_d}</d-math> that we saw in the previous figure! Here’s a summary of the sequence of operations for a single training step:</p>
|
897 |
+
<ol>
|
|
|
898 |
<li>Perform a forward pass with the same full set of BF16 parameters on each replica, but different micro-batches across replicas.</li>
|
899 |
<li>Perform a backward pass with the same full set of gradients on each replica, but different micro-batches across replicas.</li>
|
900 |
<li>Perform a <strong><em>reduce-scatter</em></strong> on the gradients (another primitive - we'll explain this one shortly).</li>
|
901 |
<li>Each replica performs an optimizer step on its local optimizer states (only <d-math>\frac{1}{N_d}</d-math> of the optimizer states) to get <d-math>\frac{1}{N_d}</d-math> updated FP32 parameters, which can then be converted to <d-math>\frac{1}{N_d}</d-math> of the full set of BF16 parameters.</li>
|
902 |
<li>Perform an all-gather on the BF16 parameters to send the missing slices back to each replica. This is a new operation in ZeRO and is not used in vanilla DP.</li>
|
903 |
+
</ol>
|
904 |
<aside>Note: Reduce-scatter is two times faster than all-reduce! <em>Yay, a third communication primitive!</em></aside>
|
905 |
|
906 |
<p>You may be wondering what this "reduce-scatter" operation is and what this all looks like, so let's try to make it more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:</p>
|
|
|
914 |
<p>If you've been following along, you'll recall from our discussion of vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of BF16 parameters. There are two main strategies for this:</p>
|
915 |
|
916 |
<ul>
|
917 |
+
<li><strong>During the optimizer step:</strong> We can initiate the all-gather immediately after the optimizer updates the first slice of the parameters. This allows the communication to potentially overlap with the updating of the other parameters.</li>
|
918 |
<li><strong>During the forward pass:</strong> We can overlap the all-gather of each layer’s parameters with the forward pass.</li>
|
919 |
</ul>
|
920 |
|
|
|
929 |
|
930 |
<h4>ZeRO-2: Adding <strong>gradient partitioning</strong></h4>
|
931 |
|
932 |
+
<p>Since on each replica we only need to have the gradient shard corresponding to its optimizer state shard, it makes sense to shard gradients as well, similarly to the optimizer states. Then, during the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Here, we only store the <d-math>\frac{1}{N_d}</d-math> gradients that are needed in memory, thus saving more memory compared to ZeRO-1.</p>
|
933 |
|
934 |
+
<aside>In the case of FP32 gradient accumulation, we only need to keep <d-math>\frac{1}{N_d}</d-math> FP32 grads used to accumulate the BF16 grads coming from the reduce-scatter. And in the optimizer step, these <d-math>\frac{1}{N_d}</d-math> FP32 grads are used to update the local shard of the optimizer states.</aside>
|
935 |
|
936 |
<p><img alt="dp_zero2.gif" src="/assets/images/dp_zero2.gif" /></p>
|
937 |
|
938 |
+
<p>It's easy to see now that sharding the gradients leads to <d-math>2\Psi + \frac{2\Psi+k\Psi}{N_d}</d-math>, and as <d-math>N_d</d-math> is increased, we can use up to 8x less memory than the baseline. In terms of communication, the same process applies as for ZeRO-1, with the only difference being that we communicate and release memory on the fly: they both require a reduce-scatter for the gradients and an all-gather over all parameters. ZeRO-2 is thus also equivalent to vanilla DP training with regard to communication.</p>
|
939 |
<!-- RH: In this figure, on the right, can "AllGather Params" and "ReduceScatter Grad" be changed to "All-gather params" and "Reduce-scatter grads"? -->
|
940 |
<p><img alt="dp_zero2_overlap.svg" src="/assets/images/dp_zero2_overlap.svg" /></p>
|
941 |
|
942 |
+
<aside>Note: You might notice that there is no real overhead to using ZeRO-2 over ZeRO-1 besides implementation complexity, and indeed ZeRO-2 is usually the better option.</aside>
|
943 |
|
944 |
+
<p>Now that we've sharded gradients as well, are we done, or can we keep making improvements? Here comes ZeRO-3!</p>
|
945 |
|
946 |
<h4>ZeRO-3: Adding <strong>parameter partitioning</strong> (FSDP)</h4>
|
947 |
|
|
|
954 |
</div>
|
955 |
</div>
|
956 |
|
957 |
+
<p>So how do we do a forward or backward pass in practice if the parameters of the model are distributed? Quite simply, we gather them on demand when we need them. In the forward pass, this looks as follows:</p>
|
958 |
|
959 |
<p><img alt="dp_zero3_fwd.svg" src="/assets/images/dp_zero3_fwd.svg" /></p>
|
960 |
|
|
|
962 |
|
963 |
<p><img alt="dp_zero3_bwd.svg" src="/assets/images/dp_zero3_bwd.svg" /></p>
|
964 |
|
965 |
+
<p>The other issue is that we need to do these all-gathers continuously throughout the forward and backward pass in a training step, which amounts to <d-math>2\cdot \text{num\_layers} -1</d-math> additional all-gathers in a training step compared to ZeRO-2. Each comes with a small <em>base latency</em> overhead, as we can see in the following figure:</p>
|
966 |
<!-- RH: In this figure, change "AllGather Params" and "ReduceScatter Grads" to "All-gather params" and "Reduce-scatter grads" and lowercase "Free"? -->
|
967 |
<p><img alt="dp_zero3_overlap.svg" src="/assets/images/dp_zero3_overlap.svg" /></p>
|
968 |
|
969 |
<p>During the forward pass we do all-gather operations for the parameters when we need them, so there's a <d-math>\Psi</d-math> communication tax. Since we discard the parameters immediately after we use them in the forward pass, we need one more all-gather during the backward pass as well, incurring another <d-math>\Psi</d-math> communication tax. Finally, we need the same reduce-scatter operation as in ZeRO-2 for the gradients, which also costs <d-math>\Psi</d-math> in communication. So, we arrive at a total communication cost of <d-math>3\Psi</d-math>, compared to <d-math>2\Psi</d-math> for ZeRO-2.</p>
|
970 |
|
971 |
+
<p>This may sound like a lot of communication overhead, but it's actually not a big deal, as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called <strong><em>prefetching</em></strong>. With prefetching, we all-gather the weights for <em>Layer n+1</em> while we do the forward pass for <em>Layer n</em>, and similarly, we all-gather the weights for <em>Layer n-1</em> while doing the backward pass for <em>Layer n</em>. Of course, this overlap only works as long as we don’t scale DP too much (as a rule of thumb, DP shouldn’t exceed 512).</p>
|
972 |
+
|
973 |
+
<aside>Note: We use "DP" to refer to both the data parallelism technique and the number of GPUs used for data parallelism <em>(DP = DP size = DP degree)</em>.</aside>
|
974 |
|
975 |
+
<p>In terms of memory, we can see that our equation has now reached its final form of <d-math>\frac{2\Psi +2\Psi+k\Psi}{N_d}</d-math>, which means we can theoretically drive memory usage down indefinitely if we can increase the DP size, at least for the model-related parameters. Notice that it doesn’t help with the intermediate activations, though - for that, we can use activation checkpointing and gradient accumulation, as we saw earlier.</p>
|
976 |
|
977 |
<p>Let’s summarize our journey into DP and ZeRO so far. We've seen that we can increase the throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO, we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients, and optimizer states across DP replicas, while incurring a small communication cost.</p>
|
978 |
<aside>If you want to read more about FSDP1, FSDP2, and some of the implementation complexities around them, check out <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">this nice blog</a>.</aside>
|
|
|
1008 |
|
1009 |
|
1010 |
|
1011 |
+
<p>So, we've sharded the model’s parameters, gradients, and optimizer states with ZeRO, but we hit a limit once activation memory overtakes our memory budget. Welcome tensor parallelism (TP), a method that shards weights, gradients, and optimizer states as well as activations - and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how TP works with simple matrix multiplication (matmul) operations.</p>
|
1012 |
|
1013 |
<p>Tensor parallelism leverages the mathematical properties of matrix multiplication, <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
|
1014 |
|
|
|
1023 |
|
1024 |
<ul>
|
1025 |
<li><d-math>X</d-math> represents the input or activation values</li>
|
1026 |
+
<li><d-math>W</d-math> represents the weight of the Linear layer</li>
|
1027 |
</ul>
|
1028 |
|
1029 |
<p>In practice, a small example of the operation looks like this:</p>
|
|
|
1066 |
|
1067 |
<h3>Tensor parallelism in a transformer block</h3>
|
1068 |
|
1069 |
+
<p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks: a feedforward multi-layer perceptron (MLP) block and a multi-head attention (MHA) block. We can apply tensor parallelism to both.</p>
|
1070 |
|
1071 |
+
<p>The feedforward part can be parallelized by having a column-linear followed by a row-linear split, which amounts to a broadcast to copy the input and an all-reduce in the forward pass. Note that the broadcast isn’t needed in actual training, where we can make sure inputs are already synced across TP ranks. This setup is more efficient than starting with a row-linear followed by column-linear split, as we can skip the intermediate all-reduce between the split operations.</p>
|
1072 |
|
1073 |
<p><img alt="image.png" src="/assets/images/tp_diagram4.png" /></p>
|
1074 |
|
1075 |
+
<p>Now that we've found an efficient schema for the feedforward part of the transformer, let's take a look at the multi-head attention block.</p>
|
|
|
|
|
1076 |
|
1077 |
+
<p>We can generally follow a similar approach here, where the Query (Q), Key (K), and Value (V) matrices are split in a column-parallel fashion and the output projection can be considered a row-linear. With multi-head attention, the column-parallel approach has a very natural interpretation: each GPU computes the attention for an individual or a subset of attention heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query attention (MQA)</em></strong></a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention (GQA)</em></strong></a>, where keys and values are shared between queries. </p>
|
1078 |
|
1079 |
<p><img alt="image.png" src="/assets/images/tp_full_diagram.png" /></p>
|
1080 |
+
|
1081 |
+
<p>We're able to apply tensor parallelism so effectively to both the Attention and MLP blocks because they have dimensions that are naturally independent. The Attention block can be parallelized along the <code>num_attention_heads</code> dimension, as each attention head operates independently. Similarly, the MLP block can be parallelized along the <code>hidden_dim</code> dimension, as operations within the feedforward network are independent along this dimension.</p>
|
1082 |
+
<p>It's worth noting, however, that the tensor parallelism degree should not exceed the number of attention heads because we shard the QKV projection along the <code>num_attention_heads</code> dimension. When using Grouped Query Attention (GQA), we have <d-math>num\_attention\_heads</d-math> query heads but only <d-math>num\_kv\_heads</d-math> key/value heads (with <d-math>num\_attention\_heads >= num\_kv\_heads</d-math>). In this case, we can still set <d-math>TP = num\_attention\_heads</d-math>, but we'll need to ensure that the K/V heads stay properly synchronized across GPUs. For instance, Llama-3 8B has 32 query heads but only 8 key/value heads, so while the TP degree could theoretically go up to 32, we would need careful implementation to maintain K/V head synchronization across the tensor-parallel workers.</p>
|
1083 |
|
1084 |
<p>Note also that tensor parallelism is not a silver bullet for training. We’ve added several distributed communication primitives directly in the computation path of our model, which are therefore hard to fully hide/overlap with computation (like we did in ZeRO), so our final performance will be the result of a trade-off between the computation and memory gains and the added communication overhead. Let's illustrate this:</p>
|
1085 |
<!-- RH: In this figure, change "TP region" to "TP Region" and "AllReduce Activs" to "All-reduce activs"? -->
|
|
|
1087 |
|
1088 |
<aside>It's possible to partially hide this communication by performing block matrix multiplication coupled with async communication/computation.</aside>
|
1089 |
|
1090 |
+
<p>Looking at the timeline of operations in tensor-parallel MLP (the same applies for MHA), we can better understand the trade-offs involved. In the forward pass of each decoder layer, we hit a synchronization point with the all-reduce operation that cannot be overlapped with computation. This <em>exposed communication overhead</em> is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
|
1091 |
|
1092 |
+
<aside>For example, Megatron-LM and Nanotron implement a partial overlapping of all-gather with Fully-Connected (FC1) computation, where a portion of the matrix multiplication result gets sent to the other GPU while the remaining part is still being computed.</aside>
|
1093 |
|
1094 |
+
<p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, TP introduces significant communication requirements that heavily depend on the network infrastructure. The inability to fully hide this particular all-reduce behind computation means it directly adds to the critical path of forward propagation, where the critical path refers to the sequence of operations that determine the minimum time required to complete the forward pass.</p>
|
1095 |
|
1096 |
<aside>This is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. </aside>
|
1097 |
|
|
|
1179 |
<li><em>f</em> is an all-reduce to synchronize gradients.</li>
|
1180 |
</ul>
|
1181 |
|
1182 |
+
<p>These <em>f</em> and <em>f*</em> operations are called <strong><em>conjugate pairs</em></strong> because they complement each other - in each pass, when one is a no-op the other is an all-reduce, and it's the opposite in the other pass.</p>
|
1183 |
|
1184 |
<p>For sequence parallelism, we use different operations labeled <em>g</em> and <em>g*</em>. Specifically, we avoid using all-reduce in the SP regions since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
|
1185 |
|
|
|
1189 |
<div>
|
1190 |
<p style="margin-bottom: 0;"><strong>Initial LayerNorm layer (SP region)</strong></p>
|
1191 |
<ul style="margin-top: 0;">
|
1192 |
+
<li>Input tensors <em>X1*</em> and <em>X2*</em> <d-math>(b,s/2,h)</d-math> enter, already split across the sequence dimension.</li>
|
1193 |
+
<li>Each GPU computes LayerNorm independently on its sequence chunk, giving <em>Y1*</em> and <em>Y2*</em>.</li>
|
1194 |
</ul>
|
1195 |
<p style="margin-bottom: 0;"><strong>First transition (SP → TP)</strong></p>
|
1196 |
<ul style="margin-top: 0;">
|
1197 |
<li><em>g</em> operation (all-gather) combines <em>Y1</em> and <em>Y2</em> back to full sequence length.</li>
|
1198 |
+
<li>Restores <em>Y</em> <d-math>(b,s,h)</d-math> since column-linear layers need the full hidden dimension <d-math>h</d-math>.</li>
|
1199 |
</ul>
|
1200 |
<p style="margin-bottom: 0;"><strong>First linear layer (TP region)</strong></p>
|
1201 |
<ul style="margin-top: 0;">
|
1202 |
+
<li><em>A1</em> and <em>A2</em> are column-linear layers, so they split <em>Y</em> along the hidden dimension.</li>
|
1203 |
<li>GELU is applied independently on each GPU.</li>
|
1204 |
+
<li><em>Z1*</em> and <em>Z2*</em> are <d-math>(b,s,h/2)</d-math>.</li>
|
1205 |
</ul>
|
1206 |
<p style="margin-bottom: 0;"><strong>Second linear layer (TP region)</strong></p>
|
1207 |
<ul style="margin-top: 0;">
|
1208 |
+
<li><em>B1</em> and <em>B2</em> are row-linear layers, so they restore the hidden dimension.</li>
|
1209 |
+
<li><em>W1</em> and <em>W2</em> are <d-math>(b,s,h)</d-math> that need to be summed together.</li>
|
1210 |
</ul>
|
1211 |
<p style="margin-bottom: 0;"><strong>Final transition (TP → SP)</strong></p>
|
1212 |
<ul style="margin-top: 0;">
|
|
|
1391 |
|
1392 |
<p><img alt="ring-attention.gif" src="/assets/images/ring-attention.gif" /></p>
|
1393 |
|
1394 |
+
<p>It's probably obvious to you from this animation why the authors chose to call this approach Ring Attention<d-cite bibtex-key="liu2023ringattentionblockwisetransformers"></d-cite>!</p>
|
1395 |
|
1396 |
<p>There is one big problem, though, which is that a naive implementation of Ring Attention leads to some strong imbalances between GPUs due to the shape of the causal attention matrix. Let’s take a look at the softmax computation by considering the attention score matrix with the causal attention mask:</p>
|
1397 |
|
|
|
1402 |
<p>Let’s see if we can balance our computations better.</p>
|
1403 |
|
1404 |
<h3>Zig-Zag Ring Attention – A balanced compute implementation</h3>
|
1405 |
+
<p>We need a better way to distribute the input sequences. This can be achieved by not assigning the tokens to the GPUs in a purely sequential manner, but instead mixing up the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag Attention. In this new arrangement, the attention mask will show an even distribution of computation, but if you count the number of colored squares, you'll see that the computation is now balanced across all GPUs.</p>
|
1406 |
+
<aside>
|
1407 |
+
<p>We show here Zig-Zag Attention, which slightly differs from Striped Attention<d-cite bibtex-key="brandon2023fasterring"></d-cite>. For details on the differences, check <a href="https://github.com/zhuzilin/ring-flash-attention/issues/2#issuecomment-2236746166">this GitHub discussion</a>.</p>
|
1408 |
+
</aside>
|
1409 |
|
1410 |
<p><img alt="cp_zigzagmask.svg" src="/assets/images/cp_zigzagmask.svg" /></p>
|
1411 |
|
|
|
1452 |
</div>
|
1453 |
</div>
|
1454 |
|
1455 |
+
<p>In the <a target="_self" href="#tensor-parallelism">"Tensor Parallelism"</a> section, we saw that trying to scale tensor parallelism past the number of GPUs on a single node - typically 4 or 8 - forces us to use lower-bandwidth network communication, which can significantly impair performance. We can see the effects of this inter-node communication clearly in the all-reduce operation when we benchmark it on our cluster across several nodes (each node here has 8 GPUs):</p>
|
|
|
1456 |
|
1457 |
<!-- <iframe class="l-body-outset" id="plotFrame11" src="assets/data/benchmarks/pp_comm_bandwidth.html" width="90%" scrolling="no" frameborder="0"></iframe> -->
|
1458 |
<div class="l-body-outset" id="fragment-pp_comm_bandwidth"></div>
|
|
|
1643 |
|
1644 |
<p><img alt="image.png" src="/assets/images/pp_zerobubble_ppschedule.png" /></p>
|
1645 |
|
1646 |
+
<div class="figure-legend"><p>On the top (Figure 2 from the Zero Bubble paper): the classical 1F1B schedule, interleaving forward and backward passes but keeping a coarse-grained backward pass. On the bottom (Figure 3 from the Zero Bubble paper): two handcrafted schedules splitting the backward pass into finer-grained <d-math>B</d-math> and <d-math>W</d-math> operations. The lower schedule is an example of a (theoretical) zero bubble schedule taking advantage of this fine-grained decomposition.</p>
|
1647 |
</div>
|
1648 |
|
1649 |
<p>DeepSeek’s DualPipe, introduced with its V3 technical report <d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>, is an extension of this decomposed approach to the additional case of two streams propagating from both ends of the PP dimension, with these streams being interleaved to further minimize idle time in the GPUs. This schedule is displayed in the following scheduling graph - as you can see, it's even more complex than the previous ones:</p>
|
1650 |
|
1651 |
<p><img alt="image.png" src="/assets/images/pp_zerobubble_dualpipe.png" /></p>
|
1652 |
|
1653 |
+
<p>In general, fully optimizing such complex schedules involves carefully measuring the duration of the various fine-grained operations and solving an Integer Linear Programming (ILP) problem to minimize the final bubble time. (See, for instance, the Zero Bubble paper<d-cite bibtex-key="qi2023zerobubblepipelineparallelism"></d-cite> for a discussion of the heuristics and algorithms used to perform such scheduling.) As a result, the zero bubble and DualPipe schedules are too complex for us to give code snippets here, but you should have a general idea of the concepts involved. </p>
|
1654 |
|
1655 |
<p>This concludes our tour of the world of pipeline schedules and bubbles. We hope you enjoyed it!</p>
|
1656 |
|
|
|
1705 |
|
1706 |
<p>However, there are a few major differences between the PP and ZeRO-3 approaches:</p>
|
1707 |
|
1708 |
+
<aside>Note here that we say "a layer" to simplify, but the actual sharding unit of the model can vary - it might be multiple layers, a single layer, or even a subset of a layer, depending on the specific implementation.</aside>
|
1709 |
<div class="l-body">
|
1710 |
<table>
|
1711 |
<thead>
|
|
|
1760 |
|
1761 |
<p>The main reason we don't want to use TP only for parallelism is that, in practice, TP has two limitations (which we've discussed in previous sections). First, since its communication operations are part of the critical path of computation, it's difficult to scale well beyond a certain point, after which communication overhead begins to dominate. Second, unlike ZeRO and PP, which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more cumbersome to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.</p>
|
1762 |
|
1763 |
+
<p>As a consequence, when combining parallelism strategies, TP will typically be kept for high-speed intra-node communications, while ZeRO-3 or PP can be used for parallelism groups spanning lower-speed inter-node communications as their communication patterns require less bandwidth (for PP) or can be more easily overlapped with computation (for ZeRO-3). The main consideration when combining these techniques is to organize the GPUs efficiently in groups for each parallelism dimension to maximize throughput and minimize communication overhead, while being mindful of TP's scaling limitations. For instance, the groups of GPUs communicating for TP should be kept inside nodes.</p>
|
1764 |
|
1765 |
<p><strong>Context parallelism</strong> and <strong>expert parallelism</strong> also help us shard activations, and can be seen as complementary to TP. CP handles long sequences while EP enables distributed Mixture of Experts training, and they can be combined without any particular issues.</p>
|
1766 |
|
1767 |
+
<p>CP specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most modules, like MLP and LayerNorm, can process these sharded sequences independently, attention blocks require communication since each token needs access to keys/values from the full sequence. As we saw in the <a target="_self" href="#context_parallelism">CP section</a>, this is handled efficiently through Ring Attention patterns that overlap computation and communication. CP is particularly valuable when scaling to extreme sequence lengths (128k+ tokens) where, even when using full activation recomputation, the memory requirements for attention would be prohibitive on a single GPU.</p>
|
1768 |
|
1769 |
<div class="large-image-background-transparent">
|
1770 |
<div class="boxed-image">
|
|
|
1772 |
</div>
|
1773 |
</div>
|
1774 |
|
1775 |
+
<p><strong>Expert parallelism</strong> specifically targets the challenge of training MoE models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication operations in EP are the "all-to-all" operations routing tokens to their assigned experts and gathering the results back. While this introduces some communication overhead, it enables scaling model capacity significantly since each token is only processed during inference (and training) by a much smaller fraction of the total parameters. In terms of distributed training/inference, partitioning experts across GPUs becomes relevant when models scales to a large number of experts.</p>
|
1776 |
<aside>For instance, DeepSeek-V3 uses 256 experts.</aside>
|
1777 |
|
1778 |
<div class="large-image-background-transparent">
|
|
|
1813 |
</tr>
|
1814 |
<tr>
|
1815 |
<td>Communication for matrix multiplication operations (column/row linear)</td>
|
1816 |
+
<td>Communication for attention keys/values</td>
|
1817 |
<td>Communication for token routing to experts</td>
|
1818 |
</tr>
|
1819 |
<tr>
|
|
|
1934 |
<li>At 512+ GPU scale, pure data parallelism/ZeRO-3 will start to becomes inefficient due to communication cost - it can be better to then combine DP with either tensor or pipeline parallelism.</li>
|
1935 |
<li>At 1024+ GPU scale, a recommended setup may be tensor parallelism (TP=8) with data parallelism (ZeRO-2) and pipeline parallelism.</li>
|
1936 |
</ul>
|
1937 |
+
<aside>We focus on fitting a single instance for now - even though we may use DP for ZeRO to achieve this goal - we're only interested here in the model parameters memory savings that it provide when used with ZeRO-3.</aside> <!-- RH: I'm having trouble working out what you're trying to say here. Maybe something like "We'll focus on fitting a single instance for now. Even though we may use DP to achieve this goal, we're only interested here in the model parameter memory savings that it provides when used with ZeRO-3."? --> <!-- NT: I'm in favor of removing as well -->
|
1938 |
</p>
|
1939 |
|
1940 |
<p><em>Special considerations:</em>
|
|
|
1988 |
|
1989 |
<p>In the <a href="https://github.com/huggingface/nanotron">Nanotron</a> repository, you'll find several scripts you can use to run all the experiments discussed previously and benchmark your own model and cluster.</p>
|
1990 |
|
1991 |
+
<p>We actually ran benchmarks ourselves on <strong>several thousand distributed configurations</strong>, covering every model size we've discussed here as well as a very large number of cluster configurations (namely, 1-64 nodes of 8xH100s) in order to produce the results we've covered up to now in this book.</p>
|
1992 |
<aside>We want to take this opportunity to apologize to our coworkers for blocking most of the science cluster, and in turn forgive any threats that may have been whispered.</aside>
|
1993 |
|
1994 |
<p>Now let's take a step back to gather and analyze the results of all our benchmarks and see if, beyond theory, we can actually discover using real-world data how various configurations fare against each other.</p>
|
|
|
2017 |
|
2018 |
<h3>Lessons learned on benchmarking</h3>
|
2019 |
|
2020 |
+
<p>Our goal for this book was not only to discuss theory and implementations, but to provide actual data points as well. So, the plan was simple: let's run every possible distributed configuration for every model and a number of cluster sizes. Even after excluding impossible configurations, we still needed to run thousands of experiments. </p>
|
2021 |
|
2022 |
<p>
|
2023 |
On paper, this sounds easy enough: we can easily launch big arrays of jobs on our cluster. However, as soon as we launched the first batches of experiments, our troubles began:
|
|
|
2036 |
<li>Minimizing cluster restart times and optimizing idle time</li>
|
2037 |
<li>Analyzing detailed NCCL debug logs</li>
|
2038 |
<li>Understanding memory usage patterns and CUDA memory allocator behaviors</li>
|
2039 |
+
<li>Improving pipeline parallelism performance on multi-node setups</li>
|
2040 |
</ul>
|
2041 |
|
2042 |
<p>These challenges taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.</p>
|
|
|
2098 |
<p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
|
2099 |
<div class="figure-legend"><p>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p></div>
|
2100 |
|
2101 |
+
<p>The goal when using a GPU is to run as many workloads as possible, in parallel, on the available cores, by taking advantage of this hierarchical organization of compute/memory resources.</p>
|
2102 |
|
2103 |
<p>A piece of code running on a core of the GPU is called a <strong><em>kernel</em></strong>. It can be written at a high level in CUDA or Triton, for instance, and is then compiled to Parallel Thread Execution (PTX), the low-level assembly used by NVIDIA GPUs.</p>
|
2104 |
|
|
|
2164 |
|
2165 |
<ul>
|
2166 |
<li>Threads are grouped in <em>warps</em>, each containing 32 threads. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li>
|
2167 |
+
<li>Warps are grouped in larger <em>blocks</em> of more flexible size (for example, there may be 512 or 1,024 threads in a block), with each block assigned to a single SM. An SM may run several blocks in parallel. However, depending on the resources available, not all of the blocks may get assigned for execution immediately; some may be waitlisted until more resources become available.</li>
|
2168 |
</ul>
|
2169 |
|
2170 |
<p>The main thing to retain here is that there are various sizing and allocation constraints (size of the various memories, number of concurrent blocks and threads in the warps) which need to be taken into account to use the GPU architecture in the most efficient way.</p>
|
|
|
2274 |
|
2275 |
</ol>
|
2276 |
|
2277 |
+
<p>We'll start by looking at one of the most frequent uses of CUDA: optimizing memory access. The global memory in GPUs (the largest memory area, as you saw earlier) has a long latency and low bandwidth in comparison to the cache, creating a major bottleneck for many applications. Efficiently accessing data from global memory can greatly improve performance.</p>
|
2278 |
|
2279 |
<h4>Memory coalescing</h4>
|
2280 |
|
|
|
2425 |
<p>In several places now, we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workloads on the GPU in a non-blocking way.</p>
|
2426 |
|
2427 |
<p>This can be useful for overlapping communication and computation – as we've seen many times in our journey – but it can also be extended to the more general idea of trying to avoid, at all cost, going back and forth between host and GPU kernel commands.</p>
|
2428 |
+
<p>This idea is beautifully illustrated by <a href="https://upload.wikimedia.org/wikipedia/commons/b/b2/Hausziege_04.jpg">Horace He</a> in <a href="https://horace.io/brrr_intro.html">these diagrams:</a></p>
|
2429 |
<div style="display: flex; gap: 20px; align-items: flex-start;">
|
2430 |
<div style="width: 50%;">
|
2431 |
<img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
|
|
|
2444 |
<p>How can we avoid the back and forth shown on the left? Well, the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations as possible together in a single kernel for the GPU to run, called a “fused kernel,” as shown on the right.</p>
|
2445 |
|
2446 |
|
2447 |
+
<p>Fused kernels are especially efficient and simple to write for successions of point-like operations that are performed independently of each other on each input token. In this case, there is no point in sending the computed values back to global memory before moving them to SM memory and spinning up a new kernel. It's much more efficient to keep all the values locally until all the computations have been performed.</p>
|
2448 |
|
2449 |
+
<p>In a Transformer model, this "fusing" approach can be applied every time we have a succession of point-wise operations, such as in the computations involved in the LayerNorm layers.</p>
|
2450 |
|
2451 |
<p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>FlashAttention</em></strong>.</p>
|
2452 |
|
2453 |
<h3>FlashAttention</h3>
|
2454 |
|
2455 |
<p>FlashAttention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize attention computations by writing custom CUDA kernels to make them much faster <em>and</em> more memory efficient. The idea behind FlashAttention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory.</p>
|
2456 |
+
|
2457 |
+
<p>The global memory in modern GPUs often uses a technology called <a href="https://semianalysis.com/2024/09/03/the-memory-wall/#hbm-roadmap"></a>High Bandwidth Memory (HBM)</a>, which despite its name, is slower than SRAM in the GPU memory hierarchy. This HBM terminology will be important when we discuss the details of FlashAttention's implementation.</p>
|
2458 |
|
2459 |
+
<p>A basic implementation of the attention mechanism involves a lot of transfer between memory and workers. It requires materializing the <strong>S</strong> matrix (where S = QK^T, the attention scores) and the <strong>P</strong> matrix (where P = softmax(S), the normalized attention weights) in HBM, which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
|
2460 |
|
2461 |
<p style="text-align: center"><img alt="image.png" src="/assets/images/flashattn.png" style="width: 500px" /></p>
|
2462 |
|
2463 |
+
<p>Since bandwidth is much lower in HBM, this introduces a severe bottleneck in the attention computation. Can we do better? <a href="https://upload.wikimedia.org/wikipedia/commons/b/b2/Hausziege_04.jpg">Tri Dao</a> says yes!</p>
|
2464 |
|
2465 |
<p>The key element is to compute the <strong>S</strong> matrix in small pieces that can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large <strong>S</strong> matrix altogether, in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So, we can compute part of <d-math>O</d-math> directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not only do we make use of the shared memory, but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length): the attention matrix.</p>
|
2466 |
|
|
|
2470 |
<p>The idea of FlashAttention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers. Notably:</p>
|
2471 |
<ul>
|
2472 |
<li>By avoiding materializing the <strong>S</strong> matrix, we <strong>reduce the memory burden of attention</strong>.</li>
|
2473 |
+
<li>We also <strong>remove a large part of the naive impact of the <d-math>O(S^2)</d-math> cost of attention</strong>.</li>
|
2474 |
</ul>
|
2475 |
|
2476 |
<p>All variants of linear attention and subquadratic approaches to approximate attention (developed shortly after the invention of the Transformer architecture) have mostly been put aside in favor of this exact and fast FlashAttention implementation and mechanism.</p>
|
|
|
2562 |
|
2563 |
<p>We can see that float32 spans 80 orders of magnitude, and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further: e5m2 can maintain the float16 range, but e4m3 has an even smaller range.</p>
|
2564 |
|
2565 |
+
<p>How come some formats are able to maintain the full range while others aren't? Let's investigate their resolutions by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
|
2566 |
|
2567 |
<p><img alt="image.png" src="/assets/images/mixedprecision_2.png" /></p>
|
2568 |
|
2569 |
+
<p>We can see here that although bfloat16 maintains the range of float32 (unlike float16), it does this at the cost of sacrificing more precision. In the case of float8 the situation is even more dire: e4m3 can represent only 7 and e5m2 only 3 numbers in the interval [1,2].</p>
|
2570 |
|
2571 |
<p>A common metric to measure a format's resolution is <em>epsilon</em>: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format, <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it's ~<d-math>10^{-3}</d-math>, and for bfloat it's 10x higher still.</p>
|
2572 |
|
2573 |
+
<p>The idea of mixed precision training is to use some of these lower-precision formats for certain computations while maintaining the performance of full precision training. As it turns out, we <strong>can’t</strong> totally abandon float32 and usually will need to do some of the computations in full precision. </p>
|
2574 |
|
2575 |
<p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further, all the way down to 8 bits.</p>
|
2576 |
|
|
|
2604 |
<!-- Hynek uncomment this once it's added to -->
|
2605 |
<!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
|
2606 |
|
2607 |
+
<p>The first successful very large scale training with FP8 mixed precision was publicly reported in the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. The authors carefully analyzed each operation of the forward pass (<em>Fprop</em>) as well as the activation (<em>Dgrad</em>) and weight (<em>Wgrad</em>) backward passes. Similar to BF16 mixed precision training, some aggregations and master weights are kept in higher precision while the operations themselves are performed in FP8.</p>
|
2608 |
|
2609 |
<p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
|
2610 |
|
|
|
3001 |
<h3>A0: Parallel Programming Crash Course</h3>
|
3002 |
|
3003 |
|
3004 |
+
<p>Throughout this book, we've scaled LLM training from one to hundreds of GPUs. This requires the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong><em>collective operations</em></strong>. In this section, we’ll do a small crash course on those operations - <em>Broadcast</em>, <em>AllReduce</em>, <em>Scatter</em>, and more. Let’s dive in!</p> <!-- RH: Up to now, you've styled the names of the operations broadcast, all-reduce, etc. Would it be better to stick to the same style in the Appendix, to avoid any potential confusion on the part of the reader? Or refer to them as patterns when using the uppercase versions of the names? --> <!-- NT: Both are fine to me, shouldn't be confusing to reader i guess -->
|
3005 |
|
3006 |
<p>The general setup is that we have a number of independent nodes, which could be CPU cores, GPUs, or compute nodes. Each performs some computation, and then we want to communicate the result or parts of it to the other nodes for the next computation step (<d-math>t+1</d-math>).</p>
|
3007 |
|
|
|
3204 |
|
3205 |
<p>As the name suggests, the goal of the Scatter operation is to take data on one node and scatter it across all the nodes, which it does by distributing a slice of the data to each node. It’s thus different from the Broadcast operation, which sends each node a complete copy of the data without slicing it, and it’s the logical inverse of the Gather operation.</p>
|
3206 |
|
3207 |
+
<p>The ReduceScatter pattern is slightly more complex. As in the AllReduce case , you apply an operation on the data from all nodes. But unlike AllReduce where each node receives the full output tensor, in ReduceScatter each node only receives a slice of the output tensor. The following image illustrates the difference between these operations:</p>
|
3208 |
|
3209 |
<p style="text-align: center"><img alt="image.png" src="/assets/images/a0_scatter_reducescatter.png" style="width: 1000px" /></p>
|
3210 |
|
|
|
3291 |
<ol>
|
3292 |
<li><strong>ReduceScatter</strong></li>
|
3293 |
<ul>
|
3294 |
+
<li>Each device splits its data (e.g., gradients) into N chunks (where N is the number of GPUs) and sends one chunk to its neighbor. Simultaneously, each device receives a chunk from its other neighbor.</li>
|
3295 |
<li>As each device receives a chunk, it adds (reduces) its corresponding chunk to the received one.</li>
|
3296 |
+
<li>This process continues around the ring until each device holds a fully reduced chunk representing the sum of the gradients across all devices for that chunk.</li>
|
3297 |
</ul>
|
3298 |
<li><strong>AllGather</strong></li>
|
3299 |
<ul>
|
3300 |
<li>Now, each device needs to collect the fully reduced chunks from the other devices.</li>
|
3301 |
+
<li>Each device sends its reduced chunk to its neighbor, and receives the reduced chunk from its other neighbor.</li>
|
3302 |
+
<li>The devices continue forwarding the chunks they receive until every device has all the fully reduced chunks, giving each device the complete, summed-up gradient.</li>
|
|
|
3303 |
</ol>
|
3304 |
|
3305 |
<p>Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first animation shows the ReduceScatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (the orange rectangle).</p>
|
|
|
3372 |
</ul>
|
3373 |
|
3374 |
<p>There are a few finer points in the decision tree that we leave to the reader to explore in the PyTorch guide referenced above.</p>
|
|
|
|
|
3375 |
|
3376 |
<h3>A1: Distributed Training Profiling</h3>
|
3377 |
|
|
|
|
|
3378 |
<h4>Kernels</h4>
|
3379 |
|
3380 |
<p>Let's begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the layer normalization function implemented in PyTorch as <code>torch.nn.functional.layer_norm</code>. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python <code>time</code> module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.</p>
|
|
|
3402 |
return start.elapsed_time(end)
|
3403 |
</d-code>
|
3404 |
|
3405 |
+
<p>A more efficient approach to profiling is to utilize the PyTorch profiler, as <a target="_self" href="#profiling_gpu_compute_and_communication">explained previously</a>. For example, consider the following code:</p>
|
3406 |
|
3407 |
<d-code block language="python">
|
3408 |
import torch
|
|
|
3555 |
|
3556 |
<li><strong>Activations (hidden states):</strong> For a single layer, the hidden state tensor is of size <d-math>seq \cdot mbs \cdot h</d-math> elements.</li>
|
3557 |
|
3558 |
+
<li><strong>Model weights and gradients:</strong> Each weight matrix in your model (e.g. linear layer) contains about <d-math>h^2</d-math> elements. Gradients have the same size as weights.</li>
|
3559 |
|
3560 |
+
<li><strong>Optimizer states:</strong> For each weight matrix (of <d-math>h^2</d-math> elements), an optimizer like Adam with mixed precision training will keep momentum and variance states in FP32 precision (<d-math>2 \cdot h^2</d-math>), plus master weights in FP32 (<d-math>h^2</d-math>). So, the total number of optimizer states will be around (<d-math>6 \cdot h^2</d-math>) per weight matrix.</li>
|
3561 |
|
3562 |
+
<li><strong>Total model parameters:</strong> Each transformer block will store:
|
3563 |
<ul>
|
3564 |
<li>Attention parameters:
|
3565 |
<ul>
|
|
|
3567 |
<li>Output projection: <d-math>h^2</d-math> parameters</li>
|
3568 |
</ul>
|
3569 |
</li>
|
3570 |
+
<li>MLP parameters with Gated Linear Units (GLU):
|
3571 |
<ul>
|
3572 |
<li>Gate and up projections: <d-math>8h^2</d-math> parameters (2 matrices of size <d-math>h \times 4h</d-math>)</li>
|
3573 |
<li>Down projection: <d-math>4h^2</d-math> parameters (1 matrix of size <d-math>4h \times h</d-math>)</li>
|
|
|
3667 |
`
|
3668 |
<h4>TP communication analysis</h4>
|
3669 |
|
3670 |
+
<p>For tensor parallelism, activations are sharded across GPUs in the <a target="_self" href="#sequence_parallelism">TP regions</a> (e.g. MLP block). Let's analyze the communication pattern:</p>
|
3671 |
|
3672 |
<ul>
|
3673 |
<li>For each column-linear operation in the forward pass:
|
|
|
3684 |
<li>Total communication per block: <d-math>8 \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3685 |
<li>Total communication for full model: <d-math>8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3686 |
</ul>
|
3687 |
+
<p>Let's take a TP region within a layer and analyze if we can overlap the all-gather communication with the computation of the next linear. The communication time for all-gather operations is:</p>
|
3688 |
|
3689 |
<d-math block>
|
3690 |
t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
|
|
|
3701 |
\frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3702 |
</d-math>
|
3703 |
|
3704 |
+
<p>This ratio tells us whether we can successfully hide the all-gather communication behind the computation of the next linear. Interestingly, the ratio only depends on the hidden size <d-math>h</d-math> and tensor parallelism degree <d-math>tp</d-math>, not on sequence length or batch size.</p>
|
3705 |
|
3706 |
|
3707 |
<h4>PP communication analysis</h4>
|