This generates a trace that we can visualize in TensorBoard or Chrome's trace viewer. The trace shows:
Understanding these patterns is crucial for optimizing distributed training performance. For example, the trace would clearly show if gradient synchronization is properly overlapped with backward computation as we'll discuss later.
+ Understanding these patterns is crucial for optimizing distributed training performance. For example, the trace will clearly show if gradient synchronization is properly overlapped with backward computation, as we'll discuss later.
- Now let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called data parallelism which –as we'll see– is just a parallel version of gradient accumulation .
+ Now let’s get a larger workstation with a couple of GPUs and start investigating our first scaling technique, called data parallelism - which, as we'll see, is just a parallel version of gradient accumulation.
The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. You've probably already seen Data Parallelism in simple training examples but as you'll soon see we'll dive quite deeper in this section so stay tuned even if you know the general approach.
+ The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replicas “model instances”) and run forward and backward passes on different micro-batches of data in parallel on each GPU - hence the name data parallelism . You've probably already seen data parallelism in simple training examples, but we'll dive quite a bit deeper in this section, so stay tuned even if you know the general approach.
Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances will be averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.
+ Using a different micro-batch for each GPU means we’ll have different gradients on each GPU, so to keep the model instances in sync across the different GPUs, we'll average the gradients from the model instances using an operation called “all-reduce.” This operation takes place during the backward pass, before the optimizer step.
- A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is A BIG NO! Because we don’t want our GPUs to stay idle while communication is happening, like on the above graph.
+ A naive DP implementation would just wait for the backward pass to finish so that we have all the gradients, then trigger an all-reduce over all the DP ranks to sync the gradients. But such sequential steps of computation followed by communication are A BIG NO-NO because we don’t want our GPUs to stay idle while communication is happening, like in the above image.
- Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.
+ Instead, we should try to overlap communication and computation whenever possible so that they happen at the same time.
- Let’s see three optimizations that allow us to do much better than our naive first implementation!
+ Let’s take a look at three optimizations that allow us to do much better than our naive first implementation.
The main drawback of the naive DDP approach we’ve just described is that after the backward pass (computation ), we have to wait for gradient synchronization (communication ) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!
-
- As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.
+ The main drawback of the naive DP approach we’ve just described is that after the backward pass (computation ), we have to wait for gradient synchronization (communication ) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!
+ As shown in the figure above, the gradients (pink boxes) for a layer can be gathered and summed even before the gradients from earlier layers (the pink boxes to their left) have been computed. For example, as soon as the backward pass of the last layer is complete (the last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.
+
Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:
+ Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with the backward pass within the same training step, significantly speeding up data parallelism. Here's a full implementation of naive DP with synchronization overlap:
- 👉 Naive DP implementation with overlap in Picotron (Click to expand)
+ 👉 Naive DP implementation with overlap in Picotron (click to expand)
-->
- We see that above some limit, our throughput starts to drop quite significantly while the memory usage per GPU stays constant and is not affected by adding more DP ranks.
+ As shown here, above some limit, our throughput starts to drop quite significantly while the memory usage per GPU stays constant and is not affected by adding more DP ranks.
- Data parallelism was our first (simple) strategy to scale training across more GPUs. This technique works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!
+ Data parallelism was our first (simple) strategy to scale training across more GPUs. This technique works like gradient accumulation but parallelizes the forward and backward passes on micro-batches, thus increasing throughput.
- The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs=1) into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated:
- Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)
+ The keen reader has already probably noted, however, that this assumes that we can fit at least one input sample forward pass (mbs =1) into GPU memory. This is not always the case! As we can see, larger models often don’t fit into a single GPU, even with activation recomputation activated:
+ Tip: You can quickly eyeball the minimum memory required for your model’s parameters by multiplying by 2 - e.g., 70B → 140 GB (= 133 GiB).
@@ -833,152 +834,151 @@
- We've also seen that Data Parallelism starts to have some limiting communication overhead above a certain level of scaling. Do we have other options for these larger models or large batch-size? We do have some solutions thankfully. They will involve either move some tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices! Let's start diving in them.
+ We've also seen that data parallelism starts to have some limiting communication overhead above a certain level of scaling. Do we have other options for these larger models or large batch sizes? We do have some solutions, thankfully - they involve either moving some tensors to the CPU or splitting the weights/gradients/optimizer states tensors across GPU devices.
- There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined!
+ There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharding (DeepSpeed ZeRO or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined!
- The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!
+ The sharding paradigm is closely related to DP, so we’ll have a look at it first by investigating the ZeRO method.
- ZeRO (Ze ro R edundancy O ptimizer)
+ Zero Redundancy Optimizer (ZeRO)
- In this section we will introduce DeepSpeed ZeRO (Ze ro R edundancy O ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.
+ In this section we will introduce DeepSpeed ZeRO, a memory optimization technology designed to reduce memory redundancy in LLM training.
- While Data Parallelism is an efficient way to scale training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!
+ While data parallelism is an efficient way to scale training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces significant memory redundancy. ZeRO eliminates this by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks, which may or may not be fully overlapped, as we’ll see next!
- We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the DeepSpeed docs .
+ We’ll focus on ZeRO-1 to ZeRO-3 in this book, as this should give a broad view of how this technology helps reduce the memory usage while showing the trade-offs to take into account. You can find details on more ZeRO flavors in the DeepSpeed docs .
- This approach is organized into three possible optimization stage of ZeRO:
+ This approach is organized into three possible optimization stages:
ZeRO-1: optimizer state partitioning
ZeRO-2: optimizer state + gradient partitioning
- ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning
+ ZeRO-3: optimizer state + gradient + parameter partitioning
- When we say partitioning, it means along the DP axis, as ZeRO is part of Data Parallelism. We’ll see later that we can partition along other axes.
+ When we say "partitioning" here, it means along the DP axis, as ZeRO is a data-parallel method. We’ll see later that we can partition along other axes as well.
- You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!
+ You might have noticed that activations is missing from the list of things we can shard. Since each DP replica of the model receives a different micro-batch, the activations on each DP rank also differ, so they are not duplicated and thus can’t be sharded!
- Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!
+ Let’s have a closer look how much we can save with the partitioning of each ZeRO stage.
Memory usage revisited
- You likely remember from our previous section the memory usage of optimizer states, gradients, and parameters during a standard training. Let's call our model's parameters count \Psi (previously N but here we use the original ZeRO paper notation). In Mixed Precision Training (more details in a later section) with the Adam optimizer, the memory usage for each item we need to store is:
+ Earlier , we discussed the memory usage of optimizer states, gradients, and parameters during standard training. Let's call our model's parameter count \Psi (previously this was N , but here we use the original ZeRO paper's notation). In mixed precision training (discussed further later in the book ) with the Adam optimizer, the memory usage for each item we need to store is:
- Model’s parameters (half precision i.e. bf16/fp16): 2\Psi
- Model’s gradients (half precision i.e. bf16/fp16): 2\Psi
- Model’s parameters in fp32 and optimizer states: 4\Psi + (4\Psi + 4\Psi)
- Model’s gradients in fp32: 4\Psi (optional, only accounted if we want to accumulate grads in fp32)
+ Model’s parameters (half precision; i.e., BF16/FP16): 2\Psi
+ Model’s gradients (half precision; i.e., BF16/FP16): 2\Psi
+ Model’s parameters in FP32 and optimizer states: 4\Psi + (4\Psi + 4\Psi)
+ Model’s gradients in FP32: 4\Psi (optional, only included if we want to accumulate gradients in FP32)
- If we don’t accumulate gradients in fp32 this gives us a total memory consumption of 2\Psi + 2\Psi + 12\Psi , and if we accumulate it would be 2\Psi + 6\Psi + 12\Psi . Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.
+ If we don't accumulate gradients in FP32, this gives us a total memory consumption of 2\Psi + 2\Psi + 12\Psi , and if we do it gives us 2\Psi + 6\Psi + 12\Psi . Let's focus for now on the case without FP32 gradient accumulation for simplicity.
- The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree N_d :
+ The idea of ZeRO is to shard these objects across the DP ranks, with each node only storing a slice of the items. These slices are then reconstructed when and if needed, thereby dividing memory usage by the data parallel degree N_d :
- Here \Psi denotes number of parameters, k denotes the memory multiplier of optimizer states (k=12 for Adam as we've just seen), and N_d denotes DP degree.
+ Here, \Psi denotes the number of parameters, k denotes the memory multiplier of optimizer states (k=12 for Adam, as we've just seen), and N_d denotes DP degree.
+ If you're using FP32 gradient accumulation with ZeRO-2 or ZeRO-3, you would need to add an additional \frac{4\Psi}{N_d} to the gradient term.
- Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.
+ Let’s explain this by exploring how each ZeRO stage works. We’ll start with ZeRO-1.
- ZeRO-1: Partitioning Optimizer States
+ ZeRO-1: Partitioning optimizer states
In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?
- In ZeRO-1, the optimizer states are partitioned into N_d equal parts where N_d is the DP degree. This means that each model replica distributed on each DP rank only keeps track of \frac{1}{N_d} of the optimizer states. During the optimization step only \frac{1}{N_d} of the float32 weights are updated.
-
- However during the forward pass, each replica need all the parameters, we thus need to add an additional all-gather (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.
+ In ZeRO-1, the optimizer states are partitioned into N_d equal parts, where N_d is the DP degree. This means that the model replicas distributed on the DP ranks each only keep track of \frac{1}{N_d} of the optimizer states, and during the optimization step, only \frac{1}{N_d} of the FP32 weights are updated.
- This explains the memory formula of 2\Psi + 2\Psi + \frac{k\Psi}{N_d} that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step
+ However, during the forward pass, each replica needs all the parameters. We thus need to add an additional all-gather (the second type of collective communication primitive we've encountered!) after the optimizer step so that each model replica has the full set of updated weights.
-
- Forward pass with the same, full set of bf16 parameters on each replica, but different microbatches across replicas
- Backward pass with the same, full set of gradients on each replica, but different microbatches across replicas
- Perform an reduce-scatter on the gradients (we'll explain the reduce-scatter primitive in the graph below)
- Each replica perform an optimizer step on its local optimizer steps (only \frac{1}{N_d} optimizer states) to get updated \frac{1}{N_d} fp32 parameters which can then be converted to \frac{1}{N_d} of the full set of bf16 parameters.
- Perform an all-gather among the bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.
-
- Note: reduce-scatter is 2 times faster than all reduce! Yay, a third communication primitive!
+ This explains the memory formula of 2\Psi + 2\Psi + \frac{k\Psi}{N_d} that we saw in the previous figure! Here’s a summary of the sequence of operations for a single training step:
+
+ Perform a forward pass with the same full set of BF16 parameters on each replica, but different micro-batches across replicas.
+ Perform a backward pass with the same full set of gradients on each replica, but different micro-batches across replicas.
+ Perform a reduce-scatter on the gradients (another primitive - we'll explain this one shortly).
+ Each replica performs an optimizer step on its local optimizer states (only \frac{1}{N_d} of the optimizer states) to get \frac{1}{N_d} updated FP32 parameters, which can then be converted to \frac{1}{N_d} of the full set of BF16 parameters.
+ Perform an all-gather on the BF16 parameters to send the missing slices back to each replica. This is a new operation in ZeRO and is not used in vanilla DP.
+
+ Note: Reduce-scatter is two times faster than all-reduce! Yay, a third communication primitive!
- You may be wondering what is this "reduce-scatter" operation and how this all look so let's try to make this more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:
+ You may be wondering what this "reduce-scatter" operation is and what this all looks like, so let's try to make it more graphical with the figure below. We'll go over all the steps of a forward/backward pass cycle:
- In terms of practical communications, compared to vanilla DP, Zero-1 change our "all-reduce" gradient communication to a "reduce-scatter" operation and adds an all-gather operation over all parameters after the optimizer step. Here is how it looks:
-
+ In terms of practical communications, compared to vanilla DP, ZeRO-1 changes our all-reduce gradient communication to a reduce-scatter operation and adds an all-gather operation over all parameters after the optimizer step. Here's how it looks:
+
- If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:
+ If you've been following along, you'll recall from our discussion of vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of BF16 parameters. There are two main strategies for this:
- During optimizer step: We can initiate the all-gather immediately after the optimizer updates part of the parameters. This allows the communication to potentially overlap with other parameters update.
- During forward: We can overlap the all-gather of each layer’s parameters with the forward pass.
+ During the optimizer step: We can initiate the all-gather immediately after the optimizer updates the first slice of the parameters. This allows the communication to potentially overlap with the updating of the other parameters.
+ During the forward pass: We can overlap the all-gather of each layer’s parameters with the forward pass.
📝 Note
-
Unfortunately these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice we can just use PyTorch native ZeRO-3/FSDP implementation and set the FSDPUnit to be the entire model, more details about this later.
+
Unfortunately, these techniques are not straightforward to implement and require sophisticated use of hooks/bucketing. In practice, we can just use PyTorch's native ZeRO-3/FSDP implementation and set the FSDPUnit
to be the entire model (more details about this later).
- In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates \frac{1}{N_d} of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place as only a subset is needed for the optimization step. Meet ZeRO-2!
+ In ZeRO-1, the optimizer states have been partitioned, which means that each replica only updates \frac{1}{N_d} of the states. The keen reader might have noticed that there is no real need to have all the gradients on all the DP ranks, as only a subset of these are needed for the optimization step. Meet ZeRO-2!
- ZeRO-2: Adding Gradient Partitioning
+ ZeRO-2: Adding gradient partitioning
- Since we only need, on each replica, to have the gradient shard corresponding to the optimizer state shard, it makes sense to shard gradient as well similarly to the optimizer states. During the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Where we only spread the \frac{1}{N_d} gradients needed in memory, thus saving more memory compared to ZeRO-1.
+ Since on each replica we only need to have the gradient shard corresponding to its optimizer state shard, it makes sense to shard gradients as well, similarly to the optimizer states. Then, during the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Here, we only store the \frac{1}{N_d} gradients that are needed in memory, thus saving more memory compared to ZeRO-1.
- In case of FP32 gradient accumulation, we only need to keep \frac{1}{N_d} fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the \frac{1}{N_d} fp32_grads.
+ In the case of FP32 gradient accumulation, we only need to keep \frac{1}{N_d} FP32 grads used to accumulate the BF16 grads coming from the reduce-scatter. And in the optimizer step, these \frac{1}{N_d} FP32 grads are used to update the local shard of the optimizer states.
- It’s easy to see now that sharding the gradients leads to to 2\Psi + \frac{2\Psi+k\Psi}{N_d} and as N_d is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.
-
- In terms of communication ZeRO-2 is similar to ZeRO-1, they both require a reduce-scatter for the gradients, and an all-gather over all parameters.
-
+ It's easy to see now that sharding the gradients leads to 2\Psi + \frac{2\Psi+k\Psi}{N_d} , and as N_d is increased, we can use up to 8x less memory than the baseline. In terms of communication, the same process applies as for ZeRO-1, with the only difference being that we communicate and release memory on the fly: they both require a reduce-scatter for the gradients and an all-gather over all parameters. ZeRO-2 is thus also equivalent to vanilla DP training with regard to communication.
+
- Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option.
+ Note: You might notice that there is no real overhead to using ZeRO-2 over ZeRO-1 besides implementation complexity, and indeed ZeRO-2 is usually the better option.
- Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. Here comes ZeRO-3!
+ Now that we've sharded gradients as well, are we done, or can we keep making improvements? Here comes ZeRO-3!
- ZeRO-3: Adding Parameter Partitioning
+ ZeRO-3: Adding parameter partitioning (FSDP)
- For Stage 3 we extend the above approach of sharding optimizer states and gradients over DP replicas up to sharding the model’s parameters.
+ For stage 3, we extend the above approach of sharding optimizer states and gradients over DP replicas to sharding the model’s parameters.
📝 Note
-
This stage is also called FSDP (Fully Shared Data Parallelism) in PyTorch native implementation. We’ll just refer to ZeRO-3 in this blogpost but you can think of FSDP wherever you see it.
+
PyTorch's native implementation of this stage is called FSDP (Fully Sharded Data Parallelism). We’ll just refer to it as ZeRO-3 in this book, but you can think of FSDP wherever you see it.
- So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:
+ So how do we do a forward or backward pass in practice if the parameters of the model are distributed? Quite simply, we gather them on demand when we need them. In the forward pass, this looks as follows:
- So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don't need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards:
+ As we perform the forward pass and sequentially go through the layers, we retrieve the necessary parameters on demand and immediately flush them from memory when we don't need them anymore. The backward pass works the same way, just inverted in flow. Here, we produce the gradient shards:
- The other issue is that we need to do these all-gathers continuously throughout the forward and backward step, which amounts to 2\cdot \text{num\_layers} -1 additional all-gathers in a training step compared to Zero-2, each comes with a small base latency overhead as we can see in the following figure:
-
+ The other issue is that we need to do these all-gathers continuously throughout the forward and backward pass in a training step, which amounts to 2\cdot \text{num\_layers} -1 additional all-gathers in a training step compared to ZeRO-2. Each comes with a small base latency overhead, as we can see in the following figure:
+
- During the forward pass we do all-gather operations for the parameters when we need them, so a \Psi communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another \Psi in communication tax. Finally we need the same reduce-scatter as in ZeRO-2 for the gradients which costs also \Psi in communication and we arrive at a total communication cost of 3\Psi , compared to 2\Psi for Zero-2.
+ During the forward pass we do all-gather operations for the parameters when we need them, so there's a \Psi communication tax. Since we discard the parameters immediately after we use them in the forward pass, we need one more all-gather during the backward pass as well, incurring another \Psi communication tax. Finally, we need the same reduce-scatter operation as in ZeRO-2 for the gradients, which also costs \Psi in communication. So, we arrive at a total communication cost of 3\Psi , compared to 2\Psi for ZeRO-2.
- This may sounds like a lot of communication overhead but it's actually pretty fine as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called prefetching . With prefetching, we will "all-gather" weights for *Layer n+1* while we do the current forward for Layer n in the forward, and similarly, we will "all-gather" weights for Layer n-1 while doing the backward for Layer n . Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)
+ This may sound like a lot of communication overhead, but it's actually not a big deal, as we can overlap the communication of the parameters for the next layer with the forward pass of the current layer in what is called prefetching . With prefetching, we all-gather the weights for Layer n+1 while we do the forward pass for Layer n , and similarly, we all-gather the weights for Layer n-1 while doing the backward pass for Layer n . Of course, this overlap only works as long as we don’t scale DP too much (as a rule of thumb, DP shouldn’t exceed 512).
+
+ Note: We use "DP" to refer to both the data parallelism technique and the number of GPUs used for data parallelism (DP = DP size = DP degree) .
- In terms of memory we can see that our equation now reached it’s final form of \frac{2\Psi +2\Psi+k\Psi}{N_d} which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in the previous chapters.
+ In terms of memory, we can see that our equation has now reached its final form of \frac{2\Psi +2\Psi+k\Psi}{N_d} , which means we can theoretically drive memory usage down indefinitely if we can increase the DP size, at least for the model-related parameters. Notice that it doesn’t help with the intermediate activations, though - for that, we can use activation checkpointing and gradient accumulation, as we saw earlier.
- Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.
-
- If you want to read more about FSDP1, FSDP2 and some of the implementation complexities around them, you should take some time to go over this nice blog .
+ Let’s summarize our journey into DP and ZeRO so far. We've seen that we can increase the throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO, we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients, and optimizer states across DP replicas, while incurring a small communication cost.
+ If you want to read more about FSDP1, FSDP2, and some of the implementation complexities around them, check out this nice blog .
- However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! We recall from the activation memory discussion that this part of the memory scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length.
-
+ However, there are some limits here: DP only works if a layer of the model fits in a single GPU, and ZeRO can only partition the parameters, gradients, and optimizer states, not the activation memory! Recall from the activation memory discussion that this part of the memory scales with sequence length and batch size. We could just limit those, but in practice we don’t want to be limited by hardware to train with only a short sequence length.
+
- To overcome this issues, it's time to explore a new, orthogonal axis of parallelism - Tensor Parallelism (TP). Unlike ZeRO3 which relies on heavy parameter communication, TP proposes to shard parameters, gradients, optimizer states AND activations across devices without requiring any communication of model parameters between GPUs.
+ To overcome this issue, it's time to examine a new, orthogonal axis of parallelism - tensor parallelism (TP) . Unlike ZeRO-3, which relies on heavy parameter communication, TP proposes to shard parameters, gradients, optimizer states, AND activations across devices without requiring any communication of model parameters between GPUs.
- What? How is this even possible?! Let's explore this seemingly magical approach together! 🙂
+ What? How is this even possible?! Let's explore this seemingly magical approach together. 🙂
Tensor Parallelism
@@ -1008,9 +1008,9 @@
- So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.
+ So, we've sharded the model’s parameters, gradients, and optimizer states with ZeRO, but we hit a limit once activation memory overtakes our memory budget. Welcome tensor parallelism (TP), a method that shards weights, gradients, and optimizer states as well as activations - and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how TP works with simple matrix multiplication (matmul) operations.
- Tensor Parallelism leverages the mathematical properties of matrix multiplication A \times B . To understand how it works, let's examine two fundamental equations that make this parallelization possible:
+ Tensor parallelism leverages the mathematical properties of matrix multiplication, A \times B . To understand how it works, let's examine two fundamental equations that make this parallelization possible:
\begin{aligned}
@@ -1019,37 +1019,35 @@
\end{aligned}
- This means that we can compute matrix product by either 1) multiplying each column of B individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: X \times W , where:
+ This means that we can compute the matrix product by either multiplying each column of B individually or multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the format X \times W , where:
- X represents the input or activation values
- W represents the weight of the nn.Linear
+ X represents the input or activation values
+ W represents the weight of the Linear layer
- In practice a small example of the operation looks like this:
+ In practice, a small example of the operation looks like this:
- Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.
+ Let’s see how we can parallelize this operation! In tensor parallelism, tensors are split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split on either columns or rows, leading to row or column parallelism. As we’ll see in the following discussion, row and column sharding require different communication primitives.
- Our first option is to use column-wise sharding (also called column-linear ): We'll copy the complete input matrices to each worker, requiring an operation called broadcast , and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an all-gather operation.
+ Our first option is to use column-wise (also called column-linear ) sharding: we'll copy the complete input matrices to each worker, requiring an operation called broadcast , and split the weight matrix by columns. The inputs are then multiplied with the partial weight matrices, and finally the results are combined using an all-gather operation.
- Here's the code implementation of column wise tensor parallelism:
+ Here's the code implementation of column-wise tensor parallelism:
- 👉 Column parallel TP implementation in Picotron (Click to expand)
+ 👉 Column parallel TP implementation in Picotron (click to expand)
- The second option is called row-wise sharding (also called row-linear ): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a scatter operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.
-
- We see here our fourth distributed primitive: scatter !
+ The second option is called row-wise (or row-linear ) sharding. As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, so we need to use a scatter operation (our fourth distributed communication primitive!) rather than the broadcast operation used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, so this scenario also requires an all-reduce operation:
@@ -1057,7 +1055,7 @@
- 👉 Row parallel TP implementation in Picotron (Click to expand)
+ 👉 Row-parallel TP implementation in Picotron (click to expand)
@@ -1066,35 +1064,36 @@
Now that we have the basic building blocks of TP, let's have a look at how we can effectively combine them inside a transformer layer!
- Tensor Parallelism in a Transformer Block
+ Tensor parallelism in a transformer block
- To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.
+ To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks: a feedforward multi-layer perceptron (MLP) block and a multi-head attention (MHA) block. We can apply tensor parallelism to both.
- The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks. This setup is more efficient than starting with "Row Linear" followed by "Column Linear" as we can skip the intermediate all-reduce between both splitted operations.
+ The feedforward part can be parallelized by having a column-linear followed by a row-linear split, which amounts to a broadcast to copy the input and an all-reduce in the forward pass. Note that the broadcast isn’t needed in actual training, where we can make sure inputs are already synced across TP ranks. This setup is more efficient than starting with a row-linear followed by column-linear split, as we can skip the intermediate all-reduce between the split operations.
- Now that we’ve found an efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).
-
- We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for multi-query (MQA) or grouped query attention (GQA) where key and values are shared between queries.
+ Now that we've found an efficient schema for the feedforward part of the transformer, let's take a look at the multi-head attention block.
- It's worth noting however that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank (otherwise we cannot compute the attentions independently on each GPU and we'll need additional communication operations). In case we’re using GQA, the TP degree should actually be smaller than the number of K/V heads. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should advantageously not exceed 8. If we use TP=16 for this model, we will need to duplicate the K/V heads on each GPU and make sure they stay in sync.
+ We can generally follow a similar approach here, where the Query (Q), Key (K), and Value (V) matrices are split in a column-parallel fashion and the output projection can be considered a row-linear. With multi-head attention, the column-parallel approach has a very natural interpretation: each GPU computes the attention for an individual or a subset of attention heads. The same approach works as well for multi-query attention (MQA) or grouped query attention (GQA) , where keys and values are shared between queries.
-
- Finally note that Tensor Parallelsim is still not a silver bullet for training. We’ve added several distributed communication primitive directly in the computation path of our model which are therefore hard to fully hide/overlap with computation (like we did in ZeRO), our final performances will be the results of a tradeoff between the computation and memory gains and the added communication overhead. Let's illustrate this:
-
+ We're able to apply tensor parallelism so effectively to both the Attention and MLP blocks because they have dimensions that are naturally independent. The Attention block can be parallelized along the num_attention_heads
dimension, as each attention head operates independently. Similarly, the MLP block can be parallelized along the hidden_dim
dimension, as operations within the feedforward network are independent along this dimension.
+ It's worth noting, however, that the tensor parallelism degree should not exceed the number of attention heads because we shard the QKV projection along the num_attention_heads
dimension. When using Grouped Query Attention (GQA), we have num\_attention\_heads query heads but only num\_kv\_heads key/value heads (with num\_attention\_heads >= num\_kv\_heads ). In this case, we can still set TP = num\_attention\_heads , but we'll need to ensure that the K/V heads stay properly synchronized across GPUs. For instance, Llama-3 8B has 32 query heads but only 8 key/value heads, so while the TP degree could theoretically go up to 32, we would need careful implementation to maintain K/V head synchronization across the tensor-parallel workers.
+
+ Note also that tensor parallelism is not a silver bullet for training. We’ve added several distributed communication primitives directly in the computation path of our model, which are therefore hard to fully hide/overlap with computation (like we did in ZeRO), so our final performance will be the result of a trade-off between the computation and memory gains and the added communication overhead. Let's illustrate this:
+
+
It's possible to partially hide this communication by performing block matrix multiplication coupled with async communication/computation.
- Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This exposed communication overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied.
+ Looking at the timeline of operations in tensor-parallel MLP (the same applies for MHA), we can better understand the trade-offs involved. In the forward pass of each decoder layer, we hit a synchronization point with the all-reduce operation that cannot be overlapped with computation. This exposed communication overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied.
- For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation where a portion of the matrix multiplication result will start to be sent to the other GPU while the other part is still being computed.
+ For example, Megatron-LM and Nanotron implement a partial overlapping of all-gather with Fully-Connected (FC1) computation, where a portion of the matrix multiplication result gets sent to the other GPU while the remaining part is still being computed.
- Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, TP introduces significant communication requirements that heavily depend on the network infrastructure. The inability to fully hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.
+ Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, TP introduces significant communication requirements that heavily depend on the network infrastructure. The inability to fully hide this particular all-reduce behind computation means it directly adds to the critical path of forward propagation, where the critical path refers to the sequence of operations that determine the minimum time required to complete the forward pass.
- This area of research is still an active area of research, with recent work like Domino exploring novel techniques to maximize this overlap.
+ This is an active area of research, with recent work like Domino exploring novel techniques to maximize this overlap.
Let's take a better look at the trade-off as we scale the TP degree:
@@ -1112,10 +1111,10 @@
While increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.
- In practice and as we see above on the left plot, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. We observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. At higher degrees of parallelism, the communication overhead becomes so high that it quickly dominates the computation time.
-
- This being said, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:
+ In practice, as we see in the lefthand plot above, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. We observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. At higher degrees of parallelism, the communication overhead becomes so high that it quickly dominates the computation time.
+ This being said, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states, and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:
+
- Increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU to the point where we can start fitting a large model on a single node of 8 GPUs.
+ Increasing tensor parallelism reduces the memory needed for model parameters, gradients, and optimizer states on each GPU to the point where we can start fitting a larger model onto a single node of 8 GPUs.
- Is there a way to get even more benefits from this technique? We've seen that layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.
+ Is there a way to get even more benefits from this technique? Layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.
📝 Note
-
One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
+
One interesting note about layer normalization in tensor-parallel training is that since each TP rank sees the same activations after the all-gather, the LayerNorm weights don't actually require an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
- Let's explore next a small and natural extension to tensor parallelism, called Sequence Parallelism which does exactly that.
+ Next, we'll explore a small, natural extension to tensor parallelism called sequence parallelism that does exactly that.
- Sequence Parallelism
+ Sequence parallelism
- Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension.
+ Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism, such as dropout and LayerNorm, but along the input sequence dimension rather than the hidden dimension.
📝 Note
-
The term Sequence Parallelism is a bit overloaded: the Sequence Parallelism in this section is tightly coupled to Tensor Parallelism and applies to dropout and layer norm operation. However, when we will move to longer sequences the attention computation will become a bottleneck, which calls for techniques such as Ring-Attention, which are sometimes also called Sequence Parallelism but we’ll refer to them as Context Parallelism to differentiate the two approaches. So each time you see sequence parallelism, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently).
+
The term sequence parallelism is a bit overloaded. The sequence parallelism discussed in this section is tightly coupled to tensor parallelism and applies to dropout and layer normalization operations. However, when we move to longer sequences, the attention computation will become a bottleneck, which calls for techniques such as Ring Attention. These are sometimes also referred to as sequence parallelism approaches, but we’ll refer to them as context parallelism instead to differentiate the two approaches. So, when you see "sequence parallelism" in this book, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently).
@@ -1159,60 +1158,60 @@
where \mu = \text{mean}(x) and \sigma^2 = \text{var}(x) are computed across hidden dimension h .
- So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this memory burden across GPUs by splitting along the sequence dimension instead.
-
- In practice we’ll go from the left diagram to the right:
+ Consequently, even though these operations are computationally cheap, they still require significant activation memory. Sequence parallelism allows us to shard this memory burden across GPUs by splitting along the sequence dimension instead.
+ The following diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled f and g ). In practice, we’ll go from the left to the right:
+
- The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.
+ The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.
- In the forward pass:
+ In tensor parallelism, in the forward pass:
- "f" is a no-op (no operation) because activations are already duplicated across ranks
- "f*" is an all-reduce to synchronize activations and ensure correctness
+ f is a no-op (no operation) because activations are already duplicated across ranks.
+ f* is an all-reduce to synchronize activations and ensure correctness.
- In the backward pass:
+ And in the backward pass:
- "f*" is a no-op because gradients are already duplicated across ranks
- "f" is an all-reduce to synchronize gradients
+ f* is a no-op because gradients are already duplicated across ranks.
+ f is an all-reduce to synchronize gradients.
- These operations "f" and "f*" are called conjugate pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.
+ These f and f* operations are called conjugate pairs because they complement each other - in each pass, when one is a no-op the other is an all-reduce, and it's the opposite in the other pass.
- For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
+ For sequence parallelism, we use different operations labeled g and g* . Specifically, we avoid using all-reduce in the SP regions since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
- So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:
-
+ So what is actually happening here? As a famous LLM would say, let’s take it step by step:
+
-
Initial LayerNorm (SP Region)
+
Initial LayerNorm layer (SP region)
- Input tensors X1 and X2 (b,s/2,h) enter LayerNorm, already split across sequence dimension
- Each GPU computes LayerNorm independently on its sequence chunk and give Y1 and Y2
+ Input tensors X1* and X2* (b,s/2,h) enter, already split across the sequence dimension.
+ Each GPU computes LayerNorm independently on its sequence chunk, giving Y1* and Y2* .
-
First Transition (SP → TP)
+
First transition (SP → TP)
- "g" operation (all-gather) combines Y1 and Y2 back to full sequence length
- Restores Y (b,s,h) since column linear needs full hidden dimension h
+ g operation (all-gather) combines Y1 and Y2 back to full sequence length.
+ Restores Y (b,s,h) since column-linear layers need the full hidden dimension h .
-
First Linear (TP Region)
+
First linear layer (TP region)
- A1 is a column-linear, so it splits Y along the hidden dimension
- GeLU is applied independently on each GPU
- Z1* is (b,s,h/2)
+ A1 and A2 are column-linear layers, so they split Y along the hidden dimension.
+ GELU is applied independently on each GPU.
+ Z1* and Z2* are (b,s,h/2) .
-
Second Linear (TP Region)
+
Second linear layer (TP region)
- B1 is a row-linear, so it restores the hidden dimension
- W1 is (b,s,h)
+ B1 and B2 are row-linear layers, so they restore the hidden dimension.
+ W1 and W2 are (b,s,h) that need to be summed together.
-
Final Transition (TP → SP)
+
Final transition (TP → SP)
- "g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension
- W1* is (b,s/2,h)
+ g* operation (reduce-scatter) reduces for previous row-linear correctness while scattering along the sequence dimension.
+ W1* and W2* are (b,s/2,h) .
@@ -1222,9 +1221,9 @@
- A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to \frac{b \cdot s \cdot h}{tp} since we always either split along the sequence or hidden dimensions.
+ A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. With tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to \frac{b \cdot s \cdot h}{tp} since we always either split along the sequence or the hidden dimension.
- It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka hidden_states
) shape change across hidden dimension h and sequence dimension s during a forward pass:
+ It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP+SP - believe us, we find it hard to map as well, so we made this small table to summarize how the activations (a.k.a. hidden_states
) shape changes across the hidden dimension h and sequence dimension s during a forward pass:
@@ -1236,24 +1235,24 @@
- Enter TP (Column Linear)
- h: sharded (weight_out is sharded) s: full
- h: sharded (weight_out is sharded) s: all-gather to full
+ Enter TP (column-linear)
+ h : sharded (weight_out
is sharded)s : full
+ h : sharded (weight_out
is sharded)s : all-gather to full
- TP Region
- h: sharded s: full
- h: sharded s: full
+ TP region
+ h : shardeds : full
+ h : shardeds : full
- Exit TP (Row Linear)
- h: full (weight_out is full + all-reduce for correctness) s: full
- h: full (weight_out is full + reduce-scatter for correctness) s: reduce-scatter to sharded
+ Exit TP (row-linear)
+ h : full (weight_out
is full + all-reduce for correctness)s : full
+ h : full (weight_out
is full + reduce-scatter for correctness)s : reduce-scatter to sharded
- SP Region
- h: full s: full
- h: full s: sharded
+ SP region
+ h : fulls : full
+ h : fulls : sharded
@@ -1270,14 +1269,14 @@
- Embedding Layer (Row Linear sharded on vocab)
- h: full (weight_out is full + all-reduce for correctness) s: full
- h: full (weight_out is full + reduce-scatter for correctness) s: reduce-scatter to sharded
+ Embedding layer (row-linear, sharded on vocab)
+ h : full (weight_out
is full + all-reduce for correctness)s : full
+ h : full (weight_out
is full + reduce-scatter for correctness)s : reduce-scatter to sharded
- By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:
+ By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:
@@ -1290,17 +1289,17 @@
- As we can see, we've again strongly reduced the maximum memory usage per GPU, allowing us to fit sequence lengths of 16k tokens with TP/SP=16, an improvement over the vanilla TP case! (TP=16 is still a bit large as we've seen in the previous section, but we'll see how we can improve this in the next section).
-
- One question you may be asking yourself is whether using TP+SP incurs more communication than vanilla TP? Well, yes and no. In the forward pass of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see the A quick focus on Ring AllReduce section in the appendix) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
+ We've again strongly reduced the maximum memory usage per GPU, allowing us to fit sequence lengths of 16k tokens with TP+SP=16 - an improvement over the vanilla TP case! (TP=16 is still a bit large, as we saw in the previous section, but we'll see how we can improve this in the next section.)
- If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:
+ One question you may be asking yourself is whether using TP+SP incurs more communication overhead than vanilla TP. Well, yes and no. In the forward pass with vanilla TP we had two all-reduce operations per transformer block, and in SP we have two all-gather and two reduce-scatter operations per transformer block. So, SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into an all-gather and a reduce-scatter (see the "Ring AllReduce" section in the appendix), they’re actually equivalent in terms of communication cost. The same reasoning applies for the backward pass, as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
+ If you’ve been paying close attention, you’ll notice that we’re talking about four communication operations in each layer (two for attention and two for MLP). This is what the MLP profiling looks like when using TP+SP:
+
- Just like vanilla TP, TP+SP can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. Here again, like vanilla TO, TP+SP is usually done only within a node (keeping the TP degree under the number of GPU per nodes, e.g. TP≤8).
+ Just like vanilla TP, TP+SP can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. Here again, like vanilla TP, TP+SP is usually done only within a node (keeping the TP degree under the number of GPUs per node; e.g., TP≤8).
- We can benchmark how this communication overhead becomes increasingly problematic as we scale up tensor parallelism. Let’s measure the throughput and memory utilization as we scale TP with SP for a 3B model with 4096 seqlen:
+ We can benchmark how this communication overhead becomes increasingly problematic as we scale up tensor parallelism. Let’s measure the throughput and memory utilization as we scale TP with SP for a 3B parameter model with a sequence length of 4,096:
@@ -1313,33 +1312,33 @@
- Here again, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees enable processing of significantly larger batch sizes by reducing the activation memory, they also reduce per-GPU throughput, in particular above a threshold corresponding to the number of GPUs per node.
+ Again, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher degrees of parallelism enable processing of significantly larger batch sizes by reducing the activation memory, they also reduce per-GPU throughput, in particular above a threshold corresponding to the number of GPUs per node.
Let’s summarize our observations:
- for both methods we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink), to communicating inter-nodes (EFA)
- the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone
+ For both methods, we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink) to communicating between nodes (EFA).
+ The activation memory savings when using TP with SP help us fit far bigger batches than with TP alone.
- We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.
+ We've seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.
📝 Note
-
Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to all-reduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is however a small communication overhead since LayerNorm has relatively few parameters.
+
Since LayerNorm layers in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to all-reduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is, however, a small communication overhead since LayerNorm has relatively few parameters.
- However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.
+ Still, there are two limits to TP+SP: if we scale the sequence length the activation memory will still blow up in the TP region, and if the model is too big to fit with TP=8 we will see a massive slowdown due to the inter-node connectivity.
- We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!
+ We can tackle the first problem with context parallelism and the second problem with pipeline parallelism . Let’s first have a look at context parallelism!
Context Parallelism
- With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node as we still have to process a full sequence length when we're inside the TP region.
+ With tensor parallelism + sequence parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g., when scaling to 128k or more tokens per sequence), we might still exceed the memory available on a single node as we still have to process a full sequence when we're inside the TP region.
- Moreover, even if we use full recomputation of the activations (which comes at a heavy compute overhead of ~30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length. Let's take a look and see how Context Parallelism can help us:
+ Moreover, even if we use full recomputation of the activations (which incurs a heavy compute overhead of ~30%), we still need to hold in memory some activations at the layer boundaries, which scale linearly with sequence length. Let's take a look and see how context parallelism (CP) can help us:
@@ -1353,91 +1352,93 @@
- The core idea of Context Parallelism is to apply a similar idea to the Sequence Parallelism approach (aka to split along the sequence length) but to the modules where we already apply Tensor Parallelism. We will thus split these modules along two dimensions, thereby also reducing the effect of sequence length. You will find this approach quite intuitive after all we’ve already convered but... there is a trick to it so stay awake!
+ The core idea of context parallelism is similar to sequence parallelism (i.e., splitting along the sequence length), but this approach is applied to the modules where we already apply tensor parallelism. We will thus split these modules along two dimensions, thereby also reducing the effect of sequence length. You should find this approach quite intuitive after all we’ve already covered, but there's a trick to it, so stay awake!
- For Context Parallelism; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model as we’ve done previously with Tensor + Sequence Parallelism.
+ With context parallelism, just like sequence parallelism, we split the input along the sequence dimension - but we now apply this splitting along the full model, instead of only the sequence-parallel regions of the model, as we did previously with TP+SP.
- Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just like data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.
+ Splitting the sequence doesn't affect most modules, like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split, not the weight matrices. Just like with data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the CP group.
- There is one important exception though as we we need to pay particular attention to the Attention blocks (haha.. pun intended :D). In the attention module each token needs to access key/value pairs from all other sequence tokens or in the case of causal attention at least attends to each previous token.
+ There is one important exception, though: we need to pay particular attention to the attention blocks (haha... pun intended :D). In the attention module, each token needs to access key/value pairs from all other sequence tokens (or, in the case of causal attention, at least attend to each previous token).
- Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module will require full communication between GPUs to exchange the necessary key/value data.
+ Because context parallelism splits the inputs along the sequence dimension across GPUs, the attention module will require full communication between GPUs to exchange the necessary key/value data.
- That sounds very expensive if we do it naively. Is there a way to do this rather efficiently and fast! Thankfully there is: a core technique to handle this communication of key/value pairs efficiently is called Ring Attention .
+ That sounds very expensive if we do it naively. Is there a way to do it more cheaply, and fast? Thankfully, there is a core technique that enables us to handle this communication of key/value pairs efficiently: Ring Attention .
📝 Note
-
Context Parallelism shares some conceptual similarities with Flash Attention (see later for more details) - both techniques rely on online softmax computation to reduce memory usage. While Flash Attention focuses on optimizing the attention computation itself on a single GPU, Context Parallelism achieves memory reduction by distributing the sequence across multiple GPUs.
+
Context parallelism shares some conceptual similarities with FlashAttention, which we'll look at later in the book - both techniques rely on online softmax computation to reduce memory usage. But while FlashAttention focuses on optimizing the attention computation itself on a single GPU, context parallelism achieves memory reduction by distributing the sequence across multiple GPUs.
- Discovering Ring Attention
+ Ring Attention
- In this implementation of the attention mechanism, each GPU first initiates an asynchronous communication operation to send its key/value pairs to other GPUs. While waiting for the other GPUs data, it computes the attention score for the portion of the data it already has in memory. Ideally, a next key/value pair is received from another GPU before this computation finishes, allowing the GPU to start the next round of computation immediately after it finishes its first computation.
+ In this implementation of the attention mechanism, each GPU first initiates an asynchronous communication operation to send its key/value pairs to other GPUs. While waiting for the other GPUs' data, it computes the attention score for the portion of the data it already has in memory. Ideally, the next key/value pair is received from another GPU before this computation finishes, allowing the GPU to start the next round of computation immediately after it finishes its first computation.
- Let's illustrate this. We'll suppose we have 4 GPUs and an input of 4 tokens. Initially, the input sequence is split evenly along the sequence dimension, so each GPU will have just one token along with its corresponding Q/K/V values. Leyt's say Q1, K1, and V1 represent the query, key, and value of the first token, which are located on the 1st GPU. The attention calculation will take 4 time steps to complete. At each time step, each GPU performs these three successive operations:
+ Let's illustrate this. We'll suppose we have four GPUs and an input of four tokens. Initially, the input sequence is split evenly along the sequence dimension, so each GPU will have just one token along with its corresponding Q/K/V values. Let's say Q1, K1, and V1 represent the query, key, and value of the first token, which are located on the first GPU. The attention calculation will take four time steps to complete. At each time step, each GPU performs these three successive operations:
- Send “current keys and values” to the next machine except during the last time step in a non-blocking manner so we can starts the following step before this step is finished
- Locally compute the attention score on the “current keys and values” it already has, which typically involves performing Softmax(\frac{QK^T}{\sqrt{d}}) * V .
- Wait to receive keys and values from the previous GPU and then circle back to step 1. where “current keys and values” are now the key/values just received from the previous GPU.
+ Send current keys and values to the next machine (in all but the last time step) in a non-blocking manner, so we can start the following operation before this one is finished.
+ Locally compute the attention score on the current keys and values, which typically involves performing Softmax(\frac{QK^T}{\sqrt{d}}) * V .
+ Wait to receive keys and values from the previous GPU, and then circle back to step 1, where the current keys and values are now the key/values just received.
- We perform these 3 steps four times to complete the attention calculation.
+ We perform these three steps four times to complete the attention calculation.
- The whole process with 4 GPUs is shown in the following animation:
+ The whole process with four GPUs is shown in the following animation:
- It's probably obvious to you on this animation why the authors chose to call this approach Ring Attention.
+ It's probably obvious to you from this animation why the authors chose to call this approach Ring Attention !
- There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU coming from the shape of the causal attention matrix. Let’s take a look at the SoftMax computation by considering the attention score matrix with the causal attention mask:
+ There is one big problem, though, which is that a naive implementation of Ring Attention leads to some strong imbalances between GPUs due to the shape of the causal attention matrix. Let’s take a look at the softmax computation by considering the attention score matrix with the causal attention mask:
- The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.
+ The softmax is computed row-wise, which means whenever a GPU has received all the tokens of a row, it can be computed. We see that GPU 1 can immediately compute it, as it starts with tokens 1-4 and doesn’t need to receive any information from any other GPUs. However, GPU 2 will need to wait for the second round to receive tokens 1-4 and thus have all the values for tokens 1-8. GPU 1 also seems to perform much less work than all the other GPUs.
- Let’s see if we can balance our computations better:
+ Let’s see if we can balance our computations better.
- Zig-Zag Ring Attention – A Balanced Compute Implementation
-
- We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag attention and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
+ Zig-Zag Ring Attention – A balanced compute implementation
+ We need a better way to distribute the input sequences. This can be achieved by not assigning the tokens to the GPUs in a purely sequential manner, but instead mixing up the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag Attention. In this new arrangement, the attention mask will show an even distribution of computation, but if you count the number of colored squares, you'll see that the computation is now balanced across all GPUs.
+
+ We show here Zig-Zag Attention, which slightly differs from Striped Attention . For details on the differences, check this GitHub discussion .
+
- At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
-
- We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:
+ You’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
+ We have two general ways to overlap computation and communication: either by performing a general all-gather, regrouping all the keys and values on each GPU at the same time (in a ZeRO-3 type of way), or by gathering them from each GPU on each GPU as needed.
+
- The key difference between these two implementations lies in their communication patterns and memory usage:
+ The key differences between these two implementations lie in their communication patterns and memory usage:
- 1. AllGather Implementation:
+ 1. All-gather implementation:
- All GPUs simultaneously gather the complete key/value pairs from all other GPUs
- Requires more temporary memory as each GPU needs to store the full KV pairs at once
- Communication happens in one step but with larger memory overhead
+ All GPUs simultaneously gather the complete key/value pairs from all other GPUs.
+ Requires more temporary memory as each GPU needs to store all the K/V pairs at once.
+ Communication happens in one step but with larger memory overhead.
- 2. All-to-All (Ring) Implementation:
+ 2. All-to-all (ring) implementation:
- GPUs exchange KV pairs in a ring-like pattern, one chunk at a time
- More memory efficient as each GPU only needs to store one additional chunk temporarily
- Communication is spread out and overlapped with computation, though with some additional base latency overhead from multiple communication steps
+ GPUs exchange K/V pairs in a ring-like pattern, one chunk at a time.
+ More memory-efficient, as each GPU only needs to store one additional chunk temporarily.
+ Communication is spread out and overlapped with computation, though with some additional base latency overhead from multiple communication steps.
- The All-to-All approach generally offers better memory efficiency at the cost of slightly more complex communication patterns, while the AllGather approach is simpler but requires more temporary memory during the attention computation.
+ The all-to-all approach generally offers better memory efficiency at the cost of a slightly more complex communication pattern, while the all-gather approach is simpler but requires more temporary memory during the attention computation.
We've now seen how we can split a model across one node with TP to tame large models and that we can use CP to tame the activation explosion with long sequences.
- However, we still know that TP doesn't scale well across nodes, so what can we do if the model weights don't easily fit on 1 node? Here come another degree of parallelism, our forth one, called Pipeline Parallelism , to the rescue!
+ However, we still know that TP doesn't scale well across nodes, so what can we do if the model weights don't easily fit on one node? Pipeline parallelism - our fourth degree of parallelism - to the rescue!
Pipeline Parallelism
@@ -1451,8 +1452,7 @@
-
- In the Tensor Parallelism section we saw that trying to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we benchmark it on our cluster across several nodes (each node has 8 GPUs):
+ In the "Tensor Parallelism" section, we saw that trying to scale tensor parallelism past the number of GPUs on a single node - typically 4 or 8 - forces us to use lower-bandwidth network communication, which can significantly impair performance. We can see the effects of this inter-node communication clearly in the all-reduce operation when we benchmark it on our cluster across several nodes (each node here has 8 GPUs):
@@ -1465,13 +1465,13 @@
-->
- Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
+
- Sequence and context parallelism can help for long sequences but don’t help much if the sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
+ Sequence and context parallelism can help for long sequences, but they don’t help much if the root cause of our memory issues is not the sequence length but rather the size of the model itself. For large models (70B+ parameters), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning another parallelism dimension: pipeline parallelism (PP).
- Pipeline parallelism is a simple but powerful technique - we split our model's layers across multiple GPUs! For example, if we have 8 GPUs, we could put layers 1-4 on GPU 1, layers 5-8 on GPU 2, and so on. This way, each GPU only needs to store and process a portion of the model's layers, significantly reducing the memory requirements per GPU. Let's see the effect of Pipeline Parallelism in action on the memory usage for a 8B model:
+ Pipeline parallelism is a simple but powerful technique - we split our model's layers across multiple GPUs! For example, if we have 8 GPUs, we could put layers 1-4 on GPU 1, layers 5-8 on GPU 2, and so on. This way, each GPU only needs to store and process a portion of the model's layers, significantly reducing the memory requirements per GPU. Let's see the effect of pipeline parallelism in action on the memory usage for an 8B parameter model:
- This technique may remind you of our discussion on ZeRO-3 where we split the model parameters across GPUs. We compare both techniques in details later in the 5D parallelism in a nutshell section.
+ This technique may remind you of our discussion of ZeRO-3 , where we split the model parameters across GPUs. We compare both techniques in detail later, in the "5D Parallelism in a Nutshell" section.
@@ -1486,115 +1486,115 @@
Looking at the figure above, we notice something interesting: while the model parameters are nicely split across GPUs, the activation memory remains the same on each GPU! This is because each GPU still needs to process the full batch of data, just with different layers. The activations from one GPU's layers will be sent to the next GPU to continue the forward pass.
- This introduces a new type of communication pattern: instead of communicating parameters like we did with ZeRO-3 in data parallelism, we're now passing activation tensors sequentially between GPUs in a "pipeline". While conceptually simple, efficiently implementing this technique is quite tricky. Let's dive right into the details!
+ This introduces a new type of communication pattern: instead of communicating parameters like we did with ZeRO-3 in data parallelism, we're now passing activation tensors sequentially between GPUs in a "pipeline." While conceptually simple, efficiently implementing this technique is quite tricky. Let's dive right into the details!
Splitting layers on various nodes - All forward, all backward
- So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.
+ To start, let’s say we simply spread the layers across several devices - e.g., a first GPU will take the first few layers, a second GPU will take the second part of the model, and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.
- We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. It can make a huge difference versus e.g. communications in Tensor Parallelism, which happens several times within each layer.
+ We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of locations along the model depth. This can make a huge difference compared to, for example, the TP approach, where communications happen several times within each layer.
- But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively” ?!? This doesn’t sound very efficient in the world of parallel computations, especially after our discussion on computation and communication overlap.
+ But you may be starting to catch a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computations, especially after our discussion of computation and communication overlap.
- Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model (here the numbers indicate the model layers):
+ Indeed, reader! The main challenge in pipeline parallelism is how to efficiently circumvent the sequential nature of PP to keep our GPUs busy at all times and avoid having one GPU computing while the others are waiting. Here's how our GPU utilization looks when doing a naive and simple forward and backward pass through the model (here, the numbers indicate the model layers):
- An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.
+
-
The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.
+
The remaining idle time is indicated in gray and usually called the “bubble.” The sight of this probably broke your heart after we spent so much time optimizing throughput.
-
We can quantify how efficient a pipeline setup is by looking at how much time we lose because of the bubble. Let’s say t_f and t_b are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have t_b \approx 2 \times t_f which you can see on the above graph). If we could perfectly parallelize the ideal total time would be t_{id}=t_f + t_b . However, we can count on the graph that due to the pipeline bubble there is additional time of t_{pb}=(p-1)*(t_f+t_b) (where p is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.
+
We can quantify how efficient a pipeline setup is by looking at how much time we lose because of the bubble. Let’s say t_f and t_b are the times for the forward and backward passes, respectively, as measured for one micro-batch and one stage of the pipeline (a simple assumption is often to have t_b \approx 2 \times t_f , as in the above graph). If we could perfectly parallelize, the ideal total time would be t_{id}=t_f + t_b . However, in this example due to the pipeline bubble there is additional time of t_{pb}=(p-1)*(t_f+t_b) (where p is the degree of pipeline parallelism; i.e., the number of GPUs). This is the time each GPU is waiting while other GPUs are computing.
-
We can compute the ratio of the additional bubble time over the ideal time:
+
We can compute the ratio of the additional bubble time over the ideal time as follows:
r_{bubble} = \frac{(p-1)*(t_f+t_b)}{t_f+t_b} = p-1
-
As we add more stages the bubble time thus increases and the utilization drops. As we can see, the bubble can be very large in a naive implementation!
-
Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble .
+
As we add more stages, the bubble time thus increases and the utilization drops. As we can see, the bubble can be very large in a naive implementation!
+
Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble.
-
Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:
+
Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bite-sized portions that can be processed in parallel (or almost), like we did before in the DP approach, for instance. Now, when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using eight micro-batches:
-
Before the numbers in the diagram indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
+
Before, the numbers in the diagram indicated the layers, but in all pipeline parallel plots from here on they indicate micro-batches. You can think of each square here as containing several layers, as seen in the previous figure.
-
The above schedule is called the all-forward-all-backward (AFAB) schedule as we first do all forward passes and then only all-backward passes. The advantage is that forward and backward steps are still generally sequential and so we're preserving the general organization of our model training code. It makes this PP implementation one of the simplest to implement.
+
The above schedule is called the all forward, all backward (AFAB) schedule, as we first do all the forward passes and then all the backward passes. The advantage is that forward and backward steps are still generally sequential, so we're preserving the general organization of our model training code. This PP implementation is one of the simplest to implement.
-
You can find the full implementation of the AFAB pipeline in picotron:
+
You can find the full implementation of the AFAB pipeline in Picotron:
- 👉 AFAB PP implementation in Picotron (Click to expand)
+ 👉 AFAB PP implementation in Picotron (click to expand)
- Let’s estimate the bubble in this example. The difference with our first example is that the ideal time to process m microbatches is now t_{id} = m*(t_f+t_b) :
+ Let’s estimate the bubble in this example. The difference from our first example is that the ideal time to process m micro-batches is now t_{id} = m*(t_f+t_b) :
r_{bubble} = \frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{m}
- As we can see, we can fight some inefficiencies of pipeline stages by adding more microbatches, reducing the size of the bubble by a factor of m .
+ As we can see, we can fight some of the inefficiencies of pipeline stages by adding more micro-batches, reducing the size of the bubble by a factor of m .
- However, as annoying as the bubble is the memory storage required for storing all activation. We need to keep all of the activations in memory until we reach the backward stage which lead to a quick memory explosion in these implementations of PP. Can we do better and avoid this memory explosion?
+ However, just as annoying as the bubble is the memory required for storing all the activations. We need to keep all of the activations in memory until we reach the backward stage, which quickly leads to a memory explosion in these implementations of PP. Can we do better and avoid this issue?
- Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.
+ Since the memory explosion is triggered by the activations we store for the backward pass, let’s see if we can start performing the backward pass while we are still performing the forward part of the computation. This will allow us to drop some of the activations needed for the backward pass as soon as possible.
- One-forward-one-backward and LLama 3.1 schemes
+ One forward, one backward and Llama 3.1 schemes
- This schedule is called one-forward-one-backward (1F1B) as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
+ This schedule is called one forward, one backward (1F1B) because the middle/steady state involves alternately performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
- If you count carefully you'll see that the bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for p micro-batches (where p is the degree of pipeline parallelism) instead of m (where m was the number of microbatches) which can reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
+ If you count carefully, you'll see that the bubble still has the same size, so our training efficiency is not significantly improved. However, we only need to store activations for p micro-batches (where p is the degree of pipeline parallelism) instead of m (where m is the number of micro-batches), which can reduce the activation memory explosion we had in the AFAB schedule. As a consequence, we can add more micro-batches, which then will actually reduce the bubble.
- A major complexity of this setup, visible on the above graph is how forward and backward passes are not anymore cleanly sequential but performed in parallel across devices and interleaved. This means we will have to schedule a switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
+ A major complexity of this setup, visible in the above figure, is how forward and backward passes are not cleanly sequential anymore but rather are performed in parallel across devices and interleaved. This means we will have to schedule a switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
- This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.
+ This is one of the reasons implementing pipeline parallelism usually requires rather extensive modifications to training code as well as modeling code.
- You can find a full implementation of 1F1B in picotron as well:
+ You can find a full implementation of 1F1B in Picotron as well:
- 👉 1F1B PP implementation in Picotron (Click to expand)
+ 👉 1F1B PP implementation in Picotron (click to expand)
- Let's take a look at how the 1F1B Pipeline Parallelism schedule scales in practice with some benchmarks on our cluster:
-
-
+ Let's take a look at how the 1F1B pipeline parallelism schedule scales in practice with some benchmarks on our cluster:
+
+
- On the left, with a number of microbatches equal to –or less than– PP degree minus one (m = p - 1 ), we see how detrimental the pipeline bubble can be - performance are low and even drops as we scale PP. The right plot shows that using many more microbatches than PP degree (m = 32 \gg p - 1 ) helps improve low-PP-degree performances while still staying limited at very large PP degree. In practice it's not possible to arbitrarly increase the number of microbatches to maintain the ratio of m \gg p - 1 since we're ultimately constrained by the target global batch size. With a maximal possible number of microbatches as we add more PP degree, we'll ultimately have to increase the bubble size according to r_{bubble} = \frac{p - 1}{m} .
+ On the left, with a number of micro-batches equal to or less than the PP degree minus one (m = p - 1 ), we see how detrimental the pipeline bubble can be - performance is low and even drops as we scale PP. The righthand plot shows that using many more micro-batches than the PP degree (m = 32 \gg p - 1 ) helps improve low-PP-degree performance, though it's still limited at very large PP degrees. In practice, it's not possible to arbitrarily increase the number of micro-batches to maintain the ratio of m \gg p - 1 since we're ultimately constrained by the target global batch size. With a maximal possible number of micro-batches as we add more PP degrees, we'll ultimately have to increase the bubble size according to r_{bubble} = \frac{p - 1}{m} .
- Interestingly, at small number of micro-batches the performance only drops by 14% when scaling from one node (p = 8 ) to two nodes (p = 16 ) - a much better scaling than Tensor Parallelism which typically sees around 43% performance degradation in similar cross-node scenarios. This type of behavior when hitting the lower-bandwith inter-node network makes Pipeline Parallelism particularly attractive for distributed training across multiple nodes.
+ Interestingly, at a small number of micro-batches the performance only drops by 14% when scaling from one node (p = 8 ) to two nodes (p = 16 ) - a much better scaling than we achieve with tensor parallelism, which typically sees around 43% performance degradation in similar cross-node scenarios. This type of behavior when hitting the lower-bandwidth inter-node network makes pipeline parallelism particularly attractive for distributed training across multiple nodes.
- While 1F1B significantly reduces our activation memory footprint, we see on this last graph that the pipeline bubble remains a major efficiency bottleneck. With the bubble size still proportional to the number of pipeline stages, we're leaving valuable GPU compute idle. Can we design an even smarter schedule to minimize this wasted computation time?
+ While 1F1B significantly reduces our activation memory footprint, we see in this last graph that the pipeline bubble remains a major efficiency bottleneck. With the bubble size still proportional to the number of pipeline stages, we're leaving valuable GPU compute idle. Can we design an even smarter schedule to minimize this wasted computation time?
Interleaving stages
- The 1F1B schedule has let us improved memory usage but not much the size of the idle buddle. Any way we could still push this frontier?
+ The 1F1B schedule let us improve memory usage but didn't have much effect on the size of the idle bubble. Is there any way we can push this frontier?
- Well it turns out this is possible if we are willing to bring in a few additional communication operations. Time to talk about interleaved stages .
+ It turns out this is possible if we are willing to bring in a few additional communication operations. Time to talk about interleaved stages !
- Up to now we’ve sliced our model naively along the model depth dimensions, hosting for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.
+ Up to now, we’ve sliced our model naively along the model depth dimensions, hosting for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, such as having odd layers (1, 3, 5, 7) on the first GPU and even layers (2, 4, 6, 8) on the second GPU.
- This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model. Let's take a graphical look at how this works:
+ This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model. Let's take a look at how this works:
- An example of interleaved pipeline parallelism for a model with layers distributed across 4 GPUs. Numbers still correspond to the microbatches IDs but for clarity we've colored differently the first and the last layers of the model to illustrate how layers are spread across GPUs.
+
-
As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of v , where v is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.
+
Additional communications are required here, as the model goes through each GPU several times for the same computation that previously took just one pass. However, each forward and backward pass is divided by a factor of v , where v is the number of stages or model chunks per GPU, as we are able to better interleave forward and backward passes:
@@ -1605,7 +1605,7 @@
-
So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by v so it’s a trade off. In the following plot you can see several configurations for a PP setup with p=8 , where the special case of m=1, v=1 corresponds to naive pipeline parallelism and the configurations with v=1 are AFAB or 1F1B setups and v \neq 1 are interleaved configurations.
+
So, we can now decrease the bubble by adding micro-batches and interleaved stages - but note that quantitatively, the amount of communication also increases by v so it’s a trade-off. In the following plot, you can see several configurations for a PP setup with p=8 , where the special case of m=1, v=1 corresponds to naive pipeline parallelism, the configurations with v=1 are AFAB or 1F1B setups, and the v \neq 1 cases are interleaved configurations.
@@ -1619,94 +1619,94 @@
-
Scheduling also becomes more complex here as we have to decide on a given GPU and at a given moment whether we are prioritizing earlier micro-batches going through later layers –meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible)– or if we prioritize to first have later micro-batches going through earlier layers (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This choice is explained in detail in the nice "Breadth-Fist Pipeline" paper .
+
Scheduling also becomes more complex here, as we have to decide on a given GPU and at a given moment whether we are prioritizing earlier micro-batches going through later layers – meaning that we close the forward and backward loops as fast as possible (the “depth-first” approach, which prioritizes getting batches out of the model as fast as possible) – or later micro-batches going through earlier layers (the “breadth-first” approach, which prioritizes filling in the pipeline as much as possible). This choice is explained in detail in the "Breadth-Fist Pipeline Parallelism" paper .
-
You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and breadth-first.
+
You now have all the elements to understand the pipeline parallelism approach in Llama 3.1, which uses a 1F1B setup with interleaved stages and a priority setting tunable between depth-first and breadth-first:
-
However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero ! These techniques were for instance used in the DeepSeek V3/R1 implementation . Peaked your curiosity? Let’s have a final quick look at these magical schedules before we leave the world of Pipeline Parallelism!
+
However, we haven’t reached the end of the possible pipeline schedules, and recently some methods have been proposed to reduce the bubble to virtually zero ! These techniques were, for instance, used in the DeepSeek-V3/R1 implementation . Piqued your curiosity? Let’s have a final quick look at these magical schedules before we leave the world of pipeline parallelism!
-
Zero Bubble and DualPipe
+
Zero bubble and DualPipe
-
Even more sophisticated ways to reduce the bubble have recently been proposed which reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe, reaches close to a zero bubble regime.
+
Even more sophisticated ways to reduce the bubble have recently been proposed that reach close to a “zero bubble” regime, such as the pipeline implementation approach in DeepSeek-V3/R1, called DualPipe. The secret here is to split the operations involved at an even finer-grained level in order to interleave them in the most efficient way.
-
Ultimate "flex" in DeepSeek V3 technical report where the authors indicate that their setup "achiev[ed] a near-zero all-to-all communication overhead".
+
In the DeepSeek-V3 technical report , the authors indicate that their setup achieved "a near-zero all-to-all communication overhead."
-
Let’s briefly see how this can work by summarizing the ZeroBubble work which is a precursor to DualPipe. The base observation of ZeroBubble is that the backward pass through a matrix multiplication actually involves two separated operations: backward operation for the inputs (B) and the backward operation for the weights (W):
+
Let’s briefly see how this can work by summarizing Sea AI Lab's zero bubble work , which is a precursor to DualPipe. The basic observation here is that the backward pass through a matrix multiplication actually involves two separate operations: the backward operation for the inputs (B ) and the backward operation for the weights (W ).
-
While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only needs to be performed before the optimiser step. We can see that in the following diagram:
+
While the output of B , the backward pass for the inputs, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W , is not and generally only needs to be performed before the optimizer step. We can see that in the following diagram (from the Zero Bubble paper):
-
This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
+
This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule at the top right is an example of a (theoretical) schedule with zero bubble taking advantage of this fine-grained decomposition.
-
On the top (Figure 2 from the ZeroBubble paper): the classical 1F1B schedule, interleaving forward and backward pass but keeping a coarse-grained backward pass. On the bottom two graphs (Figure 3 from the ZeroBubble paper), two variantes of the ZeroBubble schedule, splitting the backward operation in a "B" and a "W" finer-grained operations. The last schedule, so-called "ZB-H2" is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
+
-
DeepSeek’s DualPipe introduced with its V3 technical report an extension of this decomposed approach to the additional case of two streams propagating from both ends of the PP dimension, these streams being interleaved to minimize even further idle time in the GPUs. This schedule is displayed in the following scheduling graph and is even more complex than the previous ones:
+
DeepSeek’s DualPipe, introduced with its V3 technical report , is an extension of this decomposed approach to the additional case of two streams propagating from both ends of the PP dimension, with these streams being interleaved to further minimize idle time in the GPUs. This schedule is displayed in the following scheduling graph - as you can see, it's even more complex than the previous ones:
-
In general, fully optimizing such complex schedules involve carfully measuring the duration of the various fine-grained operations and solving a ILP to minimize the final bubble time. See for instance in the ZeroBubble paper for a discussion of the heuristics and algorithms to perform such a scheduling. As a result, the ZeroBubble and DualPipe schedules are too complex for us to give here code snippets but you should start to have a general idea of the concepts involved.
+
In general, fully optimizing such complex schedules involves carefully measuring the duration of the various fine-grained operations and solving an Integer Linear Programming (ILP) problem to minimize the final bubble time. (See, for instance, the Zero Bubble paper for a discussion of the heuristics and algorithms used to perform such scheduling.) As a result, the zero bubble and DualPipe schedules are too complex for us to give code snippets here, but you should have a general idea of the concepts involved.
-
This concludes our tour into the world of pipeline schedules and bubbles. We hope you enjoyed this guided tour!
+
This concludes our tour of the world of pipeline schedules and bubbles. We hope you enjoyed it!
-
It's now time to turn to the last parallelism method we'll detail and which we can use to train large models efficiently: Expert parallelism .
+
It's now time to turn to the last parallelism method we'll detail, which we can use to train large models efficiently: expert parallelism .
-
Expert parallelism
+
Expert Parallelism
-
This is our last parallelism method to discuss. Before tackling it, if you don't have any exposure to Mixture-of-Experts, feel free to read about them in this previous, much shorter, blog post we published some time ago and which should help you better understand the Mixture-of-Experts (MoE) architecture in general.
+
This is the last parallelism method we're going to discuss. Before tackling it, if you don't have any exposure to Mixture of Experts (MoE) models, you might want to take some time to read about them in this much shorter blog post we published some time ago, which should help you better understand the MoE architecture in general.
-
Mixture-of-expert models have gained recent traction and visibility with models such as GPT-4, Mixtral or more recently DeepSeek-V3/R1. The basic idea is that instead of having a single feedforward module per layer we can have several parallel modules and route tokens through one or the other to be processed differently.
+
The Mixture of Experts paradigm has recently gained traction and visibility with models such as GPT-4, Mixtral , and DeepSeek-V3/R1. The basic idea is that instead of having a single feedforward module per layer, we can have several parallel modules and route tokens through them to be processed differently.
-
Illustrationg of a MoE layer taken from the Switch Transformers paper
+
-
The design of MoE layers makes it actually easy to implement parallelism across the experts dimension for what we will call Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert's feedforward layer on a different worker. Compared to TP it's much more lightweight, since we don't need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert.
+
The design of MoE layers makes it easy to implement parallelism across the experts dimension, for what we call expert parallelism (EP) . Since the feedforward layers are fully independent, we can simply put each expert's feedforward layer on a different worker. Compared to TP, this approach is much more lightweight, since we don't need to split the matrix multiplication; we just need to route the hidden states of a token to the right expert.
-
In practice, EP will typically be used in conjunction with other forms of parallelism - for instance Data Parallelism. This is because EP only affects the MoE layers and doesn't shard the input tokens (unlike Context Parallelism which shards tokens along the sequence length dimension). This means our GPUs would be doing redundant compute for all the non-MoE blocks if we only used EP. By combining EP with DP, we can efficiently shard both the experts and the input batches across our GPUs, as we can see in the simplified diagram below:
+
In practice, EP is typically used in conjunction with other forms of parallelism, such as data parallelism. This is because EP only affects the MoE layers and doesn't shard the input tokens (unlike context parallelism, which shards tokens along the sequence length dimension). This means our GPUs would be doing redundant computation for all the non-MoE blocks if we only used EP. By combining EP with DP, we can efficiently shard both the experts and the input batches across our GPUs, as you can see in the simplified diagram below:
-
Source: A Survey on Mixture of Experts
+
-
But let's not get ahead of ourselves - our following section will specifically talk about all the interactions between different parallelism strategies, so don't worry if you don't understand yet this last diagram.
+
But let's not get ahead of ourselves - we'll talk about all the interactions between different parallelism strategies in the following section, so don't worry if you don't understand this last diagram yet.
-
In practice, there are a few tricks to make EP work efficiently and they are closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to keep the tokens on a single node and reduce communication overhead. While Expert parallelism has been around for a while it is just now gaining new traction with the MoE architecture gaining more traction.
+
In practice, there are a few tricks to make EP work efficiently, and they are closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to keep the tokens on a single node and reduce communication overhead. While expert parallelism has been around for a while , it is just now gaining new traction with the MoE architecture gaining popularity.
-
We plan to add a more complete example of EP in picotron/nanotron soon, so stay tuned for more!
+
We plan to add a more complete example of EP in Picotron/Nanotron soon, so stay tuned for more!
-
5D parallelism in a nutshell
+
5D Parallelism in a Nutshell
-
Congratulation reader, you have now seen all 5 parallelism strategies you can use to scale model training:
+
Congratulations, reader! You have now seen all five parallelism strategies you can use to scale model training:
- Data Parallelism (DP) – along the batch dimension
- Tensor Parallelism (TP) - along the hidden dimension
- Sequence and Context Parallelism (SP/CP) - along the sequence dimension
- Pipeline Parallelism (PP) - along the model layers
- Expert Parallelism (EP) - along the model experts
+ Data parallelism (DP) – along the batch dimension
+ Tensor parallelism (TP) - along the hidden dimension
+ Sequence and context parallelism (SP/CP) - along the sequence dimension
+ Pipeline parallelism (PP) - along the model layers
+ Expert parallelism (EP) - along the model experts
-
As well as the 3 ZeRO strategies which can be combined with Data Parallelism for memory reduction:
+
as well as the three ZeRO strategies that can be combined with data parallelism for memory reduction:
ZeRO-1 – sharding optimizer states among the DP replicas
ZeRO-2 – sharding optimizer states and gradients among the DP replicas
- ZeRO-3 – sharding optimizer states, gradients and parameters among the DP replicas
+ ZeRO-3 – sharding optimizer states, gradients, and parameters among the DP replicas
-
At this stage, one aspect you are probably curious about is how all these parallelism and ZeRO strategies compare to, and interact with, each other. In other words, which ones should we use and efficiently combine together, and which ones should we rather keep separated?
+
At this stage, one aspect you are probably curious about is how all these parallelism and ZeRO strategies compare to, and interact with, one another. In other words, which ones can we use and efficiently combine together, and which ones should we keep separated?
-
Let’s take a look at the similarities and interplay. We'll start by comparing Pipeline parallelism are ZeRO-3 side-by-side as they have some very close similarities but also important differences.
+
Let’s take a look at the similarities and interplay. We'll start by comparing pipeline parallelism are ZeRO-3 side-by-side, as they have some very close similarities but also important differences.
-
Pipeline parallelism vs. ZeRO-3 - Both PP and ZeRO-3 are ways to partition the model weights over several GPUs and perform communication/computation along the model depth axis (for example in ZeRO-3, we prefetch the next layer while computing). This means in both cases full layer operations are computed on each device, as opposed to TP or EP for instance in which computation are performed on sub-layer units.
-
In the following we say “a layer” to simplify what should be in general called “a set of layer” (as the basis sharding unit of the model).
-
-
However, there are a few major differences between PP and ZeRO-3 approaches:
+
Both pipeline parallelism and ZeRO-3 are ways to partition the model weights over several GPUs and perform communication/computation along the model depth axis (for example, in ZeRO-3, we prefetch the next layer while computing). This means in both cases full layer operations are computed on each device, as opposed to with TP or EP, for instance, in which computations are performed on sub-layer units.
+
+
However, there are a few major differences between the PP and ZeRO-3 approaches:
-
+
Note here that we say "a layer" to simplify, but the actual sharding unit of the model can vary - it might be multiple layers, a single layer, or even a subset of a layer, depending on the specific implementation.
+
@@ -1717,19 +1717,19 @@
- Each compute unit stores
+ Each compute unit stores...
only a fraction of a layer
a full layer
- Communication is used to transfer
+ Communication is used to transfer...
weights
activations
Orchestration
- model agnostic
- model agnostic
+ Model-agnostic
+ Model-agnostic
Implementation challenges
@@ -1739,17 +1739,17 @@
Scaling considerations
Prefers large mbs and seq\_len to hide comms
- Prefers large \text{grad\_acc} to hide bubble
+ Prefers large grad\_acc to hide bubble
-
As you can see, ZeRO-3 and PP solve the same challenge but involve different approaches and the choice between both will depend whether you decide to focus communication either on weights or on activations. While they can be combined, it's not often done in practice as doing so requires increasing the global batch size significantly to amortize the communication costs, creating a tradeoff between global batch size, model size, network bandwidth, and training efficiency. If you decide to combine them, ZeRO-3 should be configured to keep the weights in memory during the series of PP micro-batches to minimize as much as possible un-necessary communication overhead.
+
As you can see, ZeRO-3 and PP solve the same challenge but involve different approaches, and the choice between them will depend on whether you decide to focus communication on transferring weights or activations. While they can be combined, it's not often done in practice as doing so requires increasing the global batch size significantly to amortize the communication costs, creating a trade-off between global batch size, model size, network bandwidth, and training efficiency. If you decide to combine them, ZeRO-3 should be configured to keep the weights in memory during the series of PP micro-batches to minimize as much as possible unnecessary communication overhead.
-
On the other hand, ZeRO-1 and ZeRO-2, which focus on optimizer states and gradients, can be easily combined with Pipeline Parallelism and are complementary to it. Combining them don't raise any particular new challenge. For instance, the training of DeepSeek-v3 used PP combined with ZeRO-1 (sic).
+
On the other hand, ZeRO-1 and ZeRO-2, which focus on optimizer states and gradients, can be easily combined with pipeline parallelism and are complementary to it. These combinations don't raise any particular new challenges. For instance, the training of DeepSeek-v3 used PP combined with ZeRO-1 (sic) .
-
Tensor Parallelism (with Sequence Parallelism) is naturally complementary and can be combined with both Pipeline Parallelism and ZeRO-3 as it relies on the distributive property of matrix multiplications which allows weights and activations to be sharded and computed independently before being combined.
+
Tensor parallelism (with sequence parallelism ) is naturally complementary to and can be combined with both pipeline parallelism and ZeRO-3, as it relies on the distributive property of matrix multiplications, which allows weights and activations to be sharded and computed independently before being combined.
@@ -1758,13 +1758,13 @@
-
The main reason we don't want to use TP only for parallelism is that, in practice, TP has two limitations we've discussed in the previous sections: First, since its communication operations are part of the critical path of computation, it's difficult to scale well beyond a certain point at which communication overhead begins to dominate. Second, unlike ZeRO and PP which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more cumbersome to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.
+
The main reason we don't want to use TP only for parallelism is that, in practice, TP has two limitations (which we've discussed in previous sections). First, since its communication operations are part of the critical path of computation, it's difficult to scale well beyond a certain point, after which communication overhead begins to dominate. Second, unlike ZeRO and PP, which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more cumbersome to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.
-
As a consequence, when combining parallelism strategies, TP will typically be kept for high-speed intra-node communications while ZeRO-3 or PP can be used for parallelism groups spanning lower speed inter-node communications as their communication patterns require less bandwidth (for PP) or can be more easily overlapped with computation (for ZeRO-3). The main consideration when combining these techniques is to organize the GPU efficiently in groups for each parallelism dimension to maximize throughput and minimize communication overhead, while being mindful of TP's scaling limitations. For instance, the groups of GPUs communicating for TP should be kept inside nodes.
+
As a consequence, when combining parallelism strategies, TP will typically be kept for high-speed intra-node communications, while ZeRO-3 or PP can be used for parallelism groups spanning lower-speed inter-node communications as their communication patterns require less bandwidth (for PP) or can be more easily overlapped with computation (for ZeRO-3). The main consideration when combining these techniques is to organize the GPUs efficiently in groups for each parallelism dimension to maximize throughput and minimize communication overhead, while being mindful of TP's scaling limitations. For instance, the groups of GPUs communicating for TP should be kept inside nodes.
-
Context Parallelism and Expert Parallelism also help us shard activations, and can be seen as complimentary to TP. The first one handles long sequences while the second enables distributed Mixture of Experts training and they can be combined together without any particular issue.
+
Context parallelism and expert parallelism also help us shard activations, and can be seen as complementary to TP. CP handles long sequences while EP enables distributed Mixture of Experts training, and they can be combined without any particular issues.
-
Context Parallelism (CP) specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most operations like MLPs and LayerNorm can process these sharded sequences independently, attention layers require communication since each token needs access to keys/values from the full sequence. As we saw in CP section , this is handled efficiently through ring attention patterns that overlap computation and communication. CP is particularly valuable when scaling to extreme sequence lengths (128k+ tokens) where, even when using full activation recomputation, the memory requirements for attention would be prohibitive on a single GPU.
+
CP specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most modules, like MLP and LayerNorm, can process these sharded sequences independently, attention blocks require communication since each token needs access to keys/values from the full sequence. As we saw in the CP section , this is handled efficiently through Ring Attention patterns that overlap computation and communication. CP is particularly valuable when scaling to extreme sequence lengths (128k+ tokens) where, even when using full activation recomputation, the memory requirements for attention would be prohibitive on a single GPU.
-
Expert Parallelism (EP) specifically targets the challenge of training Mixture of Experts (MoE) models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication operation in EP is the `all-to-all` operations routing tokens to their assigned experts and gathering the results back. While this operation introduces some communication overhead, it enables scaling model capacity significantly since each token is only processed during inference (and training) by a much smaller fraction of the total parameters. In terms of distributed training/inference, partitioning experts across GPUs becomes relevant when models scales to a large number of experts.
-
For instance DeepSeek V3 uses 256 experts.
+
Expert parallelism specifically targets the challenge of training MoE models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication operations in EP are the "all-to-all" operations routing tokens to their assigned experts and gathering the results back. While this introduces some communication overhead, it enables scaling model capacity significantly since each token is only processed during inference (and training) by a much smaller fraction of the total parameters. In terms of distributed training/inference, partitioning experts across GPUs becomes relevant when models scales to a large number of experts.
+
For instance, DeepSeek-V3 uses 256 experts.
@@ -1783,17 +1783,18 @@
📝 Note
-
This similarity between EP and DP in terms of input handling is why some implementations consider Expert Parallelism to be a subgroup of Data Parallelism, with the key difference being that EP uses specialized expert routing rather than having all GPUs process inputs through identical model copies.
+
This similarity between EP and DP in terms of input handling is why some implementations consider expert parallelism to be a subset of data parallelism, with the key difference being that EP uses specialized expert routing rather than having all GPUs process inputs through identical model copies.
-
Scope and focus Let's also quickly summarize the sub-part of the model where some of these different parallelism strategies have the most impact:
+
Scope and focus
+
Let's also quickly summarize the sub-parts of the model where these different parallelism strategies have the most impact:
- Tensor Parallelism (and Sequence Parallelism) affects computation throughout the entire model by sharding both weights and activations.
- Context Parallelism primarily impacts attention layers since that's where cross-sequence communication is required, with other layers operating independently on sharded sequences.
- Expert Parallelism primarly affects the MoE layers (which replace standard MLP blocks), leaving attention and other components unchanged
- Pipeline Parallelism and ZeRO are not especially specific to any sub-module or component with the exception that modules and layers need to be balanced in Pipeline Parallelism, the first and last layers are thus often treated differently due to the additional embedding layers.
+ Tensor parallelism (and sequence parallelism) affects computation throughout the entire model by sharding both weights and activations.
+ Context parallelism primarily impacts attention layers, since that's where cross-sequence communication is required, with other layers operating independently on sharded sequences.
+ Expert parallelism primarily affects the MoE layers (which replace standard MLP blocks), leaving attention layers and other components unchanged.
+ Pipeline parallelism and ZeRO are not especially specific to any submodule or component, with the exception that modules and layers need to be balanced in pipeline parallelism (the first and last layers are thus often treated differently due to the additional embedding layers).
@@ -1806,52 +1807,53 @@
- shards weights and activations along hidden/seq dim
- shards activations along sequence dim
- shards specialized expert weights and activations
+ Shards weights and activations along hidden/seq dim
+ Shards activations along sequence dim
+ Shards specialized expert weights and activations
- communication for matrix multiply operations (column/row linears)
- communication for attention key/values
- communication for token routing to experts
+ Communication for matrix multiplication operations (column/row linear)
+ Communication for attention keys/values
+ Communication for token routing to experts
- model-specific implementation needed
- model-agnostic except for attention
- model-agnostic except for MoE layers
+ Model-specific implementation needed
+ Model-agnostic except for attention
+ Model-agnostic except for MoE layers
Prefers high-bandwidth intra-node communication
Prefers large sequence lengths
- Requires MoEs
+ Requires MoE layers
-
Summarizing it all– Now what about gathering and combining all the techniques we've seen in a single diagram combining them all. Yes, we're up for the challenge!
-
In this summary diagram, you will find illustrated activations and modules for a single transformers layer –in it's MoE variant–. We also illustrate the various directions of parallelism and the communication operations we've been discussing in all the previous sections.
-
+
Summarizing it all
+
Now, what about gathering all the techniques we've seen into a single diagram combining them all? Yes, we're up for the challenge!
+
In this summary diagram, you will find illustrated activations and modules for a single transformer layer, in its MoE variant. We also illustrate the various directions of parallelism and the communication operations we've been discussing in the previous sections.
+
-
We can also represent side-by-side a full overview of the memory savings for each one of these strategies. We'll plot them with different sequence length as well as with selective (top) and full (bottom) recomputation so you can see how they all play with activations:
-
+
We can also give a full overview of the memory savings for each of these strategies. We'll plot them with different sequence lengths as well as with selective (top) and full (bottom) recomputation so you can see how they all play with activations:
+
-
Let's finish this section with a high level view at all of these techniques, their main underlying idea and major bottleneck:
+
Let's finish this section with a high-level view of all of these techniques, their main underlying ideas, and their major bottlenecks:
Method
- Memory savings applies specifically on
+ Memory savings apply specifically on...
Parallel/sharding dimension
Disadvantage
@@ -1870,22 +1872,22 @@
Idle bubble and complex schedules
- TP/SP
+ TP+SP
Model parameters and activations
- Hidden dimension / Sequence length
- Requires high bandwidth communication
+ Hidden dimension/sequence length
+ Requires high-bandwidth communication
CP
Activations
Sequence length
- Add communication overhead in attention modules
+ Adds communication overhead in attention modules
EP
Experts parameters
- Expert dimension
- Requires MoE layers, add routing communication overhead
+ Experts dimension
+ Requires MoE layers, adds routing communication overhead
ZeRO-1
@@ -1908,90 +1910,90 @@
-
Clearly, none of these techniques is a silver bullet for magical scaling and we'll often have to combine them in one way or another. Can we actually come up with a few rules that would help us find a good starting point to choose among –and combine– them? This will be the topic of our next section.
+
Clearly, none of these techniques is a silver bullet for magical scaling, and we'll often have to combine them in one way or another. Can we actually come up with a few rules that will help us find a good starting point to choose among (and combine) them? This will be the topic of the next section.
Finding the Best Training Configuration
-
We’ve now covered all the parallelism techniques that are actually used to distribute and train larger models as well as how and why they can be combined together. There remain a general question: which ones should we choose in the end and how to decide on a specific combination?
+
We’ve now covered all the parallelism techniques that are actually used to distribute and train larger models, as well as how and why they can be combined together. The general question remains, though: Which ones should we choose, and how do we decide on a specific combination?
-
We touched this a little bit in the previous section but let's now walk in details through a possible decision process, step by step, keeping in mind that you'll always have to run a few experiments to find the definitive optimal setup for your compute cluster given its various physical properties, network bandwidth, GPUs per node, memory per GPU, etc.
+
We touched on this a little bit in the previous section, but let's now walk in more detail through a possible decision process, step by step (keeping in mind that you'll always have to run a few experiments to find the definitive optimal setup for your compute cluster given its various physical properties, network bandwidth, GPUs per node, memory per GPU, etc.).
-
Step 1: Fitting a Training Step in Memory
+
Step 1: Fitting a training step in memory
-
First, we need to figure out how we can fit a full model instance on our GPUs. There are two general cases.
+
First, we need to figure out how we can fit a full model instance on our GPUs. There are two general cases:
-
GPU-rich case 🤑 - when you have plenty of GPUs available:
+
1. GPU-rich case 🤑 - when you have plenty of GPUs available:
- For models under 10B parameters, you can use a single parallelism technique, e.g. Tensor Parallelism or ZeRO-3/DP with Full Recompute across 8 GPUs
+ For models under 10B parameters, you can use a single parallelism technique, e.g. tensor parallelism or ZeRO-3/DP with full recompute across 8 GPUs.
For models between 10B-100B parameters requiring more than 8 GPUs, you have several options:
- Combining Tensor Parallelism (TP=8) with Pipeline Parallelism
- Combining Tensor Parallelism (TP=8) with Data Parallelism (ZeRO-3)
- Using only ZeRO-3 (i.e. only pure Data Parallelism)
+ Combining tensor parallelism (TP=8) with pipeline parallelism
+ Combining tensor parallelism (TP=8) with data parallelism (ZeRO-3)
+ Using only ZeRO-3 (i.e., pure data parallelism)
- At 512+ GPU scale, pure Data Parallelism/ZeRO-3 will start to becomes inefficient due to communication cost - it can be better to then combine DP with either Tensor or Pipeline Parallelism
- At 1024+ GPU scale, a recommended setup can be Tensor Parallelism TP=8 with Data Parallelism (ZeRO-2) and Pipeline Parallelism
+ At 512+ GPU scale, pure data parallelism/ZeRO-3 will start to becomes inefficient due to communication cost - it can be better to then combine DP with either tensor or pipeline parallelism.
+ At 1024+ GPU scale, a recommended setup may be tensor parallelism (TP=8) with data parallelism (ZeRO-2) and pipeline parallelism.
-
We focus on fitting a single instance for now - even though we may use DP for ZeRO to achieve this goal - we're only interested here in the model-parameters memory savings that it provide when used with ZeRO-3.
+
We focus on fitting a single instance for now - even though we may use DP for ZeRO to achieve this goal - we're only interested here in the model parameters memory savings that it provide when used with ZeRO-3.
-
Special considerations:
+
Special considerations:
- For very long sequences, you will probably want to add Context Parallelism (CP) across nodes.
- For Mixture of Experts architectures, you will advantageously use Expert Parallelism (EP) across nodes.
+ For very long sequences, you will probably want to add context parallelism across nodes.
+ For Mixture of Experts architectures, it will be advantageous to use expert parallelism across nodes.
-
GPU-poor case 😭 - when you might be low on GPU resources:
+
2. GPU-poor case 😭 - when you might be low on GPU resources:
- You can enable full activation recomputation to trade some compute for memory (and train a bit slower).
+ You can enable full activation recomputation to trade some compute for memory (and train a bit more slowly).
You can increase gradient accumulation to process larger batches with limited memory.
Now that we have a first model instance training, we need to make sure we have the right batch size.
-
Step 2: Achieving Target Global Batch Size
+
Step 2: Achieving the target global batch size
-
Depending on where step 1 left us in terms of micro batch size and DP, our current batch size might be too small or too big. It's now time to hit our target batch size.
+
Depending on where step 1 left us in terms of micro-batch size and DP, our current batch size might be too small or too big. It's now time to hit our target batch size.
To increase our current global batch size:
- We can scale up Data Parallelism or gradient accumulation steps
- For long sequences, we can leverage Context Parallelism
+ We can scale up data parallelism or gradient accumulation steps.
+ For long sequences, we can leverage context parallelism.
To decrease our current global batch size:
- We can reduce Data Parallelism in favor of other parallelization strategies
- For long sequences, we can reduce Context Parallelism
+ We can reduce data parallelism in favor of other parallelization strategies.
+ For long sequences, we can reduce context parallelism.
-
Ok, now we have the model running in the general configuration we want in terms of model size and batch size, but are we training it the fastest way? Let's now start to optimize throughput as much as possible.
+
OK, now we have the model running in the general configuration we want in terms of model size and batch size - but are we training it the fastest way? The final step is to work on optimizing throughput.
-
Step 3: Optimizing Training Throughput
+
Step 3: Optimizing training throughput
-
So we want to make sure the training is running as fast as possible so all our precious GPUs are well utilized at all times. As long as memory and communication aren't bottlenecks we can try the following:
+
We want to make sure the training is running as fast as possible so all our precious GPUs are well utilized at all times. As long as memory and communication aren't bottlenecks, we can try the following:
- Scale up Tensor Parallelism (using the fast intra-node bandwidth) until we reach a degree close to the node size, so that we can reduce other parallelism
- Increase Data Parallelism with ZeRO-3 while keeping target batch size
- When Data Parallelism communication starts to become a bottleneck, transition to using Pipeline Parallelism
- Try scaling up different parallelisms one by one
- Experiment with several micro batch size (mbs) to aim for an optimal balance between max GBS, model size, compute, and communication.
+ Scale up tensor parallelism (using the fast intra-node bandwidth) until we reach a degree close to the node size, so that we can reduce other forms of parallelism.
+ Increase data parallelism with ZeRO-3 while keeping the target batch size.
+ When data parallelism communication starts to become a bottleneck, transition to using pipeline parallelism.
+ Try scaling up different parallelisms one by one.
+ Experiment with micro-batch sizes (mbs ) to aim for an optimal balance between max global batch size, model size, compute, and communication.
Benchmarking thousands of configurations
-
Now that we've covered the step-by-step, let's implement this search process in real-life.
+
Now that we've covered the step-by-step, let's implement this search process in real life.
-
You will find, in the nanotron repository, several scripts you can use to run all the experiments we discussed above and be able to benchmark your own model and cluster in real life.
+
In the Nanotron repository, you'll find several scripts you can use to run all the experiments discussed previously and benchmark your own model and cluster.
-
We actually ran ourself benchmarks on several thousands of distributed configurations covering every model size we've discussed above as well as a very large number of cluster configurations (namely 1-64 nodes of 8xH100s) we could try in order to produce the results we've covered up to now in this book.
-
We want to take this opportunity to apologize to our co-workers for blocking most of the science cluster and in turn forgive any threats that may have been whispered.
+
We actually ran benchmarks ourselves on several thousand distributed configurations , covering every model size we've discussed here as well as a very large number of cluster configurations (namely, 1-64 nodes of 8xH100s) in order to produce the results we've covered up to now in this book.
+
We want to take this opportunity to apologize to our coworkers for blocking most of the science cluster, and in turn forgive any threats that may have been whispered.
-
Now let's take a step back to gather and analyze the results of all our benchmarks and see if, beyond theory, we can actually discover on real-world data how various configurations fare against each other.
+
Now let's take a step back to gather and analyze the results of all our benchmarks and see if, beyond theory, we can actually discover using real-world data how various configurations fare against each other.
-
All the following benchmarks were conducted with a sequence length of 4096 and a global batch size of 1M tokens. We gathered all the top configurations for each model and cluster size and plotted them in the following heatmaps:
+
All the following benchmarks were conducted with a sequence length of 4,096 and a global batch size of 1M tokens. We gathered all the top configurations for each model and cluster size and plotted them in the following heatmaps:
@@ -2000,46 +2002,46 @@
-
From this high-level visualization, we can draw several important insights:
-
-
-
First, as we increase the number of nodes (higher parallelism), we observe a decrease in efficiency. This effect is particularly pronounced for smaller models, which have a lower compute-to-model-size ratio. While we might typically compensate for small model size by increasing the batch size, we're constrained by our global batch size limit of 1M.
-
+
From this high-level visualization, we can draw several important insights:
+
+
+ First, as we increase the number of nodes (higher parallelism), we observe a decrease in efficiency. This effect is particularly pronounced for smaller models, which have a lower compute to model size ratio. While we might typically compensate for small model size by increasing the batch size, we're constrained by our global batch size limit of 1M.
- Second, Larger models present a different challenge. As model size increases, memory requirements grow substantially. This creates two scenarios with fewer nodes: either the model doesn't fit at all, or it barely fits but runs inefficiently due to operating near the GPU memory limits (see for instance the 80B parameter model training on 4 nodes).
+ Second, larger models present a different challenge. As model size increases, memory requirements grow substantially. This creates two scenarios with fewer nodes: either the model doesn't fit at all, or it fits but runs inefficiently due to operating near the GPU memory limits (see for instance the 80B parameter model training on 4 nodes).
- Finally, our benchmarks show how performance heavily depends on implementation quality. When we first implemented both parallelism strategies, Tensor Parallelism (TP) outperformed Pipeline Parallelism (PP). After optimizing our PP code, it became the faster option. Now that we're improving the communication overlap in our TP implementation, we expect it to regain the performance lead.
-
+ Finally, our benchmarks show how performance heavily depends on implementation quality. When we first implemented both parallelism strategies, tensor parallelism outperformed pipeline parallelism. After optimizing our PP code, it became the faster option. Now that we're improving the communication overlap in our TP implementation, we expect it to regain the performance lead.
+
+
Lessons learned on benchmarking
-
Our goal for this book was not only to discuss theory and implementations but provide actual data points as well. So the plan was simple: let's run every possible distributed configuration for every model and a number of cluster sizes (namely 1-64 nodes of 8xH100s). Even after excluding impossible configuration we still needed to run thousands of experiments.
+
Our goal for this book was not only to discuss theory and implementations, but to provide actual data points as well. So, the plan was simple: let's run every possible distributed configuration for every model and a number of cluster sizes. Even after excluding impossible configurations, we still needed to run thousands of experiments.
- On paper this sounds easy enough: we can easily launch big arrays of jobs on our cluster. However, as soon as we launched the first batches of experiments, troubles began:
+ On paper, this sounds easy enough: we can easily launch big arrays of jobs on our cluster. However, as soon as we launched the first batches of experiments, our troubles began:
- PyTorch processes would sometimes fail to clean up properly
- Slurm job manager would forcefully terminate jobs, leading to node failures
- Simple benchmarks that should take minutes would stretch into hours
- Some jobs would hang indefinitely
+ PyTorch processes would sometimes fail to clean up properly.
+ The Slurm job manager would forcefully terminate jobs, leading to node failures.
+ Simple benchmarks that should have taken minutes would stretch into hours.
+ Some jobs would hang indefinitely.
-
Running all experiments in a finite amount of time required additional engineering and we ended up spending a significant amount of time on things like:
+
Running all the experiments in a finite amount of time required additional engineering, and we ended up spending a significant amount of time on things like:
- Minimizing cluster restart times and optimize idle time
+ Minimizing cluster restart times and optimizing idle time
Analyzing detailed NCCL debug logs
- Understand memory usage patterns and CUDA memory allocator behaviors
- Improving pipeline parallelism performance on multi-node
+ Understanding memory usage patterns and CUDA memory allocator behaviors
+ Improving pipeline parallelism performance on multi-node setups
-
These challenges deserve their own story, but they taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.
+
These challenges taught us valuable lessons about the complexities of distributed training infrastructure. What looks simple in theory often requires careful attention to many moving parts in practice.
-
Reproducing theoretical results in practice is challenging, especially given the limited availability of production training code. Through open-source projects like nanotron and picotron , we hope we can help making distributed training techniques more accessible as well as collaborating on simple and efficient codebases that help researchers and practitioners take the most out of their hardware resources.
+
Reproducing theoretical results in real life is challenging, especially given the limited availability of production training code. Through open source projects like Nanotron and Picotron , we hope we can help making distributed training techniques more accessible, as well as collaborating on simple and efficient codebases that help researchers and practitioners get the most out of their hardware resources.
-
On the compute side, GPUs consist of an array of compute units called Streaming Multiprocessors (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see docs for tensor cores for details), each capable of handling multiple threads simultaneously.
+
On the compute side, a GPU consists of an array of compute units called streaming multiprocessors (SMs) . Each SM contains and controls a set of streaming processors, also known as cores . For example, an NVIDIA H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see the docs for tensor cores for details), each capable of handling multiple threads simultaneously.
-
The memory side is also highly hierarchical with several layers of cache and memory: Registers are the smallest units and are private to the threads during executions, Shared Memory and L1 cache are shared between the threads running on a single SM, higher up is the L2 cache shared by all SMs, finally there is the Global Memory which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
+
The memory side is also highly hierarchical, with several layers of cache and memory. Registers are the smallest units and are private to the threads during executions. Shared memory and the L1 cache are shared between the threads running on a single SM. Higher up is the L2 cache shared by all SMs, and finally there is the global memory , which is the largest memory on the GPU (the advertised 80 GB for an H100, for instance) but also the slowest to access and query.
-
The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.
+
The goal when using a GPU is to run as many workloads as possible, in parallel, on the available cores, by taking advantage of this hierarchical organization of compute/memory resources.
-
A piece of code running on a core of the GPU is called a kernel . It can be written at a high-level in CUDA or Triton for instance, and is then compiled to Parallel Thread Execution, PTX, the low-level assembly used by NVIDIA GPUs.
+
A piece of code running on a core of the GPU is called a kernel . It can be written at a high level in CUDA or Triton, for instance, and is then compiled to Parallel Thread Execution (PTX), the low-level assembly used by NVIDIA GPUs.
-
To run the kernel, you will also need a specific code part, called host code , which is executed on the CPU/host and will take care of preparing data allocations and loading data and code.
+
To run the kernel you will also need some host code , which is executed on the CPU/host and takes care of preparing data allocations and loading data and code:
@@ -2134,7 +2136,7 @@
cudaFree(d_C);
}
@@ -2158,23 +2160,23 @@
Figure 5: Host code for a CUDA kernel for adding two vectors from https://blog.codingconfessions.com/p/gpu-computing
-->
-
Kernels are generally scheduled as follow:
+
Kernels are generally scheduled as follows:
- threads are grouped in warps of sizes of 32. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.
- warps are grouped in larger blocks of more flexible size (e.g. size 256), each block still being assigned to a single SM. An SM may run several blocks in parallel, however, depending on the resources, not all the blocks may get assigned for execution immediately, some can be waitlisted waiting for resources.
+ Threads are grouped in warps , each containing 32 threads. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.
+ Warps are grouped in larger blocks of more flexible size (for example, there may be 512 or 1,024 threads in a block), with each block assigned to a single SM. An SM may run several blocks in parallel. However, depending on the resources available, not all of the blocks may get assigned for execution immediately; some may be waitlisted until more resources become available.
-
The main thing to remember from these details is that there are various sizing and allocation constraints (size of the various memories, number of concurrent block and threads in the wraps) which need to be taken into account to use the GPU architecture in the most efficient way.
+
The main thing to retain here is that there are various sizing and allocation constraints (size of the various memories, number of concurrent blocks and threads in the warps) which need to be taken into account to use the GPU architecture in the most efficient way.
-
Most of the time you don’t need to go down to this level of precision and you can luckily reuse the kernels and code prepared by other members of the community. But in any case we want to give you a primer on how to get started with kernels!
+
Most of the time, you don’t need to go down to this level of precision and you can reuse the kernels and code prepared by other members of the community - but we'll give you a few tips on getting started with kernels anyway!
-
How to improve performance with Kernels ?
+
Improving performance with kernels
-
If you’re looking to add a new operation that lacks an optimized kernel or to speed up an existing PyTorch function, writing kernels from scratch might seem like the most direct route. However, creating high-performance CUDA kernels from scratch requires extensive experience and a steep learning curve. Generally a better way to get started is to leverage torch.compile
, which dynamically optimizes PyTorch code by capturing your operations and generating lower-level, high-performance kernels in triton.
+
If you’re looking to add a new operation that lacks an optimized kernel or to speed up an existing PyTorch function, writing kernels from scratch might seem like the most direct route. However, creating high-performance CUDA kernels from scratch requires extensive experience, and there's a steep learning curve. Generally, a better way to get started is to leverage torch.compile
, which dynamically optimizes PyTorch code by capturing your operations and generating lower-level, high-performance kernels in Triton.
-
Let’s suppose you want to write a kernel for an activation function called Exponential Linear Unit:
+
Let’s suppose you want to write a kernel for the Exponential Linear Unit (ELU) activation function:
\text{ELU}(x) = \begin{cases}
@@ -2183,7 +2185,7 @@
\end{cases}
-
You can start by a simple pytorch implementation and then just add the @torch.compile
decorator on top:
+
You can start by writing a simple PyTorch implementation, and then just add the @torch.compile
decorator on top:
@torch.compile
@@ -2191,19 +2193,19 @@
return torch.where(x < 0, alpha * (torch.exp(x) - 1), x)
-
The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns):
-
+
As you can see in the following graph, there's a remarkable performance difference between the compiled and non-compiled versions, especially given that we only added a decorator (N here is the number of columns):
+
-
However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by @torch.compile . To do so, you simply need to set the environment variable TORCH_LOGS
to "output_code"
:
+
However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the Triton kernel generated by @torch.compile
. To do so, you simply need to set the environment variable TORCH_LOGS
to "output_code"
:
export TORCH_LOGS="output_code"
-
Once you run the Python script with the @torch.compile
decorator, it will generate and output the corresponding Triton kernel, which, in this case, is:
+
Once you run the Python script with the @torch.compile
decorator, it will generate and output the corresponding Triton kernel, which in this case is:
@triton.jit
@@ -2250,35 +2252,35 @@
tl.store(output_ptr + block_indices, output_values, valid_mask)
-
Here, tl.program_id(0)
provides a unique block ID, that we use to determine which section of data that block will process. Using this block ID, block_start
calculates the starting index for each block’s section, while block_indices
specifies the range of indices within that section. A valid_mask
ensures that only indices within num_elements
are processed, safely loading the data with tl.load
. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with tl.store
.
+
Here, tl.program_id(0)
provides a unique block ID, which we use to determine which section of data that block will process. Using this block ID, block_start
calculates the starting index for each block’s section, while block_indices
specifies the range of indices within that section. A valid_mask
ensures that only indices within num_elements
are processed, safely loading the data with tl.load
. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with tl.store
.
-
When we benchmark the generated kernel using triton.testing.Benchmark
we have the following performance:
+
When we benchmark the generated kernel using triton.testing.Benchmark
, we have the following performance:
-
This standalone kernel even demonstrates superior performance with smaller sizes compared to @torch.compile
but this is likely just an artifact of the compilation time of torch.compile
. In any case, instead of starting from scratch, remember that you can start from such generated kernels and focus your attention to optimizing its performance, saving you a lot of time in the process.
+
This standalone kernel even demonstrates superior performance with smaller sizes compared to @torch.compile
, but this is likely just an artifact of the compilation time of torch.compile
. In any case, instead of starting from scratch, remember that you can start from such generated kernels and focus your attention on optimizing their performance, saving you a lot of time.
-
Even in Triton, sometimes, we cannot fully achieve the peak performance of the device due to the language limitations to handle low level details like shared memory and scheduling within streaming multiprocessors (SMs). Triton capabilities are restricted to blocks and scheduling of blocks across SMs. To gain an even deeper control, you will need to implement kernels directly in CUDA, where you will have access to all the underlying low-level details.
+
Even in Triton, sometimes we cannot fully achieve the peak performance of the device due to the language's limitations in handling low-level details like shared memory and scheduling within streaming multiprocessors. Triton's capabilities are restricted to blocks and scheduling of blocks across SMs. To gain even deeper control, you will need to implement kernels directly in CUDA, where you will have access to all the underlying low-level details.
-
Moving down to CUDA, various techniques can be employed to improve the efficiency of kernels. We will just cover a few here: optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle times.
+
With CUDA, various techniques can be employed to improve the efficiency of kernels. We will cover just a few here: optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle time.
-
Before we dive deeper in CUDA examples, let's summarize the tools we've seen that let us write kernel code to execute instructions on the GPU:
+
Before we dive deeper into CUDA examples, let's summarize the tools we've seen that let us write kernel code to execute instructions on the GPU:
- Pytorch: easy but slow
- torch.compile: easy, fast, but not flexible
- triton: harder, faster, and more flexible
- CUDA: hardest, fastest, and flexiblest (if you get it right)
+ PyTorch: easy but slow
+ @torch.compile
: easy, fast, but not flexible
+ Triton: harder, faster, but more flexible
+ CUDA: hardest, fastest, and most flexible (if you get it right)
-
Let’s talk about one of the most frequent technique we can use in CUDA: optimizing memory access. The global memory in GPUs (the largest memory in our above graph) has a long latency and low bandwidth in comparison to the cache which often creates a major bottleneck for most applications. Efficiently accessing data from global memory can improve performance by a lot.
+
We'll start by looking at one of the most frequent uses of CUDA: optimizing memory access. The global memory in GPUs (the largest memory area, as you saw earlier) has a long latency and low bandwidth in comparison to the cache, creating a major bottleneck for many applications. Efficiently accessing data from global memory can greatly improve performance.
-
Memory Coalescing
+
Memory coalescing
To effectively utilize the bandwidth of global memory, it is essential to understand its architecture. In CUDA devices, global memory is implemented using DRAM.
-
Memory coalescing takes advantage of how DRAM delivers data in bursts, or ranges of consecutive memory locations, whenever a memory address is accessed. Each time a DRAM location is accessed, a sequence of consecutive locations, including the requested one, is read in parallel by multiple sensors in the DRAM chip. Once read, this data can then be quickly transferred to the processor as a burst. In CUDA, coalescing uses this burst behavior to maximize memory access efficiency by ensuring that threads in a warp—32 threads that execute the same instruction in lockstep (SIMD)—access consecutive memory locations. For instance, if thread 0 accesses location M, thread 1 accesses M + 1, thread 2 accesses M + 2, and so forth, the GPU hardware coalesces or combines these requests into one large, efficient access request for the DRAM burst, rather than handling each access individually.
+
Memory coalescing takes advantage of how DRAM delivers data in bursts whenever a memory address is accessed. Each time a DRAM location is accessed, a sequence of consecutive locations (including the requested one) is read in parallel by multiple sensors in the DRAM chip. Once read, this data can then be quickly transferred to the processor as a burst. In CUDA, coalescing uses this burst behavior to maximize memory access efficiency by ensuring that threads in a warp — a set of 32 threads that execute the same instruction in lockstep — access consecutive memory locations. For instance, if thread 0 accesses location M , thread 1 accesses M + 1 , thread 2 accesses M + 2 , and so forth, the GPU hardware coalesces or combines these requests into one large, efficient access request for the DRAM burst, rather than handling each access individually.
Let’s take the example of matrix multiplication. A simple, straightforward implementation would have each thread compute a single element of the output matrix, like this:
@@ -2297,11 +2299,11 @@
}
-
Here’s an excellent visualization of the kernel from this fantastic blogpost :
+
Here’s an excellent visualization of the kernel from Simon Boehm’s fantastic blog post :
-
However, when profiling this kernel with a tool like ncu
, we can see issues, including low memory throughput and uncoalesced memory accesses.
+
However, when profiling this kernel with a tool like ncu
, we can see issues, including low memory throughput and uncoalesced memory accesses:
-
The reason for this is that in this kernel, two threads in the same block with Thread IDs (0, 0)
and (1, 0)
(which will end up in the same warp) will both load from the same column of matrix B
but different rows of matrix A
. Since matrix elements are stored in row-major order (meaning row elements are in consecutive memory addresses, as shown in the figure below) thread (0, 0)
will load A_{0,0} , and thread (1, 0)
will load A_{1,0} in the first iteration i = 0
. These elements are not stored close to each other in memory, and this misalignment will be present at each iteration, thereby preventing memory accesses from being coalesced.
+
The reason for this is that in this kernel, two threads in the same block with thread IDs (0, 0)
and (1, 0)
(which will end up in the same warp) will both load from the same column of matrix B but different rows of matrix A . Since matrix elements are stored in row-major order (meaning row elements are in consecutive memory addresses, as shown in the figure below), thread (0, 0)
will load A_{0,0} and thread (1, 0)
will load A_{1,0} in the first iteration, i = 0 . These elements are not stored close to each other in memory, and this misalignment will be present at each iteration, thereby preventing memory accesses from being coalesced.
-
To improve the performances of our kernel we can change the way coordinates
x and y
are calculated to the following:
+
To improve the performance of our kernel, we can change the way coordinates x
and y
are calculated to the following:
const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
@@ -2330,9 +2332,9 @@
}
-
Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of x
and y
. In this new method, threads within the same warp (which have close threadIdx.x
values) will share the same x
value but have different y
values. This means that they will load the same row of matrix A
but different columns of matrix B
. As a result, memory accesses can be coalesced for a row-major matrix.
+
Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of x
and y
. In this new method, threads within the same warp (which have close threadIdx.x
values) will share the same x
value but have different y
values. This means that they will load the same row of matrix A but different columns of matrix B . As a result, memory accesses can be coalesced for a row-major matrix.
-
When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and the GPU's memory throughput has increased by approximately 10 times .
+
When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and the GPU's memory throughput has increased by approximately a factor of 10 .
-
We also notice that the execution time of the kernel decreases by 10x ! Amazing.
-
Now let's cover another technique you will often see mentioned in the litterature: tiling .
+
We also notice that the execution time of the kernel has decreased by 10x . Amazing!
+
Now let's cover another technique you will often see mentioned in the literature: tiling .
Tiling
-
Tiling is a technique that leverages shared memory to optimize memory access patterns. As we mentioned above, the shared memory is a small, fast memory accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from slower global memory.
+
Tiling is a technique that leverages shared memory to optimize memory access patterns. As we mentioned earlier, the shared memory on a GPU is a small, fast memory area accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from the slower global memory.
-
In matrix multiplication for example, each thread in a block may need elements from two matrices, say A and B. If each thread independently loads the row and column it needs from global memory, we end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or tile) of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.
+
In matrix multiplication, for example, each thread in a block may need elements from two matrices, say A and B . If each thread independently loads the row and column it needs from global memory, we'll end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or "tile") of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.
-
In the tiling approach, each iteration involves all threads within a block to cooperatively load two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size BLOCK_SIZE_M
by BLOCK_SIZE_K
) and a tile of matrix B (of size BLOCK_SIZE_K
by BLOCK_SIZE_N
). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
+
In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles — one from matrix A and another from matrix B — into shared memory. Specifically, the threads load a tile of matrix A (of size BLOCK_SIZE_M
by BLOCK_SIZE_K
) and a tile of matrix B (of size BLOCK_SIZE_K
by BLOCK_SIZE_N
). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all the necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
@@ -2388,46 +2390,42 @@
}
C[localRow * N + localCol] = sum;
-
For simplicity we consider a square shaped tile.
+
For simplicity, we consider a square-shaped tile.
-
Each thread begins by loading one element from both Matrix A and Matrix B into shared memory. In this scenario, achieving coalesced memory access is straightforward, by assigning threadIdx.x
as the local column index (localCol) , threads within the same warp will access adjacent elements of both matrices. After each thread in the block completes loading its elements into shared memory (ensured by calling __syncthreads()
), they proceed to compute the dot product of the two tiles. Once the threads have iterated through all the tiles—horizontally for Matrix A and vertically for Matrix B —the resulting sum is stored in the corresponding location of Matrix C .
+
Each thread begins by loading one element from both matrix A and matrix B into shared memory. In this scenario, achieving coalesced memory access is straightforward: by assigning threadIdx.x
as the local column index (localCol
), we ensure that threads within the same warp will access adjacent elements of both matrices. After each thread in the block completes loading its elements into shared memory (ensured by calling __syncthreads()
), they proceed to compute the dot product of the two tiles. Once the threads have iterated through all the tiles — horizontally for A and vertically for B - the resulting sum is stored in the corresponding location of matrix C .
-
When benchmarking this kernel using ncu, we noticed that the memory throughput increased to 410 Gb / s, and the kernel execution time decreased by ~43% achieving a ~6.6 TFLOPs performance
+
When benchmarking this kernel using ncu
, we noticed that the memory throughput increased to 410 Gb/s and the kernel execution time decreased by ~43%, achieving a ~6.6 TFLOPS performance.
-
Thread Coarsening
+
Thread coarsening
-
The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
+
The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states, which quantify how many cycles were spent in each state, we observe the following:
-
The meaning of these cryptic state names can be found in NVidia's profiling Guide , in the Warp Stall Reasons section. There we can read that:
-
-
"smsp__pcsamp_warps_issue_stalled_mio_throttle
: Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure."
+
The meaning of these cryptic state names can be found in NVIDIA's Kernel Profiling Guide , in the "Warp Stall Reasons" section. There, we see that smsp__pcsamp_warps_issue_stalled_mio_throttle
indicates "Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure."
-
So it seems warps are stalling waiting for shared memory accesses to return! To solve this issue we can apply a technique called Thread Coarsening which involves merging several threads into a single coarsened thread. This will significantly reduce shared memory accesses as each coarsened thread can handle multiple output elements.
+
So it seems warps are stalling waiting for shared memory accesses to return! To solve this issue we can apply a technique called thread coarsening , which involves merging several threads into a single coarsened thread. This will significantly reduce shared memory accesses, as each coarsened thread can handle multiple output elements.
-
Let's briefly go through a last important consideration when writing or improving custom kernels: Minimizing Control Divergence .
+
Next, let's briefly go through a last important consideration when writing or improving custom kernels: minimizing control divergence .
-
Minimizing Control Divergence
+
Minimizing control divergence
-
A Streaming Multiprocessor (SM) is built to execute all threads in a warp using the Single Instruction, Multiple Data (SIMD) model. This means that at any given moment, one instruction is fetched and executed simultaneously for all threads within the warp. When a warp is executed, the threads within it operate on different segments of the data but follow the same instruction, hence the name Single Instruction, Multiple Data. The primary advantage of SIMD is its efficiency; the control hardware responsible for instruction fetching and dispatching is shared among multiple execution units. This design minimizes the hardware overhead associated with control functions, allowing a greater portion of the hardware to focus on improving arithmetic throughput.
+
A streaming multiprocessor is built to execute all threads in a warp using the Single Instruction, Multiple Data (SIMD) model. This means that at any given moment, one instruction is fetched and executed simultaneously for all threads within the warp. When a warp is executed, the threads within it operate on different segments of the data but follow the same instruction (hence the name Single Instruction, Multiple Data). The primary advantage of SIMD is its efficiency: the control hardware responsible for instruction fetching and dispatching is shared among multiple execution units. This design minimizes the hardware overhead associated with control functions, allowing a greater portion of the hardware to focus on improving arithmetic throughput.
-
Control divergence occurs when threads within the same warp take different execution paths. For instance, if a conditional statement (like an if
statement) leads to some threads executing one block of code while others execute a different block, the warp must serialize these executions, resulting in idle threads waiting for others to complete. To minimize control divergence, we need to design kernels to ensure that threads within the same warp follow the same execution path. This can be achieved by restructuring code to reduce branching, using data structures that ensure all threads follow similar execution paths, or employing techniques such as predication.
+
Control divergence occurs when threads within the same warp take different execution paths. For instance, if a conditional statement (like an if
statement) leads to some threads executing one block of code while others execute a different block, the warp must serialize these executions, resulting in idle threads waiting for others to complete. To minimize control divergence, we need to design kernels to ensure that threads within the same warp follow the same execution path. This can be achieved by restructuring code to reduce branching, using data structures that ensure all threads follow similar execution paths, or employing techniques such as predication.
-
-
-
We have covered some of the main considerations when writing custom kernels and improving the performance and memory footprint of GPU operations. But there’s one more important concept before moving to a real example which is “fusing kernels”.
+
We have covered some of the main considerations when writing custom kernels and improving the performance and memory footprint of GPU operations. But there’s one more important concept to consider before we move to a real example: fusing kernels .
-
Fused Kernels
+
Fused kernels
-
In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.
+
In several places now, we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workloads on the GPU in a non-blocking way.
-
Non-blocking can be useful for overlapping communication and computation –as we saw many times along our journey– but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands.
-
This idea is beautifully illustrated by Horace He in these diagrams:
+
This can be useful for overlapping communication and computation – as we've seen many times in our journey – but it can also be extended to the more general idea of trying to avoid, at all cost, going back and forth between host and GPU kernel commands.
+
This idea is beautifully illustrated by Horace He in these diagrams:
@@ -2438,72 +2436,71 @@
-
How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.
+
How can we avoid the back and forth shown on the left? Well, the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations as possible together in a single kernel for the GPU to run, called a “fused kernel,” as shown on the right.
-
Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values locally until the succession of computation has been performed.
+
Fused kernels are especially efficient and simple to write for successions of point-like operations that are performed independently of each other on each input token. In this case, there is no point in sending the computed values back to global memory before moving them to SM memory and spinning up a new kernel. It's much more efficient to keep all the values locally until all the computations have been performed.
-
There are many places in a Transformer model where this "fusing" approach can be applied: every time we have a succession of point-wise operations e.g. in the computation involved in the Layer norms.
+
In a Transformer model, this "fusing" approach can be applied every time we have a succession of point-wise operations, such as in the computations involved in the LayerNorm layers.
-
We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: Flash Attention
+
We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: FlashAttention .
-
Flash Attention 1-3
+
FlashAttention
-
Flash attention was introduced by Tri Dao and proposed to optimize the attention computations by writing custom CUDA kernels make them much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory of the GPU.
-
Note that the global memory of the GPU is confusingly called the "High Bandwidth Memory", HBM 🫠
+
FlashAttention was introduced by Tri Dao and proposed to optimize attention computations by writing custom CUDA kernels to make them much faster and more memory efficient. The idea behind FlashAttention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory.
+
+
The global memory in modern GPUs often uses a technology called High Bandwidth Memory (HBM), which despite its name, is slower than SRAM in the GPU memory hierarchy. This HBM terminology will be important when we discuss the details of FlashAttention's implementation.
-
A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
+
A basic implementation of the attention mechanism involves a lot of transfer between memory and workers. It requires materializing the S matrix (where S = QK^T, the attention scores) and the P matrix (where P = softmax(S), the normalized attention weights) in HBM, which means that the results need to be sent to HBM and then back to SRAM for the next computations:
-
Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
+
Since bandwidth is much lower in HBM, this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
-
The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of O directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
+
The key element is to compute the S matrix in small pieces that can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix altogether, in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So, we can compute part of O directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not only do we make use of the shared memory, but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length): the attention matrix.
-
The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:
+
The idea of FlashAttention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers. Notably:
- By avoiding to materialize the S matrix we reduce the memory burden of attention
- We also remove a large part of the naive impact of the S^2 cost of attention
+ By avoiding materializing the S matrix, we reduce the memory burden of attention .
+ We also remove a large part of the naive impact of the O(S^2) cost of attention .
-
As a result as well, all variants of linear attention and sub-quadratic approaches to approximate attention –developed shortly after the invention of the transformers architecture– have been mostly put aside in favor of this exact and fast flash attention implementation and mechanism.
+
All variants of linear attention and subquadratic approaches to approximate attention (developed shortly after the invention of the Transformer architecture) have mostly been put aside in favor of this exact and fast FlashAttention implementation and mechanism.
-
Following Flash-attention 1, two successive improved versions have been released by the same lab: Flash-attention 2 and 3. In comparison to Flash-attention 1, the improvements in Flash-attention 2 and 3 are less about the general attention mechanism than about tailoring its low level implementation more specifically to the GPU by (1) reducing the number of non-matmul operations as much as possible (2) partitioning carefully the workload among wraps and thread blocks (for Flash Attention 2) and carefully optimizing for FP8 and Tensor Core support on the latest Hopper (H100) architecture for Flash Attention 3.
+
Following FlashAttention-1, two successive improved versions were released by the same lab: FlashAttention-2 and -3. In comparison to FlashAttention-1, the improvements in FlashAttention-2 and -3 are less about the general attention mechanism and more about tailoring its low-level implementation more specifically to the GPU by (1) reducing the number of non-matmul operations as much as possible, (2) carefully partitioning the workload among wraps and thread blocks (for FlashAttention-2), and (3) carefully optimizing for FP8 and Tensor Core support on the latest Hopper (H100) architecture for FlashAttention-3.
-
Flash attention puts some restrictions on which attention patterns can be sped up. Check out FlexAttention which is a fast and flexible variant.
+
FlashAttention puts some restrictions on which attention patterns can be sped up. Check out FlexAttention , which is a fast and flexible variant.
-
Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.
+
FlashAttention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.
-
-
-
The techniques described so far in this operation-fusion section have required us to implement modeling code changes and write custom kernels for certain operations in order to speed up training.
-
In the final section of our low-level dive in the compute operations themselves, we will take a look at a range of methods that are agnostic to the modeling code and can be used for any model and are so widely used that they have become a standard in the industry: Mixed Precision Training !
+
The techniques described so far in this section have required us to implement modeling code changes and write custom kernels for certain operations in order to speed up training.
+
In the final section of our low-level dive into the compute operations themselves, we will take a look at a range of methods that are agnostic to the modeling code. They can be used for any model, and are so widely used that they have become standard in the industry: up next, mixed precision training !
-
Mixed Precision Training
+
Mixed precision training
-
In various sections along this book, we've talked about lower precisions formats and their impact on the memory requirements for storing activations, parameters and optimizer states. It's now time to dive deeper in the details of these formats and understand better their trade-offs, advantages and limitations.
+
In various sections of this book, we've talked about lower-precision formats and their impact on the memory requirements for storing activations, parameters, and optimizer states. It's now time to dive deeper into the details of these formats and get a better understanding of their trade-offs, advantages, and limitations.
-
Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:
+
Mixed precision training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating-point format , also called FP32 or float32 , which means that every number stored takes up 32 bits, or 4 bytes. The available bits to represent a number are divided into three parts:
Sign: the first bit determines if the number is positive or negative
- Mantissa: determines the significant figures of a number
Exponent: controls the magnitude of the number
+ Mantissa: determines the significant figures of the number
-
The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. - 5.734 \times 10^{7} , where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:
+
The principle of floating-point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. - 5.734 \times 10^{7} , where we first have the sign, followed by the mantissa and the exponent. As such, we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default, there are a range of floating-point formats available in PyTorch:
@@ -2558,48 +2555,47 @@
Note: You might be wondering where the “b” in bfloat16 comes from. The format was developed at Google Brain and thus the “b” stands for “brain”.
-
Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
+
Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay: we can sacrifice bits in either the mantissa or the exponent. For this reason, there also exist two float8 (FP8) formats, named according to the exponent and mantissa, so we can flexibly choose the most appropriate format. Let's take a look at the possible range of numbers for each format:
-
We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.
+
We can see that float32 spans 80 orders of magnitude, and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further: e5m2 can maintain the float16 range, but e4m3 has an even smaller range.
-
How come some formats are able to maintain the range and others not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
+
How come some formats are able to maintain the full range while others aren't? Let's investigate their resolutions by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
-
We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.
-
-
A common metric to measure a formats resolution is epsilon: the first representable number after 1.00 . We can see that for the float32 format 10^{-4} is an upper bound (it’s actually 1.19^{-7} ). For float16 it is ~ 10^{-3} and for bfloat 10x higher still.
+
We can see here that although bfloat16 maintains the range of float32 (unlike float16), it does this at the cost of sacrificing more precision. In the case of float8 the situation is even more dire: e4m3 can represent only 7 and e5m2 only 3 numbers in the interval [1,2].
-
The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training.
+
A common metric to measure a format's resolution is epsilon : the first representable number after 1.00 . We can see that for the float32 format, 10^{-4} is an upper bound (it’s actually 1.19^{-7} ). For float16 it's ~10^{-3} , and for bfloat it's 10x higher still.
-
It turns out we can’t totally abandon float32 and usually will need to maintain some parts in full precision. This is why lower precision training is usually called mixed precision training.
+
The idea of mixed precision training is to use some of these lower-precision formats for certain computations while maintaining the performance of full precision training. As it turns out, we can’t totally abandon float32 and usually will need to do some of the computations in full precision.
-
Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.
+
Let’s now take a look at training models with 16 bits and then see if we can take it a step further, all the way down to 8 bits.
FP16 and BF16 training
-
Naively switching all the tensors and operations to float16 unfortunately doesn’t work and the result is usually diverging losses. However, the original mixed precision training paper came up with three tricks to match float32 trainings:
+
Naively switching all the tensors and operations to float16 unfortunately doesn’t work, and the result is usually diverging losses. However, the original mixed precision training paper came up with three tricks to match float32 training:
- FP32 copy of weights : There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.
- Loss scaling : We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step.
- Accumulation : Finally, when performing certain arithmetic operations in 16-bit precision such as averages or summations, we can also face under or overflows. A solution is then to accumulate intermediate results in float32 during the operation and only cast the final result back to 16 bit precision.
+ FP32 copy of weights: There are two possible issues with FP16 weights. During training, some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to 0, if the updates are very small the difference in magnitude can cause them to underflow during the addition. Once the weights are 0, they will remain at 0 for the rest of training as there is no gradient signal coming through anymore.
+ Loss scaling: We have a similar issue with the gradients as well, as gradients tend to be much smaller than 1 and are thus at risk of underflow. A simple yet effective strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass, and the scaling does not affect training because we unscale before processing the gradients further (e.g., clipping) and the optimization step.
+ Accumulation: Finally, when performing certain arithmetic operations in 16-bit precision (such as averages or summations), we can also face under- or overflows. A solution then is to accumulate intermediate results in FP32 during the operation and only cast the final result back to 16-bit precision.
-
With these techniques, we can get a stable training while benefitting from a higher throughput due to the faster, lower precision arithmetic operations. Naturally, as a curious reader –and by now slightly addicted to maximizing the throughput– you may ask the question: can we go further and faster than 16-bit precision?
+
With these techniques, we can get stable training while benefitting from higher throughput due to the faster, lower-precision arithmetic operations.
+
Naturally, as a curious reader – and by now slightly addicted to maximizing the throughput – you may ask the question: Can we go further and faster than 16-bit precision?
Maybe!
FP8 pretraining
-
Even if we perfectly overlap communication with computation, we always eventually run into the low level theoretical FLOPS limit of the hardware itself, i.e. the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of bfloat16, making lower-precision training an attractive path for further optimization.
+
Even if we perfectly overlap communication with computation, we always eventually run into the low-level theoretical FLOPS limit of the hardware itself - i.e., the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of BF16, making lower-precision training an attractive path for further optimization.
-
Recent research - including FP8-LM , torchao , and DeepSeek-V3 - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
+
Recent research - including FP8-LM , torchao , and DeepSeek-V3 - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability . At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
We know that instability increases as learning rates rise for a fixed model size , making FP8 pretraining particularly tricky.
@@ -2608,11 +2604,11 @@
-
The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8.
+
The first successful very large scale training with FP8 mixed precision was publicly reported in the DeepSeek-V3 technical report . The authors carefully analyzed each operation of the forward pass (Fprop ) as well as the activation (Dgrad ) and weight (Wgrad ) backward passes. Similar to BF16 mixed precision training, some aggregations and master weights are kept in higher precision while the operations themselves are performed in FP8.
-
In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of activation values, for instance by computing their absolute maximum. DeepSeek-V3 further introduced a specific quantization scheme where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less strongly impacted by outlier values in the activations. There is a number of additional tricks they proposed to further reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report .
+
In order to switch from high precision (e.g., FP32 or BF16) to lower precision (e.g., FP16 or FP8) with a smaller range, we need to normalize the range of activation values, for instance by computing their absolute maximum. DeepSeek-V3 further introduced a specific quantization scheme where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less strongly impacted by outlier values in the activations. The authors also proposed a number of additional tricks to further reduce the memory and communication footprint, which you can read about in section 3.3 of the technical report .
Here’s a summary of a few known approaches to FP8 training:
@@ -2625,126 +2621,125 @@
Accumulated gradients
Model weights
Gradients
-
Optimizer States
-
Total Memory
+
Optimizer states
+
Total memory
- bfloat16 with fp32 mixed precision baseline
- bf16
- fp32
- fp32
- bf16
- bf16
- fp32 + fp32
+ BF16 with FP32 mixed precision baseline
+ BF16
+ FP32
+ FP32
+ BF16
+ BF16
+ FP32 + FP32
4 + 4 + 2 + 2 + 4 + 4 = 20 bytes
- Above without FP32 grad accumulation
- bf16
- fp32
+ The above without FP32 grad accumulation
+ BF16
+ FP32
n/a
- bf16
- bf16
- fp32 + fp32
- 4 + 2 + 2 + 4 + 4 = 16 bytes
+ BF16
+ BF16
+ FP32 + FP32
+ 4 + 2 + 2 + 4 + 4 = 16 bytes (20% reduction)
- Transformer Engine
- fp8
+ Transformer engine
+ FP8
n/a
n/a
- fp32
- fp32
- fp32 + fp32
+ FP32
+ FP32
+ FP32 + FP32
4 + 4 + 4 + 4 = 16 bytes (20% reduction)
FP8-LM's O3 level
- fp8
- fp16
- fp16
- fp8
- fp8
- fp8 + fp16
- 2 + 2 + 1 + 1 + 1 + 2 = 9 bytes (55%)
+ FP8
+ FP16
+ FP16
+ FP8
+ FP8
+ FP8 + FP16
+ 2 + 2 + 1 + 1 + 1 + 2 = 9 bytes (55% reduction)
DeepSeek-V3
- fp8
- fp32
- fp32
- fp8
- bf16
- bf16 + bf16
- 4+4+1+2+2+2 = 15 (25%)
+ FP8
+ FP32
+ FP32
+ FP8
+ BF16
+ BF16 + BF16
+ 4 + 4 + 1 + 2 + 2 + 2 = 15 (25% reduction)
- nanotron's FP8
- fp8
- bf16
- fp32
- fp8
- fp8
- fp8 + fp8
- 2 + 4 + 1 + 1 + 1 + 1 = 10 bytes (50%)
+ Nanotron's FP8
+ FP8
+ BF16
+ FP32
+ FP8
+ FP8
+ FP8 + FP8
+ 2 + 4 + 1 + 1 + 1 + 1 = 10 bytes (50% reduction)
-
Overall, FP8 remains –in early 2025– an experimental technique and methods are still evolving. Given its obvious benefits, it will likely become the standard and soon replace bf16 mixed-precision. To follow an open-source implementations of FP8 training techniques, please head to the nanotron’s implementation in this PR .
+
Overall, FP8 remains (in early 2025) an experimental technique, and methods are still evolving. Given its obvious benefits, it will likely become the standard and soon replace BF16 mixed precision. To see an open source implementation of FP8 training techniques, check out this Nanotron PR .
-
Projecting further into the future, Blackwell, the next generation of NVIDIA chips, have been announced to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.
+
Projecting further into the future, Blackwell, the next generation of NVIDIA chips, have been announced to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.
-
-
-
This last section concluded our long journey in the land of fast and large model training on tens to thousands of GPUs. Time to slowly bring our GPU cluster to rest and take a step back to conclude on all we've learned along the way.
+
This last section concluded our long journey into the land of fast and large model training on tens to thousands of GPUs. Time to slowly bring our GPU cluster to rest and take a step back to reflect on all we've learned along the way!
Conclusion
-
Congratulations, dear reader, you made it to the end! We've completed quite a journey: we started from understanding how to train a simple model on a single GPU, all the way to mastering all the intricate techniques used to efficiently train massive language models like Llama-405B and DeepSeek-V3 on thousands of GPUs. By now, you can read a diagram, like Llama-3's 4D parallel setup, with (relative) ease:
+
Congratulations, dear reader, you made it to the end! We've completed quite a journey: we started with exploring how to train a simple model on a single GPU and went all the way to mastering various intricate techniques used to efficiently train massive language models like Llama-405B and DeepSeek-V3 on thousands of GPUs. By now, you can read a diagram like Llama-3's 4D parallel setup with (relative) ease:
-
Orchestrating large clusters of GPUs to train LLMs efficiently is no easy feat. We learned how to optimize computations and communications between GPUs such that they run with maximum utilization at all times. It involves choosing the right parallelization strategy for a given model and cluster size, overlapping communication and computation where possible, and writing custom kernels that take into account the hardware layout to perform an operation as fast as possible on the GPU.
+
Orchestrating large clusters of GPUs to train LLMs efficiently is no easy feat, but you've learned how to optimize computations and communications between GPUs such that they run with maximum utilization at all times. As you've seen, this involves choosing the right parallelization strategy for a given model and cluster size, overlapping communication and computation where possible, and writing custom kernels that take into account the hardware layout to perform an operation as fast as possible on the GPU.
-
You might still believe that this knowledge is a bit niche and only concerns the small set of people that pretrain LLMs. Historically, that may have been true, but as both the AI builder community and model sizes are growing rapidly, the community of people using distributed techniques for inference, fine-tuning and training is increasing exponentially as well making distributed training setups more and more common. Diving deeper into all things distributed might thus prove very timely.
+
You might still believe that this knowledge is a bit niche and only concerns the small set of people that pretrain LLMs. Historically, that may have been true, but as both the AI builder community and model sizes are growing rapidly, the community of people using distributed techniques for inference, fine-tuning, and training is increasing exponentially as well, making distributed training setups more and more common. Diving deeper into all things distributed might thus prove very timely.
-
This has been a long learning journey, but not just for you! Running thousands of benchmarks on a GPU cluster was more challenging than we anticipated and we want to share a few highlights of our own learning experience as well.
+
This has been a long learning journey, and not just for you! Running thousands of benchmarks on a GPU cluster was more challenging than we anticipated, and we wanted to share a few highlights of our own learning experience as well.
So, what’s next?
-
You now have good overview of the main distributed training concepts but at the same time we just scratched to surface of several of these tools and techniques. There are many ways to dive deep into a subject but here are some steps that we recommend:
+
You now have a good overview of the main distributed training concepts, but we just scratched the surface of several of these tools and techniques. There are many ways to dive deeper into a given subject, but here are some steps that we recommend:
- Carefully read some of the landmark or very recent papers. You can find a very extenside list of the most impactful papers, blog posts and books in References .
- Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.
- Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!
+ Carefully read some of the landmark or very recent papers. You can find an extensive list of the most impactful papers, blog posts, and books we know of in the References .
+ Start from scratch and implement an algorithm yourself. Often, a method only fully “clicks” if you've actually implemented it.
+ Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get into any ML field!
-
We hope this book helps you get started in distributed training and that you will train the next generation of awesome models to the hum of your GPU cluster!
+
We hope this book helps you get started with distributed training, and that you will train the next generation of awesome models to the hum of your GPU cluster! May the force of open source and open science always be with you.
-
+
+
-
Acknowledgements
+
Acknowledgments
We thank Elie for conducting thorough reviews and creating the audio components using NotebookLM. Special thanks to Hynek for optimizing the frontend performance. We also thank Simon for resolving some issues on the hub.
Discussion page
-
If you want to discuss the content of this blog post, ask questions, propose changes or just say hi, please open a thread on the discussion page .
+
If you want to discuss the content of this book, ask questions, propose changes, or just say hi, please open a thread on the discussion page .
References
-
Landmark LLM Scaling Papers
+
Landmark LLM scaling papers
Megatron-LM
@@ -2768,43 +2763,43 @@
Llama 3
-
The Llama 3 Herd of Models
+
Introduces the Llama 3 herd of models.
DeepSeek-V3
-
DeepSeek's report on architecture and training of the DeepSeek-V3 model.
+
DeepSeek's report on the architecture and training of the DeepSeek-V3 model.
-
Training Frameworks
+
Training frameworks
Nanotron
-
Our framework for training large language models featuring various parallelism strategies
+
Our framework for training large language models, featuring various parallelism strategies.
Megatron-LM
-
NVIDIA's framework for training large language models featuring various parallelism strategies.
+
NVIDIA's framework for training large language models, featuring various parallelism strategies.
DeepSpeed
-
Microsoft's deep learning optimization library featuring ZeRO optimization stages and various parallelism strategies.
+
Microsoft's deep learning optimization library, featuring ZeRO optimization stages and various parallelism strategies.
FairScale
-
PyTorch extension library for large-scale training, offering various parallelism and optimization techniques.
+
A PyTorch extension library for large-scale training, offering various parallelism and optimization techniques.
-
ColossalAI
-
Integrated large-scale model training system with various optimization techniques.
+
Colossal-AI
+
An integrated large-scale model training system with various optimization techniques.
@@ -2815,12 +2810,12 @@
LitGPT
-
Lightning AI's implementation of state-of-the-art open-source LLMs with focus on reproducibility.
+
Lightning AI's implementation of 20+ state-of-the-art open source LLMs, with a focus on reproducibility.
-
DiLoco
-
Training language models across compute clusters with DiLoCo.
+
OpenDiLoCo
+
An open source framework for training language models across compute clusters with DiLoCo.
@@ -2830,7 +2825,7 @@
OSLO
-
OSLO: Open Source for Large-scale Optimization.
+
The Open Source for Large-scale Optimization framework for large-scale modeling.
Debugging
@@ -2847,15 +2842,15 @@
-
Distribution Techniques
+
Distribution techniques
Data parallelism
@@ -2864,7 +2859,7 @@
ZeRO
-
Introduces Zero Redundancy Optimizer for training large models with memory optimization.
+
Introduces the Zero Redundancy Optimizer for training large models with memory optimization.
@@ -2873,7 +2868,7 @@
@@ -2883,28 +2878,28 @@
-
Ring-flash-attention
-
Implementation of ring attention mechanism combined with flash attention for efficient training.
+
Ring Flash Attention
+
Implementation of the Ring Attention mechanism combined with FlashAttention for efficient training.
ZeRO and 3D
-
DeepSpeed's guide to understanding tradeoffs between ZeRO and 3D parallelism strategies.
+
DeepSpeed's guide to understanding the trade-offs between ZeRO and 3D parallelism strategies.
@@ -2913,7 +2908,7 @@
@@ -2921,34 +2916,34 @@
Hardware
Others
@@ -2959,12 +2954,12 @@
@@ -2973,13 +2968,13 @@
@@ -2989,15 +2984,15 @@
thonking.ai
-
Some of Horace He's blogposts - Making GPUs go BRRR..
+
Some of Horace He's blog posts.
@@ -3006,27 +3001,27 @@
A0: Parallel Programming Crash Course
-
Throughout the blogpost we scale LLM training from one to hundreds of GPUs. This will require the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called collective operations . In this section we’ll do a small crash course of all the operations like Broadcast, AllReduce, Scatter and more. Let’s dive in!
+
Throughout this book, we've scaled LLM training from one to hundreds of GPUs. This requires the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called collective operations . In this section, we’ll do a small crash course on those operations - Broadcast , AllReduce , Scatter , and more. Let’s dive in!
-
The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
+
The general setup is that we have a number of independent nodes, which could be CPU cores, GPUs, or compute nodes. Each performs some computation, and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1 ).
-
Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with root
that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.
+
Maybe we need to send the result from one node to all other nodes, or to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with root , that is the target or source of some operations. Let’s start with one of the simplest primitives: a Broadcast operation.
Broadcast
-
A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:
+
A very common pattern is that you have some data on node 1 and you want to share it with all the other nodes so they can do some computation with the data. The Broadcast operation does just that:
-
Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with dist.initi_process_group
which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with dist.get_rank
). Finally, it establishes a connection between the workers.
+
Collective operations are provided natively by PyTorch, so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with dist.initi_process_group
, which sets up the communication backend (we’ll talk about NCCL later). It determines how many workers (a.k.a. nodes) exist and assigns a rank to each one (which we can get with dist.get_rank
). Finally, it establishes a connection between the workers.
-
To showcase the dist.broadcast
operation, let's create a tensor with non-zero values on rank=0
and tensors full of zeros on the other workers. We then distribute the rank=0
tensor to all other ranks with dist.broadcast(tensor, src=0)
:
+
To showcase the dist.broadcast
operation, let's create a tensor with nonzero values on rank=0
and tensors full of zeros on the other workers. We then distribute the rank=0
tensor to all other ranks with dist.broadcast(tensor, src=0)
:
import torch
@@ -3046,11 +3041,11 @@
print(f"After broadcast on rank {dist.get_rank()}: {tensor}")
init_process()
- example_broadcast()
+ example_broadcats()
-
You can run the above script with torchrun --nproc_per_node=3 dist_op.py
(you’ll need 3 GPUs for this or change nproc_per_node
accordingly) and you should see the following output:
+
You can run the above script with torchrun --nproc_per_node=3 dist_op.py
(you’ll need three GPUs for this, or change nproc_per_node
accordingly), and you should see the following output:
Before broadcast on rank 0: tensor([1., 2., 3., 4., 5.], device='cuda:0')
@@ -3062,20 +3057,20 @@
After broadcast on rank 2: tensor([1., 2., 3., 4., 5.], device='cuda:2')
-
Great, seems like it works as expected. Note that the rank messages can be printed out of order as we have no control over which print statement is executed first (we ordered them here for readability). Now let’s move on to the Reduce and AllReduce patterns!
+
Great, seems like it works as expected. Note that the rank messages can be printed out of order, as we have no control over which print statement is executed first (we ordered them here for readability). Now let’s move on to the Reduce and AllReduce patterns!
Reduce & AllReduce
-
Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function f()
which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
+
Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function f()
, which may perform, for instance, summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcast to all nodes:
-
Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.
+
Of course, there's no magic “free-flying” node that can perform this operation itself; generally, each node does a partial computation, with the nodes organized in a ring or tree structure. Here’s a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbor, which adds its number to the received number before forwarding it to the next neighbor. At the end of a round through the ring of nodes, the first node will receive the total sum.
-
Here’s the code to run a simple Reduce operation summing the tensors, we specify the operation to use with op=dist.ReduceOp.SUM
(you can find more information on the supported operations in the Pytorch docs ):
+
Here’s the code to run a simple Reduce operation summing the tensors. We specify the operation to use with op=dist.ReduceOp.SUM
(you can find more information on the supported operations in the PyTorch docs ):
def example_reduce():
@@ -3088,7 +3083,7 @@
example_reduce()
-
Note that in the Reduce operation only the tensor on the dst
node is updated:
+
Note that in the Reduce operation, only the tensor on the dst
node is updated:
Before reduce on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
@@ -3100,7 +3095,7 @@
After reduce on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
-
Similarly we can perform an AllReduce (we don’t need to specify a destination in this case):
+
Similarly, we can perform an AllReduce as follows (we don’t need to specify a destination in this case):
def example_all_reduce():
@@ -3113,7 +3108,7 @@
example_all_reduce()
-
In this case the result is available on all nodes:
+
In this case, the result is available on all nodes:
Before all_reduce on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
@@ -3125,17 +3120,17 @@
After all_reduce on rank 2: tensor([6., 6., 6., 6., 6.], device='cuda:2')
-
Now let’s turn to our next distributed communication operation. In many real cases, each node individually perform many complex computations and we need to share the final results among nodes. Gather and AllGather are the operations we want to use in this case. Let’s take a look!
+
Now let’s turn to our next distributed communication operation. In many real cases, each node individually performs many complex computations and we need to share the final results among all nodes. Gather and AllGather are the operations we want to use in this case. Let’s take a look!
Gather & AllGather
-
Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
-
+
Gather and AllGather are quite similar to the Broadcast operation in that they allow distributing data among nodes without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes; instead, each node has an individual chunk of data, and we want to either gather all the data on one node (in the case of Gather) or gather all the data on all nodes (in the case of AllGather). A picture being worth a thousand words, let’s take a look:
+
-
Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
+
The dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
-
In the case of the gather operation we need to prepare a container objects where the gathered tensors can be stored in this example the gather_list
:
+
In the case of the Gather operation, we need to prepare a container object where the gathered tensors can be stored - in this example, the gather_list
object:
def example_gather():
@@ -3156,7 +3151,7 @@
example_gather()
-
And we see that the `gather_list` indeed contains the tensors of all ranks:
+
And we see that gather_list
indeed contains the tensors of all ranks:
Before gather on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
@@ -3185,7 +3180,7 @@
example_all_gather()
-
And indeed we can see that now each node has all the data:
+
Here, we see that each node now has all the data:
Before all_gather on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
@@ -3203,17 +3198,17 @@
tensor([3., 3., 3., 3., 3.], device='cuda:2')]
-
Now what about the inverse of a gather? In this case we would have all the data on one node and want to distribute/slice it among node, possibly with some intermediate processing? We can use the Scatter, or in the case of an operation in between a Reduce Scatter pattern:
+
What about the inverse of a Gather? In this case, we have all the data on one node and want to distribute/slice it among nodes, possibly with some intermediate processing. We can use the Scatter or, in the case where an operation is performed on the data before distributing it, ReduceScatter pattern for this.
Scatter & ReduceScatter
-
As the name subtly suggests, the goal of the Scatter operation is to take data on one node and distribute slices of it to all other nodes. It’s thus different from the Broadcast operation which copy data without slicing and it’s the logical the inverse of the Gather operation.
+
As the name suggests, the goal of the Scatter operation is to take data on one node and scatter it across all the nodes, which it does by distributing a slice of the data to each node. It’s thus different from the Broadcast operation, which sends each node a complete copy of the data without slicing it, and it’s the logical inverse of the Gather operation.
-
The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
+
The ReduceScatter pattern is slightly more complex. As in the AllReduce case , you apply an operation on the data from all nodes. But unlike AllReduce where each node receives the full output tensor, in ReduceScatter each node only receives a slice of the output tensor. The following image illustrates the difference between these operations:
-
The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the src
:
+
The Scatter operation is written in code as the opposite of Gather: instead of preparing a list of tensors as a target, we prepare the source data as a list of tensors we want to distribute. We also need to specify the src
:
def example_scatter():
@@ -3234,7 +3229,7 @@
example_scatter()
-
As a result we can see how the empty tensors got filled with the contents of the scatter_list
+
As a result, the empty tensors get filled with the contents of scatter_list
Rank 0: Tensor to scatter: [tensor([1., 1., 1., 1., 1.], device='cuda:0'),
@@ -3249,7 +3244,7 @@
After scatter on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
-
Let’s create more interesting data to demonstrate the ReduceScatter logic: on each node we create a list of 2-elements vector on each node with a power exponent and an offset function of the node rank (it’s a bit hard to imagine so just look below for an example):
+
Let’s create some more interesting data to demonstrate the ReduceScatter logic. On each node, we'll create a list of two-element vectors with a power exponent and an offset function of the node rank:
def example_reduce_scatter():
@@ -3268,47 +3263,46 @@
example_reduce_scatter()
-
Let’s print the pattern of data that we created. We also immediately see the ReduceScatter pattern: the first rank received the sum of the first tensor from each node, and the second rank contains the sum of the second tensor on each node and so on:
+
The print statements reveal the pattern of data that we created. We also immediately see the ReduceScatter pattern in action - the first rank received the sum of the first tensor from each node, the second rank the sum of the second tensor on each node, and so on:
Before ReduceScatter on rank 0: [tensor([1., 2.], device='cuda:0'),
- tensor([1., 4.], device='cuda:0'),
- tensor([1., 8.], device='cuda:0')]
+ tensor([1., 4.], device='cuda:0'),
+ tensor([1., 8.], device='cuda:0')]
Before ReduceScatter on rank 1: [tensor([2., 4.], device='cuda:1'),
- tensor([ 4., 16.], device='cuda:1'),
- tensor([ 8., 64.], device='cuda:1')]
+ tensor([4., 16.], device='cuda:1'),
+ tensor([8., 64.], device='cuda:1')]
Before ReduceScatter on rank 2: [tensor([3., 6.], device='cuda:2'),
- tensor([ 9., 36.], device='cuda:2'),
- tensor([ 27., 216.], device='cuda:2')]
+ tensor([9., 36.], device='cuda:2'),
+ tensor([27., 216.], device='cuda:2')]
- After ReduceScatter on rank 0: tensor([ 6., 12.], device='cuda:0')
+ After ReduceScatter on rank 0: tensor([6., 12.], device='cuda:0')
After ReduceScatter on rank 1: tensor([14., 56.], device='cuda:1')
- After ReduceScatter on rank 2: tensor([ 36., 288.], device='cuda:2')
+ After ReduceScatter on rank 2: tensor([36., 288.], device='cuda:2')
-
Let's have a quick look at a common implementation of AllReduce that uses ReduceScatter and AllGather: Ring AllReduce.
+
Next, let's have a quick look at a common implementation of AllReduce that uses ReduceScatter and AllGather: Ring AllReduce.
-
A quick focus on Ring AllReduce
+
Ring AllReduce
-
Ring AllReduce is one specific implementation of AllReduce, optimized for scalability. Rather than all devices communicating with each other directly, which could create communication bottlenecks, Ring All-Reduce can be broken down into two key steps: ReduceScatter and AllGather. Here's how it works:
+
Ring AllReduce is a specific implementation of AllReduce optimized for scalability. Rather than all devices communicating with each other directly, which could create communication bottlenecks, Ring AllReduce can be broken down into two key steps: ReduceScatter and AllGather. Here's how it works:
ReduceScatter
- Each device splits its data (e.g., gradients) into chunks and sends one chunk to its neighbour. Simultaneously, each device receives a chunk from its other neighbour.
+ Each device splits its data (e.g., gradients) into N chunks (where N is the number of GPUs) and sends one chunk to its neighbor. Simultaneously, each device receives a chunk from its other neighbor.
As each device receives a chunk, it adds (reduces) its corresponding chunk to the received one.
- This process continues around the ring until each device holds a partially reduced chunk, representing a sum of the gradients across all devices for that chunk.
+ This process continues around the ring until each device holds a fully reduced chunk representing the sum of the gradients across all devices for that chunk.
AllGather
- Now, each device needs to collect the fully reduced chunks from other devices.
- The devices start sending their reduced chunks to neighbours.
- Each device forwards the chunks it receives until every device has all the fully reduced chunks, giving each device the complete, summed-up gradient.
-
+ Now, each device needs to collect the fully reduced chunks from the other devices.
+ Each device sends its reduced chunk to its neighbor, and receives the reduced chunk from its other neighbor.
+ The devices continue forwarding the chunks they receive until every device has all the fully reduced chunks, giving each device the complete, summed-up gradient.
-
Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first animation shows the ReduceScatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (orange rectangle).
+
Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first animation shows the ReduceScatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (the orange rectangle).
@@ -3316,26 +3310,26 @@
-
You may have noticed that each of the N GPUs sends and receives values N-1 times during both the reduce-scatter and all-gather steps. Each GPU sends \frac{K}{N} values per transfer, where K is the total number of values in the array being summed across the GPUs. Therefore, the total amount of data transferred to and from each GPU is 2 \times (N-1) \times \frac{K}{N} . When N (the number of GPUs) is large, the total amount of data transferred to and from each GPU is approximately 2 \times K , where K is the total number of parameters.
+
You may have noticed that each of the N GPUs sends and receives values N-1 times during both the ReduceScatter and AllGather steps. Each GPU sends \frac{K}{N} values per transfer, where K is the total number of values in the array being summed across the GPUs. Therefore, the total amount of data transferred to and from each GPU is 2 \times (N-1) \times \frac{K}{N} . When N (the number of GPUs) is large, the total amount of data transferred to and from each GPU is approximately 2 \times K , where K is the total number of parameters.
There are two key things to keep in mind for AllReduce:
- The communication cost for AllReduce is approximately 2xK when N (the number of GPUs) is large.
- An AllReduce operation can be broken down into a reduce-scatter followed by an all-gather. The communication cost for these two operations is half that of the AllReduce, which is approximately K .
+ The communication cost for AllReduce is approximately 2 \times K when N (the number of GPUs) is large.
+ An AllReduce operation can be broken down into a ReduceScatter followed by an AllGather. The communication cost for these two operations is half that of the AllReduce, which is approximately K .
-
As we can see this implementation can make efficient use of even a limited bandwidth between nodes.
+
As you can see, this implementation can make efficient use of even the limited bandwidth between nodes.
-
We now have seen the main building block of distributed operations but before we see them in action let’s have a look at a special operation used for synchronization: the Barrier.
+
You've now seen the main building blocks of distributed operations - but before we see them in action, let’s have a look at a special operation used for synchronization: the Barrier operation.
Barrier
-
The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:
+
Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Only then are the nodes allowed to continue with further computations:
-
We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:
+
We can easily simulate delayed nodes by setting up a different sleep time on each node and seeing how long it takes for all of them to pass the barrier:
def example_barrier():
@@ -3350,7 +3344,7 @@
example_barrier()
-
We can see that although the first rank didn’t sleep at all it also took it 2sec to pass the barrier:
+
We can see that although the first rank didn’t sleep at all, it also took it 2 seconds to pass the barrier:
Rank 0 sleeps 0 seconds.
@@ -3362,15 +3356,15 @@
Rank 2 after barrier time delta: 2.0024
-
We need to be careful with synchronizing all nodes like this, as this defeat the purpose of parallel independent operations and might thus slow down the whole processing. In many situations it can be just fine if a fast node already starts processing the next job as the fast node could be slower in a next iteration therefore evening out the delay over the whole process.
+
We need to be careful with synchronizing all nodes like this, as this defeats the purpose of parallel independent operations and might thus slow down the processing as a whole. In many situations, it can be just fine if a fast node starts processing the next job ahead of the others, as the fast node could be slower in the next iteration, thereby evening out the delay over the whole process.
-
Before turning to practical distributed training implementations, let’s first solve a mystery: what the heck is NCCL?
+
Before turning to practical distributed training implementations, let’s first solve a mystery: What the heck is NCCL?
-
NCCL: NVIDIA Collective Communications Library
+
NCCL
-
When training large models on many GPUs we may sometimes strike gold but we will always encounter nickel (or NCCL 🥁)! What’s is that?
+
When training large models on many GPUs, we may sometimes strike gold, but we will always encounter nickel (or NCCL 🥁)! What’s that?
-
There are several libraries that implement collective communication and are support by PyTorch: there’s the classic MPI (Message Passing Interface), there’s Gloo by Meta, and finally there is `NCCL` (NVIDIA Collective Communications Library). They all provide similar functionality in terms of collective communication patterns but are optimized for different hardware setups; NCCL is designed to serve GPU-GPU communication efficiently while MPI and Gloo are setup for CPU-CPU or CPU-GPU communication. PyTorch provides a great guide to decide which one to use:
+
There are several libraries that implement collective communication and are supported by PyTorch: there’s the classic MPI (Message Passing Interface), Gloo by Meta, and finally NCCL (the NVIDIA Collective Communications Library). They all provide similar functionality in terms of collective communication patterns but are optimized for different hardware setups - NCCL is designed to serve GPU-GPU communication efficiently, while MPI and Gloo are set up for CPU-CPU or CPU-GPU communication. PyTorch provides a great guide to decide which one to use, but here's what it boils down to:
GPU training: use NCCL
@@ -3378,24 +3372,22 @@
There are a few finer points in the decision tree that we leave to the reader to explore in the PyTorch guide referenced above.
-
-
Now that we covered the fundamental operations for distributed training and you should now be ready to follow the blog post easily.
A1: Distributed Training Profiling
Kernels
-
Let's begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the Layer Normalization function implemented in PyTorch as torch.nn.functional.layer_norm
. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python time
module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.
+
Let's begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the layer normalization function implemented in PyTorch as torch.nn.functional.layer_norm
. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python time
module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.
To address this, we can utilize torch.cuda.Event
for accurate timing and employ the torch.cuda.synchronize()
directive to ensure we wait for the kernel execution to complete. This approach is demonstrated in the following snippet:
def profile_pytorch(func, input):
- # Create CUDA events to track time. CUDA operations are asynchronous,
+ # Create CUDA events to track time. CUDA operations are asynchronous.
start = torch.cuda.Event(enable_timing=True) # Event to mark the start time
end = torch.cuda.Event(enable_timing=True) # Event to mark the end time
- # Warmup to eliminate any overhead from the first run, which might not reflect
- # the actual performance.
+ # Warm up to eliminate any overhead from the first run, which might not reflect
+ # the actual performance
for _ in range(10):
func(input)
# Record the start time before executing the function
@@ -3404,13 +3396,13 @@
# Record the end time after the function has completed
end.record()
# Synchronize the CUDA operations to ensure all operations are completed
- # before measuring the elapsed time.
+ # before measuring the elapsed time
torch.cuda.synchronize()
- # Calculate and return the elapsed time in milliseconds.
+ # Calculate and return the elapsed time in milliseconds
return start.elapsed_time(end)
-
A more effective approach to profiling is to utilize the PyTorch Profiler, as explained previously. For example, consider the following code:
+
A more efficient approach to profiling is to utilize the PyTorch profiler, as explained previously . For example, consider the following code:
import torch
@@ -3452,12 +3444,12 @@
-
You can also try to inspect the trace as we previously mentioned on chrome://tracing/
+
You can also try to inspect the trace, as we previously mentioned, on chrome://tracing/ .
💡 Tip
-
If you're new to this tool, you can navigate the trace by using the right and left arrow keys. Additionally, you can zoom in and out by holding the Alt key while scrolling left or right with your mouse.
+
If you're new to this tool, you can navigate the trace by using the right and left arrow keys. Additionally, you can zoom in and out by holding down the Alt key while scrolling left or right with your mouse.
@@ -3469,7 +3461,7 @@
-
The sequence begins in the CPU (the upper section) with aten::layer_norm
, progressing to aten::native_layer_norm
, and then transitioning to cudaLaunchKernel
. From there, we move on to the GPU, where the vectorized_layer_norm_kernel
kernel is called.
+
The sequence begins in the CPU (the upper section) with aten::layer_norm
, progressing to aten::native_layer_norm
and then transitioning to cudaLaunchKernel
. From there, we move on to the GPU, where the vectorized_layer_norm_kernel
kernel is called.
📝 Note
@@ -3478,19 +3470,19 @@
-
While the PyTorch Profiler offers a quick performance overview, NVIDIA Nsight Compute (ncu) provides deeper insights into GPU performance, including detailed execution times and memory usage for each kernel. To run the profiler it's very simple:
+
While the PyTorch profiler offers a quick performance overview, the NVIDIA Nsight Compute CLI (ncu
) provides deeper insights into GPU performance, including detailed execution times and memory usage for each kernel. Running the profiler is simple:
ncu --set full python layer_norm.py
-
Where layer_norm.py
is a straightforward file that executes the layer normalization function. This command will generate log outputs, but a more effective way to visualize the results is by setting the output flag:
+
where layer_norm.py is a straightforward file that executes the layer normalization function. This command will generate log output, but a more effective way to visualize the results is by setting the output flag:
ncu --set full -o output python layer_norm.py
-
and open the file output.ncu-rep
with Nsight Compute, you will have a view that looks like this:
+
If you then open the file output.ncu-rep with Nsight Compute, you will have a view that looks like this, with clear warnings about compute and memory utilization and tips on how to make the kernel better at balancing compute and memory and achieve maximal occupancy:
-
With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.
-
-
CPP extension
+
CPP extension
-
If the kernel you want to profile isn't already integrated into PyTorch, you can use PyTorch's cpp_extension
module to easily compile and run custom CUDA code. The process is straightforward—just create your CUDA kernel in a .cu
file, and use the load
function from the cpp_extension
module to load it in Python.
+
If the kernel you want to profile isn't already integrated into PyTorch, you can use PyTorch's cpp_extension
module to easily compile and run custom CUDA code. The process is straightforward — just create your CUDA kernel in a .cu file, and use the load
function from the cpp_extension
module to load it in Python.
-
The .cu
file would like this for a simple add
kernel:
+
The .cu file would like this for a simple add
kernel:
#include
@@ -3529,7 +3519,7 @@
}
-
And the python file to load the kernel:
+
And here's the Python file to load the kernel:
import torch
@@ -3556,20 +3546,20 @@
A2: Typical Scales in LLM Training
- Let's get a feel for the typical sizes of things in LLM training. When we talk about memory or compute, we're often counting "elements" - think of these as numbers in tensors. To get the actual memory in bytes, you'll need to multiply by the size of each number (e.g., 2 bytes for bf16, 4 bytes for fp32).
+ Let's get a feel for the typical sizes of things in LLM training. When we talk about memory or compute, we're often counting "elements" - think of these as numbers in tensors. To get the actual memory in bytes, you'll need to multiply by the size of each number (e.g., 2 bytes for BF16, 4 bytes for FP32).
Here are some quick ballpark figures:
- Input tokens: For each batch, we process seq \cdot mbs tokens, where mbs is the micro batch size and seq is the sequence length.
+ Input tokens: For each batch, we process seq \cdot mbs tokens, where mbs is the micro-batch size and seq is the sequence length.
Activations (hidden states): For a single layer, the hidden state tensor is of size seq \cdot mbs \cdot h elements.
- Model weights and gradients: Each weight matrix in your model (like in linears) is about h^2 elements. This is per weight matrix. Gradients have the same size as weights.
+ Model weights and gradients: Each weight matrix in your model (e.g. linear layer) contains about h^2 elements. Gradients have the same size as weights.
- Optimizer states: For each weight matrix (of elements h^2 ), if you're using an optimizer like Adam with mixed precision training, it keeps momentum and variance states in fp32 precision (2 \cdot h^2 ), plus master weights in fp32 (h^2 ). So total optimizer states will be around (6 \cdot h^2 ) per weight matrix.
+ Optimizer states: For each weight matrix (of h^2 elements), an optimizer like Adam with mixed precision training will keep momentum and variance states in FP32 precision (2 \cdot h^2 ), plus master weights in FP32 (h^2 ). So, the total number of optimizer states will be around (6 \cdot h^2 ) per weight matrix.
- Total model parameters: For each transformer block:
+ Total model parameters: Each transformer block will store:
Attention parameters:
@@ -3577,7 +3567,7 @@
Output projection: h^2 parameters
- MLP parameters with GLU:
+ MLP parameters with Gated Linear Units (GLU):
Gate and up projections: 8h^2 parameters (2 matrices of size h \times 4h )
Down projection: 4h^2 parameters (1 matrix of size 4h \times h )
@@ -3595,21 +3585,21 @@
- Forward and backward pass compute (FLOPs): A very rough estimate for the FLOPs in a forward pass is 2 \cdot num\_tokens \cdot num\_params . And backward pass compute is twice as that: 4 \cdot num\_tokens \cdot num\_params .
+ Forward and backward pass compute (FLOPS): A very rough estimate for the FLOPS in a forward pass is 2 \cdot num\_tokens \cdot num\_params . The backward pass compute is twice that: 4 \cdot num\_tokens \cdot num\_params .
A3: Math for Compute/Communication Overlap
- Using the formulas from the previous section, we can estimate when computation and communication can effectively overlap in distributed training. Let's look at data parallelism (Zero-0) as an example.
+ Using the formulas from the previous section, we can estimate when computation and communication can effectively overlap in distributed training. Let's look at data parallelism (ZeRO-0) as an example.
- Data Parallelism Communication Analysis
+ Data parallelism communication analysis
The total gradient size that needs to be communicated is:
Gradients = Parameters ≈ num\_layers \cdot 16h^2
- During backward pass, these gradients are communicated in buckets (default 25MB). The communication time to all-reduce each bucket is:
+ During the backward pass, these gradients are communicated in buckets (default size 25 MB). The communication time to all-reduce each bucket is:
t_{comm} = t_{comm\_bucket} = \frac{bucket\_size \cdot 2(DP-1)}{DP \cdot peak\_bw}
@@ -3622,7 +3612,7 @@
-
The computation time for backward pass is:
+
The computation time for the backward pass is:
t_{compute} = \frac{4 \cdot num\_tokens \cdot num\_params}{peak\_flops}
@@ -3636,33 +3626,33 @@
This ratio helps determine if communication will become a bottleneck in training. When the ratio is less than 1, communication can be fully overlapped with computation.
- ZeRO-3 (FSDP) Communication Analysis
+ ZeRO-3 (FSDP) communication analysis
For ZeRO-3, parameters and gradients are sharded across GPUs. Let's analyze the communication pattern for a model with transformer blocks of size 16h^2 parameters each:
- For each transformer block in forward pass:
+ For each transformer block in the forward pass:
- Allgather parameters: 16h^2/DP bytes per rank
+ All-gather parameters: 16h^2/DP bytes per rank
- For each transformer block in backward pass:
+ For each transformer block in the backward pass:
- Allgather parameters: 16h^2/DP bytes per rank
- Reducescatter gradients: 16h^2/DP bytes per rank
+ All-gather parameters: 16h^2/DP bytes per rank
+ Reduce-scatter gradients: 16h^2/DP bytes per rank
Total communication per block: 3 \cdot 16h^2/DP bytes
Total communication for full model: 3 \cdot num\_layers \cdot 16h^2/DP bytes
- The communication time for allgather operations is:
+ The communication time for all-gather operations is:
t_{comm} = 16h^2 \cdot \frac{DP-1}{DP \cdot peak\_bw}
- The computation time for forward pass of one decoder layer is:
+ The computation time for the forward pass of one decoder layer is:
t_{compute} = \frac{32 \cdot seq\_len \cdot mbs \cdot h^2}{peak\_flops}
@@ -3675,32 +3665,32 @@
When this ratio is less than 1, the communication of parameters for the next layer can be hidden behind the computation of the current layer.
`
- TP Communication Analysis
+ TP communication analysis
- For Tensor Parallel (TP), activations are sharded across GPUs during linears. Let's analyze the communication pattern:
+ For tensor parallelism, activations are sharded across GPUs in the TP regions (e.g. MLP block). Let's analyze the communication pattern:
- For each column linear in forward pass:
+ For each column-linear operation in the forward pass:
- Allgather activations: seq \cdot mbs \cdot h/TP bytes per rank
+ All-gather activations: seq \cdot mbs \cdot h/TP bytes per rank
- For each column linear in backward pass:
+ For each column-linear operation in the backward pass:
- Reducescatter gradients: seq \cdot mbs \cdot h/TP bytes per rank
+ Reduce-scatter gradients: seq \cdot mbs \cdot h/TP bytes per rank
- And vice-versa for row linears. Each transformer block has 2 column linears and 2 row linears.
+ And vice versa for row-linear operations. Each transformer block has 2 column-linear and 2 row-linear operations.
Total communication per block: 8 \cdot seq \cdot mbs \cdot h/TP bytes
Total communication for full model: 8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP bytes
- Let's analyze if we can overlap the allgather communication for one layer with the computation of the next linear. The communication time for allgather operations is:
+ Let's take a TP region within a layer and analyze if we can overlap the all-gather communication with the computation of the next linear. The communication time for all-gather operations is:
t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
- While the computation time for the next linear (with parameters h^2 ) is:
+ While the computation time for the next linear layer (with parameters h^2 ) is:
t_{compute} = \frac{2 \cdot seq \cdot mbs \cdot h^2}{TP \cdot peak\_flops}
@@ -3711,26 +3701,26 @@
\frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
- This ratio tells us whether we can successfully hide the allgather communication behind the computation of the next linear. Interestingly, the ratio only depends on the hidden size h and tensor parallelism degree TP, not on sequence length or batch size.
+ This ratio tells us whether we can successfully hide the all-gather communication behind the computation of the next linear. Interestingly, the ratio only depends on the hidden size h and tensor parallelism degree tp , not on sequence length or batch size.
- PP Communication Analysis
+ PP communication analysis
- For Pipeline Parallel (PP), activations and gradients are communicated between pipeline stages. Let's analyze the communication pattern:
+ For pipeline parallelism, activations and gradients are communicated between pipeline stages. Let's analyze the communication pattern:
- For each microbatch in forward pass:
+ For each micro-batch in the forward pass:
Receive and send activations: 2 \cdot seq \cdot mbs \cdot h bytes
- For each microbatch in backward pass:
+ For each micro-batch in the backward pass:
Receive and send gradients: 2 \cdot seq \cdot mbs \cdot h bytes
- Total communication per microbatch: 4 \cdot seq \cdot mbs \cdot h bytes
- For gradient accumulation steps (gas), total communication: 4 \cdot gas \cdot seq \cdot mbs \cdot h bytes
+ Total communication per micro-batch: 4 \cdot seq \cdot mbs \cdot h bytes
+ For gradient accumulation steps (gas ), total communication: 4 \cdot gas \cdot seq \cdot mbs \cdot h bytes
Let's analyze if we can overlap the communication of activations/gradients with computation of the next transformer block. The computation time for transformer blocks in the next pipeline stage is:
@@ -3751,7 +3741,7 @@
\frac{t_{comm}}{t_{compute}} = \frac{peak\_flops}{32 \cdot h \cdot num\_layers\_in\_next\_pp \cdot peak\_bw} \leq 1
-
Similar to TP, this ratio is independent of sequence length and batch size. It depends on the hidden size h, number of layers in the next pipeline stage, and the ratio of compute to P2P bandwidth capabilities of the hardware.
+
As with TP, this ratio is independent of sequence length and batch size. It depends on the hidden size h , the number of layers in the next pipeline stage, and the ratio of compute to P2P bandwidth capabilities of the hardware.