danieldk HF staff commited on
Commit
132e594
·
1 Parent(s): 3ac7aee

Port vLLM attention kernels

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,7 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ## attention
6
+
7
+ Paged attention kernels from [vLLM](https://github.com/vllm-project/).
build.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "attention"
6
+ src = [
7
+ "torch-ext/registration.h",
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h"
10
+ ]
11
+ pyroot = "torch-ext"
12
+
13
+ [kernel.cuda_utils]
14
+ capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
15
+ src = [
16
+ "cuda-utils/cuda_utils_kernels.cu",
17
+ ]
18
+ depends = []
19
+
20
+
21
+ [kernel.paged_attention]
22
+ capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
23
+ src = [
24
+ "paged-attention/attention/attention_dtypes.h",
25
+ "paged-attention/attention/attention_generic.cuh",
26
+ "paged-attention/attention/attention_kernels.cuh",
27
+ "paged-attention/attention/attention_utils.cuh",
28
+ "paged-attention/attention/dtype_bfloat16.cuh",
29
+ "paged-attention/attention/dtype_float16.cuh",
30
+ "paged-attention/attention/dtype_float32.cuh",
31
+ "paged-attention/attention/dtype_fp8.cuh",
32
+ "paged-attention/attention/paged_attention_v1.cu",
33
+ "paged-attention/attention/paged_attention_v2.cu",
34
+ "paged-attention/cache_kernels.cu",
35
+ "paged-attention/cuda_compat.h",
36
+ "paged-attention/dispatch_utils.h",
37
+ "paged-attention/quantization/fp8/amd/hip_float8.h",
38
+ "paged-attention/quantization/fp8/amd/hip_float8_impl.h",
39
+ "paged-attention/quantization/fp8/amd/quant_utils.cuh",
40
+ "paged-attention/quantization/fp8/nvidia/quant_utils.cuh",
41
+ ]
42
+ include = [ "." ]
43
+ depends = [ "torch" ]
44
+
cuda-utils/cuda_utils_kernels.cu ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef USE_ROCM
2
+ #include <hip/hip_runtime.h>
3
+ #include <hip/hip_runtime_api.h>
4
+ #endif
5
+ int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
6
+ int device, value;
7
+ if (device_id < 0) {
8
+ cudaGetDevice(&device);
9
+ } else {
10
+ device = device_id;
11
+ }
12
+ cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
13
+ device);
14
+ return value;
15
+ }
16
+
17
+ int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
18
+ int64_t attribute;
19
+ // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
20
+ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
21
+
22
+ #ifdef USE_ROCM
23
+ attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
24
+ #else
25
+ attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
26
+ #endif
27
+
28
+ return get_device_attribute(attribute, device_id);
29
+ }
flake.nix ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for attention kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
14
+ }
paged-attention/attention/attention_dtypes.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "attention_generic.cuh"
4
+ #include "dtype_float16.cuh"
5
+ #include "dtype_float32.cuh"
6
+ #include "dtype_bfloat16.cuh"
7
+ #include "dtype_fp8.cuh"
paged-attention/attention/attention_generic.cuh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #pragma once
20
+
21
+ #include <stdint.h>
22
+
23
+ namespace vllm {
24
+
25
+ // A vector type to store Q, K, V elements.
26
+ template <typename T, int VEC_SIZE>
27
+ struct Vec {};
28
+
29
+ // A vector type to store FP32 accumulators.
30
+ template <typename T>
31
+ struct FloatVec {};
32
+
33
+ // Template vector operations.
34
+ template <typename Acc, typename A, typename B>
35
+ inline __device__ Acc mul(A a, B b);
36
+
37
+ template <typename T>
38
+ inline __device__ float sum(T v);
39
+
40
+ template <typename T>
41
+ inline __device__ float dot(T a, T b) {
42
+ return sum(mul<T, T, T>(a, b));
43
+ }
44
+
45
+ template <typename A, typename T>
46
+ inline __device__ float dot(T a, T b) {
47
+ return sum(mul<A, T, T>(a, b));
48
+ }
49
+
50
+ template <typename T>
51
+ inline __device__ void zero(T& dst) {
52
+ constexpr int WORDS = sizeof(T) / 4;
53
+ union {
54
+ T raw;
55
+ uint32_t words[WORDS];
56
+ } tmp;
57
+
58
+ #pragma unroll
59
+ for (int ii = 0; ii < WORDS; ++ii) {
60
+ tmp.words[ii] = 0u;
61
+ }
62
+ dst = tmp.raw;
63
+ }
64
+
65
+ } // namespace vllm
paged-attention/attention/attention_kernels.cuh ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+
20
+ #include <torch/all.h>
21
+ #include <ATen/cuda/CUDAContext.h>
22
+ #include <c10/cuda/CUDAGuard.h>
23
+ #include <algorithm>
24
+
25
+ #include "attention_dtypes.h"
26
+ #include "attention_utils.cuh"
27
+
28
+ #ifdef USE_ROCM
29
+ #include <hip/hip_bf16.h>
30
+ #include "../quantization/fp8/amd/quant_utils.cuh"
31
+ typedef __hip_bfloat16 __nv_bfloat16;
32
+ #else
33
+ #include "../quantization/fp8/nvidia/quant_utils.cuh"
34
+ #endif
35
+
36
+ #ifndef USE_ROCM
37
+ #define WARP_SIZE 32
38
+ #else
39
+ #define WARP_SIZE warpSize
40
+ #endif
41
+
42
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
43
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
44
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
45
+
46
+ namespace vllm {
47
+
48
+ // Utility function for attention softmax.
49
+ template <int NUM_WARPS>
50
+ inline __device__ float block_sum(float* red_smem, float sum) {
51
+ // Decompose the thread index into warp / lane.
52
+ int warp = threadIdx.x / WARP_SIZE;
53
+ int lane = threadIdx.x % WARP_SIZE;
54
+
55
+ // Compute the sum per warp.
56
+ #pragma unroll
57
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
58
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
59
+ }
60
+
61
+ // Warp leaders store the data to shared memory.
62
+ if (lane == 0) {
63
+ red_smem[warp] = sum;
64
+ }
65
+
66
+ // Make sure the data is in shared memory.
67
+ __syncthreads();
68
+
69
+ // The warps compute the final sums.
70
+ if (lane < NUM_WARPS) {
71
+ sum = red_smem[lane];
72
+ }
73
+
74
+ // Parallel reduction inside the warp.
75
+ #pragma unroll
76
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
77
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
78
+ }
79
+
80
+ // Broadcast to other threads.
81
+ return VLLM_SHFL_SYNC(sum, 0);
82
+ }
83
+
84
+ // TODO(woosuk): Merge the last two dimensions of the grid.
85
+ // Grid: (num_heads, num_seqs, max_num_partitions).
86
+ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
87
+ int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
88
+ bool IS_BLOCK_SPARSE,
89
+ int PARTITION_SIZE = 0> // Zero means no partitioning.
90
+ __device__ void paged_attention_kernel(
91
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
92
+ float* __restrict__ max_logits, // [num_seqs, num_heads,
93
+ // max_num_partitions]
94
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
95
+ // head_size]
96
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
97
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
98
+ // head_size/x, block_size, x]
99
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
100
+ // head_size, block_size]
101
+ const int num_kv_heads, // [num_heads]
102
+ const float scale,
103
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
104
+ const int* __restrict__ seq_lens, // [num_seqs]
105
+ const int max_num_blocks_per_seq,
106
+ const float* __restrict__ alibi_slopes, // [num_heads]
107
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
108
+ const float* k_scale, const float* v_scale, const int tp_rank,
109
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
110
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
111
+ const int seq_idx = blockIdx.y;
112
+ const int partition_idx = blockIdx.z;
113
+ const int max_num_partitions = gridDim.z;
114
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
115
+ const int seq_len = seq_lens[seq_idx];
116
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
117
+ // No work to do. Terminate the thread block.
118
+ return;
119
+ }
120
+
121
+ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
122
+ const int num_blocks_per_partition =
123
+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
124
+
125
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
126
+ const int start_block_idx =
127
+ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
128
+ const int end_block_idx =
129
+ MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
130
+ const int num_blocks = end_block_idx - start_block_idx;
131
+
132
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
133
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
134
+ const int end_token_idx =
135
+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
136
+ const int num_tokens = end_token_idx - start_token_idx;
137
+
138
+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
139
+ constexpr int NUM_THREAD_GROUPS =
140
+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
141
+ // divides NUM_THREADS
142
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
143
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP =
144
+ DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
145
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
146
+ const int thread_idx = threadIdx.x;
147
+ const int warp_idx = thread_idx / WARP_SIZE;
148
+ const int lane = thread_idx % WARP_SIZE;
149
+
150
+ const int head_idx = blockIdx.x;
151
+ const int num_heads = gridDim.x;
152
+ const int num_queries_per_kv = num_heads / num_kv_heads;
153
+ const int kv_head_idx = head_idx / num_queries_per_kv;
154
+ const float alibi_slope =
155
+ alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
156
+
157
+ // A vector type to store a part of a key or a query.
158
+ // The vector size is configured in such a way that the threads in a thread
159
+ // group fetch or compute 16 bytes at a time. For example, if the size of a
160
+ // thread group is 4 and the data type is half, then the vector size is 16 /
161
+ // (4 * sizeof(half)) == 2.
162
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
163
+ using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
164
+ using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
165
+ using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
166
+
167
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
168
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
169
+
170
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
171
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
172
+
173
+ // Load the query to registers.
174
+ // Each thread in a thread group has a different part of the query.
175
+ // For example, if the the thread group size is 4, then the first thread in
176
+ // the group has 0, 4, 8, ... th vectors of the query, and the second thread
177
+ // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
178
+ // q is split from a qkv tensor, it may not be contiguous.
179
+ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
180
+ __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
181
+ #pragma unroll
182
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
183
+ i += NUM_THREAD_GROUPS) {
184
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
185
+ q_vecs[thread_group_offset][i] =
186
+ *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
187
+ }
188
+ __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
189
+ // memory wall right before we use q_vecs
190
+
191
+ // Memory planning.
192
+ extern __shared__ char shared_mem[];
193
+ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
194
+ float* logits = reinterpret_cast<float*>(shared_mem);
195
+ // Workspace for reduction.
196
+ __shared__ float red_smem[2 * NUM_WARPS];
197
+
198
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
199
+ // Each thread group fetches x elements from the key at a time.
200
+ constexpr int x = 16 / sizeof(cache_t);
201
+ float qk_max = -FLT_MAX;
202
+
203
+ // Iterate over the key blocks.
204
+ // Each warp fetches a block of keys for each iteration.
205
+ // Each thread group in a warp fetches a key from the block, and computes
206
+ // dot product with the query.
207
+ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
208
+
209
+ // blocksparse specific vars
210
+ int bs_block_offset;
211
+ int q_bs_block_id;
212
+ if constexpr (IS_BLOCK_SPARSE) {
213
+ // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
214
+ // blocksparse_block_size);
215
+ q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
216
+ if (blocksparse_head_sliding_step >= 0)
217
+ // sliding on q heads
218
+ bs_block_offset =
219
+ (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
220
+ else
221
+ // sliding on kv heads
222
+ bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
223
+ (-blocksparse_head_sliding_step) +
224
+ 1;
225
+ }
226
+
227
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
228
+ block_idx += NUM_WARPS) {
229
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to
230
+ // int64 because int32 can lead to overflow when this variable is multiplied
231
+ // by large numbers (e.g., kv_block_stride).
232
+ // For blocksparse attention: skip computation on blocks that are not
233
+ // attended
234
+ if constexpr (IS_BLOCK_SPARSE) {
235
+ const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
236
+ const bool is_remote =
237
+ ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
238
+ const bool is_local =
239
+ (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
240
+ if (!is_remote && !is_local) {
241
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
242
+ const int physical_block_offset =
243
+ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
244
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
245
+
246
+ if (thread_group_offset == 0) {
247
+ // NOTE(linxihui): assign very large number to skipped tokens to
248
+ // avoid contribution to the sumexp softmax normalizer. This will
249
+ // not be used at computing sum(softmax*v) as the blocks will be
250
+ // skipped.
251
+ logits[token_idx - start_token_idx] = -FLT_MAX;
252
+ }
253
+ }
254
+ continue;
255
+ }
256
+ }
257
+ const int64_t physical_block_number =
258
+ static_cast<int64_t>(block_table[block_idx]);
259
+
260
+ // Load a key to registers.
261
+ // Each thread in a thread group has a different part of the key.
262
+ // For example, if the the thread group size is 4, then the first thread in
263
+ // the group has 0, 4, 8, ... th vectors of the key, and the second thread
264
+ // has 1, 5, 9, ... th vectors of the key, and so on.
265
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
266
+ const int physical_block_offset =
267
+ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
268
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
269
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
270
+
271
+ #pragma unroll
272
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
273
+ const cache_t* k_ptr =
274
+ k_cache + physical_block_number * kv_block_stride +
275
+ kv_head_idx * kv_head_stride + physical_block_offset * x;
276
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
277
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
278
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
279
+
280
+ if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
281
+ k_vecs[j] = *reinterpret_cast<const K_vec*>(
282
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
283
+ } else {
284
+ // Vector conversion from Quant_vec to K_vec.
285
+ Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
286
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
287
+ k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
288
+ k_vec_quant, *k_scale);
289
+ }
290
+ }
291
+
292
+ // Compute dot product.
293
+ // This includes a reduction across the threads in the same thread group.
294
+ float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
295
+ q_vecs[thread_group_offset], k_vecs);
296
+ // Add the ALiBi bias if slopes are given.
297
+ qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
298
+
299
+ if (thread_group_offset == 0) {
300
+ // Store the partial reductions to shared memory.
301
+ // NOTE(woosuk): It is required to zero out the masked logits.
302
+ const bool mask = token_idx >= seq_len;
303
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
304
+ // Update the max value.
305
+ qk_max = mask ? qk_max : fmaxf(qk_max, qk);
306
+ }
307
+ }
308
+ }
309
+
310
+ // Perform reduction across the threads in the same warp to get the
311
+ // max qk value for each "warp" (not across the thread block yet).
312
+ // The 0-th thread of each thread group already has its max qk value.
313
+ #pragma unroll
314
+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
315
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
316
+ }
317
+ if (lane == 0) {
318
+ red_smem[warp_idx] = qk_max;
319
+ }
320
+ __syncthreads();
321
+
322
+ // TODO(woosuk): Refactor this part.
323
+ // Get the max qk value for the sequence.
324
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
325
+ #pragma unroll
326
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
327
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
328
+ }
329
+ // Broadcast the max qk value to all threads.
330
+ qk_max = VLLM_SHFL_SYNC(qk_max, 0);
331
+
332
+ // Get the sum of the exp values.
333
+ float exp_sum = 0.f;
334
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
335
+ float val = __expf(logits[i] - qk_max);
336
+ logits[i] = val;
337
+ exp_sum += val;
338
+ }
339
+ exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
340
+
341
+ // Compute softmax.
342
+ const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
343
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
344
+ logits[i] *= inv_sum;
345
+ }
346
+ __syncthreads();
347
+
348
+ // If partitioning is enabled, store the max logit and exp_sum.
349
+ if (USE_PARTITIONING && thread_idx == 0) {
350
+ float* max_logits_ptr = max_logits +
351
+ seq_idx * num_heads * max_num_partitions +
352
+ head_idx * max_num_partitions + partition_idx;
353
+ *max_logits_ptr = qk_max;
354
+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
355
+ head_idx * max_num_partitions + partition_idx;
356
+ *exp_sums_ptr = exp_sum;
357
+ }
358
+
359
+ // Each thread will fetch 16 bytes from the value cache at a time.
360
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
361
+ using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
362
+ using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
363
+ using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
364
+ using Float_L_vec = typename FloatVec<L_vec>::Type;
365
+
366
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
367
+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
368
+ constexpr int NUM_ROWS_PER_THREAD =
369
+ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
370
+
371
+ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
372
+ float accs[NUM_ROWS_PER_THREAD];
373
+ #pragma unroll
374
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
375
+ accs[i] = 0.f;
376
+ }
377
+
378
+ scalar_t zero_value;
379
+ zero(zero_value);
380
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
381
+ block_idx += NUM_WARPS) {
382
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to
383
+ // int64 because int32 can lead to overflow when this variable is multiplied
384
+ // by large numbers (e.g., kv_block_stride).
385
+ // For blocksparse attention: skip computation on blocks that are not
386
+ // attended
387
+ if constexpr (IS_BLOCK_SPARSE) {
388
+ int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
389
+ if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
390
+ !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
391
+ continue;
392
+ }
393
+ }
394
+ const int64_t physical_block_number =
395
+ static_cast<int64_t>(block_table[block_idx]);
396
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
397
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
398
+ L_vec logits_vec;
399
+ from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
400
+ start_token_idx));
401
+
402
+ const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
403
+ kv_head_idx * kv_head_stride;
404
+ #pragma unroll
405
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
406
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
407
+ if (row_idx < HEAD_SIZE) {
408
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
409
+ V_vec v_vec;
410
+
411
+ if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
412
+ v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
413
+ } else {
414
+ V_quant_vec v_quant_vec =
415
+ *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
416
+ // Vector conversion from V_quant_vec to V_vec.
417
+ v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
418
+ *v_scale);
419
+ }
420
+ if (block_idx == num_seq_blocks - 1) {
421
+ // NOTE(woosuk): When v_vec contains the tokens that are out of the
422
+ // context, we should explicitly zero out the values since they may
423
+ // contain NaNs. See
424
+ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
425
+ scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
426
+ #pragma unroll
427
+ for (int j = 0; j < V_VEC_SIZE; j++) {
428
+ v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
429
+ }
430
+ }
431
+ accs[i] += dot(logits_vec, v_vec);
432
+ }
433
+ }
434
+ }
435
+
436
+ // Perform reduction within each warp.
437
+ #pragma unroll
438
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
439
+ float acc = accs[i];
440
+ #pragma unroll
441
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
442
+ acc += VLLM_SHFL_XOR_SYNC(acc, mask);
443
+ }
444
+ accs[i] = acc;
445
+ }
446
+
447
+ // NOTE(woosuk): A barrier is required because the shared memory space for
448
+ // logits is reused for the output.
449
+ __syncthreads();
450
+
451
+ // Perform reduction across warps.
452
+ float* out_smem = reinterpret_cast<float*>(shared_mem);
453
+ #pragma unroll
454
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
455
+ int mid = i / 2;
456
+ // Upper warps write to shared memory.
457
+ if (warp_idx >= mid && warp_idx < i) {
458
+ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
459
+ #pragma unroll
460
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
461
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
462
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
463
+ dst[row_idx] = accs[i];
464
+ }
465
+ }
466
+ }
467
+ __syncthreads();
468
+
469
+ // Lower warps update the output.
470
+ if (warp_idx < mid) {
471
+ const float* src = &out_smem[warp_idx * HEAD_SIZE];
472
+ #pragma unroll
473
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
474
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
475
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
476
+ accs[i] += src[row_idx];
477
+ }
478
+ }
479
+ }
480
+ __syncthreads();
481
+ }
482
+
483
+ // Write the final output.
484
+ if (warp_idx == 0) {
485
+ scalar_t* out_ptr =
486
+ out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
487
+ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
488
+ #pragma unroll
489
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
490
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
491
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
492
+ from_float(*(out_ptr + row_idx), accs[i]);
493
+ }
494
+ }
495
+ }
496
+ }
497
+
498
+ // Grid: (num_heads, num_seqs, 1).
499
+ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
500
+ int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
501
+ bool IS_BLOCK_SPARSE>
502
+ __global__ void paged_attention_v1_kernel(
503
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
504
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
505
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
506
+ // head_size/x, block_size, x]
507
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
508
+ // head_size, block_size]
509
+ const int num_kv_heads, // [num_heads]
510
+ const float scale,
511
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
512
+ const int* __restrict__ seq_lens, // [num_seqs]
513
+ const int max_num_blocks_per_seq,
514
+ const float* __restrict__ alibi_slopes, // [num_heads]
515
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
516
+ const float* k_scale, const float* v_scale, const int tp_rank,
517
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
518
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
519
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
520
+ KV_DTYPE, IS_BLOCK_SPARSE>(
521
+ /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
522
+ v_cache, num_kv_heads, scale, block_tables, seq_lens,
523
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
524
+ kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
525
+ blocksparse_vert_stride, blocksparse_block_size,
526
+ blocksparse_head_sliding_step);
527
+ }
528
+
529
+ // Grid: (num_heads, num_seqs, max_num_partitions).
530
+ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
531
+ int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
532
+ bool IS_BLOCK_SPARSE,
533
+ int PARTITION_SIZE>
534
+ __global__ void paged_attention_v2_kernel(
535
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
536
+ float* __restrict__ max_logits, // [num_seqs, num_heads,
537
+ // max_num_partitions]
538
+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
539
+ // max_num_partitions, head_size]
540
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
541
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
542
+ // head_size/x, block_size, x]
543
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
544
+ // head_size, block_size]
545
+ const int num_kv_heads, // [num_heads]
546
+ const float scale,
547
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
548
+ const int* __restrict__ seq_lens, // [num_seqs]
549
+ const int max_num_blocks_per_seq,
550
+ const float* __restrict__ alibi_slopes, // [num_heads]
551
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
552
+ const float* k_scale, const float* v_scale, const int tp_rank,
553
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
554
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
555
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
556
+ KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
557
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
558
+ block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
559
+ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
560
+ blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
561
+ blocksparse_head_sliding_step);
562
+ }
563
+
564
+ // Grid: (num_heads, num_seqs).
565
+ template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
566
+ int PARTITION_SIZE>
567
+ __global__ void paged_attention_v2_reduce_kernel(
568
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
569
+ const float* __restrict__ exp_sums, // [num_seqs, num_heads,
570
+ // max_num_partitions]
571
+ const float* __restrict__ max_logits, // [num_seqs, num_heads,
572
+ // max_num_partitions]
573
+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
574
+ // max_num_partitions, head_size]
575
+ const int* __restrict__ seq_lens, // [num_seqs]
576
+ const int max_num_partitions) {
577
+ const int num_heads = gridDim.x;
578
+ const int head_idx = blockIdx.x;
579
+ const int seq_idx = blockIdx.y;
580
+ const int seq_len = seq_lens[seq_idx];
581
+ const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
582
+ if (num_partitions == 1) {
583
+ // No need to reduce. Only copy tmp_out to out.
584
+ scalar_t* out_ptr =
585
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
586
+ const scalar_t* tmp_out_ptr =
587
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
588
+ head_idx * max_num_partitions * HEAD_SIZE;
589
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
590
+ out_ptr[i] = tmp_out_ptr[i];
591
+ }
592
+ // Terminate the thread block.
593
+ return;
594
+ }
595
+
596
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
597
+ const int warp_idx = threadIdx.x / WARP_SIZE;
598
+ const int lane = threadIdx.x % WARP_SIZE;
599
+
600
+ // Size: 2 * num_partitions.
601
+ extern __shared__ char shared_mem[];
602
+ // Workspace for reduction.
603
+ __shared__ float red_smem[2 * NUM_WARPS];
604
+
605
+ // Load max logits to shared memory.
606
+ float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
607
+ const float* max_logits_ptr = max_logits +
608
+ seq_idx * num_heads * max_num_partitions +
609
+ head_idx * max_num_partitions;
610
+ float max_logit = -FLT_MAX;
611
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
612
+ const float l = max_logits_ptr[i];
613
+ shared_max_logits[i] = l;
614
+ max_logit = fmaxf(max_logit, l);
615
+ }
616
+ __syncthreads();
617
+
618
+ // Get the global max logit.
619
+ // Reduce within the warp.
620
+ #pragma unroll
621
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
622
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
623
+ }
624
+ if (lane == 0) {
625
+ red_smem[warp_idx] = max_logit;
626
+ }
627
+ __syncthreads();
628
+ // Reduce across warps.
629
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
630
+ #pragma unroll
631
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
632
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
633
+ }
634
+ // Broadcast the max value to all threads.
635
+ max_logit = VLLM_SHFL_SYNC(max_logit, 0);
636
+
637
+ // Load rescaled exp sums to shared memory.
638
+ float* shared_exp_sums =
639
+ reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
640
+ const float* exp_sums_ptr = exp_sums +
641
+ seq_idx * num_heads * max_num_partitions +
642
+ head_idx * max_num_partitions;
643
+ float global_exp_sum = 0.0f;
644
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
645
+ float l = shared_max_logits[i];
646
+ float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
647
+ global_exp_sum += rescaled_exp_sum;
648
+ shared_exp_sums[i] = rescaled_exp_sum;
649
+ }
650
+ __syncthreads();
651
+ global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
652
+ const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
653
+
654
+ // Aggregate tmp_out to out.
655
+ const scalar_t* tmp_out_ptr =
656
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
657
+ head_idx * max_num_partitions * HEAD_SIZE;
658
+ scalar_t* out_ptr =
659
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
660
+ #pragma unroll
661
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
662
+ float acc = 0.0f;
663
+ for (int j = 0; j < num_partitions; ++j) {
664
+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
665
+ inv_global_exp_sum;
666
+ }
667
+ from_float(out_ptr[i], acc);
668
+ }
669
+ }
670
+
671
+ } // namespace vllm
672
+
673
+ #undef WARP_SIZE
674
+ #undef MAX
675
+ #undef MIN
676
+ #undef DIVIDE_ROUND_UP
paged-attention/attention/attention_utils.cuh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #pragma once
20
+
21
+ #include "../cuda_compat.h"
22
+ #include "attention_dtypes.h"
23
+
24
+ #include <float.h>
25
+ #include <type_traits>
26
+
27
+ namespace vllm {
28
+
29
+ // Q*K^T operation.
30
+ template <int THREAD_GROUP_SIZE, typename Vec, int N>
31
+ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
32
+ using A_vec = typename FloatVec<Vec>::Type;
33
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
34
+ A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
35
+ #pragma unroll
36
+ for (int ii = 1; ii < N; ++ii) {
37
+ qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
38
+ }
39
+
40
+ // Finalize the reduction across lanes.
41
+ float qk = sum(qk_vec);
42
+ #pragma unroll
43
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
44
+ qk += VLLM_SHFL_XOR_SYNC(qk, mask);
45
+ }
46
+ return qk;
47
+ }
48
+
49
+ template <typename T, int THREAD_GROUP_SIZE>
50
+ struct Qk_dot {
51
+ template <typename Vec, int N>
52
+ static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
53
+ return qk_dot_<THREAD_GROUP_SIZE>(q, k);
54
+ }
55
+ };
56
+
57
+ } // namespace vllm
paged-attention/attention/dtype_bfloat16.cuh ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * and
5
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
6
+ * Copyright (c) 2023, The vLLM team.
7
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+ #pragma once
22
+
23
+ #include "attention_generic.cuh"
24
+ #include "dtype_float32.cuh"
25
+
26
+ #ifndef USE_ROCM
27
+ #include <cuda_bf16.h>
28
+ #include <cuda_fp16.h>
29
+ #else
30
+ #include <hip/hip_bf16.h>
31
+ #include <hip/hip_fp16.h>
32
+
33
+ typedef __hip_bfloat162 __nv_bfloat162;
34
+ typedef __hip_bfloat16 __nv_bfloat16;
35
+ #endif
36
+
37
+ #include <stdint.h>
38
+
39
+ namespace vllm {
40
+
41
+ // Define custom BF16 vector data types.
42
+ struct bf16_4_t {
43
+ __nv_bfloat162 x;
44
+ __nv_bfloat162 y;
45
+ };
46
+
47
+ struct bf16_8_t {
48
+ __nv_bfloat162 x;
49
+ __nv_bfloat162 y;
50
+ __nv_bfloat162 z;
51
+ __nv_bfloat162 w;
52
+ };
53
+
54
+ // BF16 vector types for Q, K, V.
55
+ template <>
56
+ struct Vec<__nv_bfloat16, 1> {
57
+ using Type = __nv_bfloat16;
58
+ };
59
+ template <>
60
+ struct Vec<__nv_bfloat16, 2> {
61
+ using Type = __nv_bfloat162;
62
+ };
63
+ template <>
64
+ struct Vec<__nv_bfloat16, 4> {
65
+ using Type = bf16_4_t;
66
+ };
67
+ template <>
68
+ struct Vec<__nv_bfloat16, 8> {
69
+ using Type = bf16_8_t;
70
+ };
71
+
72
+ // FP32 accumulator vector types corresponding to Vec.
73
+ template <>
74
+ struct FloatVec<__nv_bfloat16> {
75
+ using Type = float;
76
+ };
77
+ template <>
78
+ struct FloatVec<__nv_bfloat162> {
79
+ using Type = float2;
80
+ };
81
+ template <>
82
+ struct FloatVec<bf16_4_t> {
83
+ using Type = Float4_;
84
+ };
85
+ template <>
86
+ struct FloatVec<bf16_8_t> {
87
+ using Type = Float8_;
88
+ };
89
+
90
+ // Utility functions for type conversions.
91
+ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
92
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
93
+ assert(false);
94
+ #else
95
+ return __bfloat1622float2(val);
96
+ #endif
97
+ __builtin_unreachable(); // Suppress missing return statement warning
98
+ }
99
+
100
+ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
101
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
102
+ assert(false);
103
+ #else
104
+ return __bfloat162bfloat162(val);
105
+ #endif
106
+ __builtin_unreachable(); // Suppress missing return statement warning
107
+ }
108
+
109
+ // Vector addition.
110
+ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
111
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
112
+ assert(false);
113
+ #else
114
+ #ifndef USE_ROCM
115
+ return a + b;
116
+ #else
117
+ return __hadd(a, b);
118
+ #endif
119
+ #endif
120
+ __builtin_unreachable(); // Suppress missing return statement warning
121
+ }
122
+
123
+ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
124
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
125
+ assert(false);
126
+ #else
127
+ return __hadd2(a, b);
128
+ #endif
129
+ __builtin_unreachable(); // Suppress missing return statement warning
130
+ }
131
+
132
+ inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
133
+ bf16_4_t c;
134
+ c.x = add(a.x, b.x);
135
+ c.y = add(a.y, b.y);
136
+ return c;
137
+ }
138
+
139
+ inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
140
+ bf16_8_t c;
141
+ c.x = add(a.x, b.x);
142
+ c.y = add(a.y, b.y);
143
+ c.z = add(a.z, b.z);
144
+ c.w = add(a.w, b.w);
145
+ return c;
146
+ }
147
+
148
+ inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
149
+ float2 fa = bf1622float2(a);
150
+ return add(fa, fb);
151
+ }
152
+
153
+ inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
154
+ Float4_ fc;
155
+ fc.x = add(a.x, fb.x);
156
+ fc.y = add(a.y, fb.y);
157
+ return fc;
158
+ }
159
+
160
+ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
161
+ Float8_ fc;
162
+ fc.x = add(a.x, fb.x);
163
+ fc.y = add(a.y, fb.y);
164
+ fc.z = add(a.z, fb.z);
165
+ fc.w = add(a.w, fb.w);
166
+ return fc;
167
+ }
168
+
169
+ // Vector multiplication.
170
+ template <>
171
+ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
172
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
173
+ assert(false);
174
+ #else
175
+ return __hmul(a, b);
176
+ #endif
177
+ __builtin_unreachable(); // Suppress missing return statement warning
178
+ }
179
+
180
+ template <>
181
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
182
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
183
+ assert(false);
184
+ #else
185
+ return __hmul2(a, b);
186
+ #endif
187
+ __builtin_unreachable(); // Suppress missing return statement warning
188
+ }
189
+
190
+ template <>
191
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
192
+ return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
193
+ }
194
+
195
+ template <>
196
+ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
197
+ bf16_4_t c;
198
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
199
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
200
+ return c;
201
+ }
202
+
203
+ template <>
204
+ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
205
+ __nv_bfloat162 s = bf162bf162(a);
206
+ bf16_4_t c;
207
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
208
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
209
+ return c;
210
+ }
211
+
212
+ template <>
213
+ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
214
+ bf16_8_t c;
215
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
216
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
217
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
218
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
219
+ return c;
220
+ }
221
+
222
+ template <>
223
+ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
224
+ __nv_bfloat162 s = bf162bf162(a);
225
+ bf16_8_t c;
226
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
227
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
228
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
229
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
230
+ return c;
231
+ }
232
+
233
+ template <>
234
+ inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
235
+ float fa = __bfloat162float(a);
236
+ float fb = __bfloat162float(b);
237
+ return fa * fb;
238
+ }
239
+
240
+ template <>
241
+ inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
242
+ float2 fa = bf1622float2(a);
243
+ float2 fb = bf1622float2(b);
244
+ return mul<float2, float2, float2>(fa, fb);
245
+ }
246
+
247
+ template <>
248
+ inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
249
+ return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
250
+ }
251
+
252
+ template <>
253
+ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
254
+ Float4_ fc;
255
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
256
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
257
+ return fc;
258
+ }
259
+
260
+ template <>
261
+ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
262
+ __nv_bfloat162 s = bf162bf162(a);
263
+ Float4_ fc;
264
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
265
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
266
+ return fc;
267
+ }
268
+
269
+ template <>
270
+ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
271
+ Float8_ fc;
272
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
273
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
274
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
275
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
276
+ return fc;
277
+ }
278
+
279
+ template <>
280
+ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
281
+ __nv_bfloat162 s = bf162bf162(a);
282
+ Float8_ fc;
283
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
284
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
285
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
286
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
287
+ return fc;
288
+ }
289
+
290
+ // Vector fused multiply-add.
291
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
292
+ __nv_bfloat162 c) {
293
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
294
+ assert(false);
295
+ #else
296
+ return __hfma2(a, b, c);
297
+ #endif
298
+ __builtin_unreachable(); // Suppress missing return statement warning
299
+ }
300
+
301
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
302
+ __nv_bfloat162 c) {
303
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
304
+ assert(false);
305
+ #else
306
+ return __hfma2(bf162bf162(a), b, c);
307
+ #endif
308
+ __builtin_unreachable(); // Suppress missing return statement warning
309
+ }
310
+
311
+ inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
312
+ bf16_4_t d;
313
+ d.x = fma(a.x, b.x, c.x);
314
+ d.y = fma(a.y, b.y, c.y);
315
+ return d;
316
+ }
317
+
318
+ inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
319
+ __nv_bfloat162 s = bf162bf162(a);
320
+ bf16_4_t d;
321
+ d.x = fma(s, b.x, c.x);
322
+ d.y = fma(s, b.y, c.y);
323
+ return d;
324
+ }
325
+
326
+ inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
327
+ bf16_8_t d;
328
+ d.x = fma(a.x, b.x, c.x);
329
+ d.y = fma(a.y, b.y, c.y);
330
+ d.z = fma(a.z, b.z, c.z);
331
+ d.w = fma(a.w, b.w, c.w);
332
+ return d;
333
+ }
334
+
335
+ inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
336
+ __nv_bfloat162 s = bf162bf162(a);
337
+ bf16_8_t d;
338
+ d.x = fma(s, b.x, c.x);
339
+ d.y = fma(s, b.y, c.y);
340
+ d.z = fma(s, b.z, c.z);
341
+ d.w = fma(s, b.w, c.w);
342
+ return d;
343
+ }
344
+
345
+ inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
346
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
347
+ }
348
+
349
+ inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
350
+ float2 fa = bf1622float2(a);
351
+ float2 fb = bf1622float2(b);
352
+ return fma(fa, fb, fc);
353
+ }
354
+
355
+ inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
356
+ return fma(bf162bf162(a), b, fc);
357
+ }
358
+
359
+ inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
360
+ Float4_ fd;
361
+ fd.x = fma(a.x, b.x, fc.x);
362
+ fd.y = fma(a.y, b.y, fc.y);
363
+ return fd;
364
+ }
365
+
366
+ inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
367
+ __nv_bfloat162 s = bf162bf162(a);
368
+ Float4_ fd;
369
+ fd.x = fma(s, b.x, fc.x);
370
+ fd.y = fma(s, b.y, fc.y);
371
+ return fd;
372
+ }
373
+
374
+ inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
375
+ Float8_ fd;
376
+ fd.x = fma(a.x, b.x, fc.x);
377
+ fd.y = fma(a.y, b.y, fc.y);
378
+ fd.z = fma(a.z, b.z, fc.z);
379
+ fd.w = fma(a.w, b.w, fc.w);
380
+ return fd;
381
+ }
382
+
383
+ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
384
+ __nv_bfloat162 s = bf162bf162(a);
385
+ Float8_ fd;
386
+ fd.x = fma(s, b.x, fc.x);
387
+ fd.y = fma(s, b.y, fc.y);
388
+ fd.z = fma(s, b.z, fc.z);
389
+ fd.w = fma(s, b.w, fc.w);
390
+ return fd;
391
+ }
392
+
393
+ // Vector sum.
394
+ template <>
395
+ inline __device__ float sum(__nv_bfloat16 v) {
396
+ return __bfloat162float(v);
397
+ }
398
+
399
+ template <>
400
+ inline __device__ float sum(__nv_bfloat162 v) {
401
+ float2 vf = bf1622float2(v);
402
+ return vf.x + vf.y;
403
+ }
404
+
405
+ template <>
406
+ inline __device__ float sum(bf16_4_t v) {
407
+ return sum(v.x) + sum(v.y);
408
+ }
409
+
410
+ template <>
411
+ inline __device__ float sum(bf16_8_t v) {
412
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
413
+ }
414
+
415
+ // From float32 to bfloat16.
416
+ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
417
+ dst = __float2bfloat16(src);
418
+ }
419
+
420
+ inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
421
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
422
+ assert(false);
423
+ #else
424
+ dst = __float22bfloat162_rn(src);
425
+ #endif
426
+ }
427
+
428
+ inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
429
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
430
+ assert(false);
431
+ #else
432
+ dst.x = __float22bfloat162_rn(src.x);
433
+ dst.y = __float22bfloat162_rn(src.y);
434
+ #endif
435
+ }
436
+
437
+ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
438
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
439
+ assert(false);
440
+ #else
441
+ dst.x = __float22bfloat162_rn(src.x);
442
+ dst.y = __float22bfloat162_rn(src.y);
443
+ dst.z = __float22bfloat162_rn(src.z);
444
+ dst.w = __float22bfloat162_rn(src.w);
445
+ #endif
446
+ }
447
+
448
+ // From bfloat16 to float32.
449
+ inline __device__ float to_float(__nv_bfloat16 u) {
450
+ return __bfloat162float(u);
451
+ }
452
+
453
+ // Zero-out a variable.
454
+ inline __device__ void zero(__nv_bfloat16& dst) {
455
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
456
+ assert(false);
457
+ #else
458
+ // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
459
+ dst = __ushort_as_bfloat16((unsigned short)0x0000U);
460
+ #endif
461
+ }
462
+
463
+ } // namespace vllm
paged-attention/attention/dtype_float16.cuh ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * and
5
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
6
+ * Copyright (c) 2023, The vLLM team.
7
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+ #pragma once
22
+
23
+ #include "attention_generic.cuh"
24
+ #include "dtype_float32.cuh"
25
+
26
+ #ifdef USE_ROCM
27
+ #include <hip/hip_fp16.h>
28
+ #endif
29
+
30
+ #include <stdint.h>
31
+
32
+ namespace vllm {
33
+
34
+ // FP16 vector types for Q, K, V.
35
+ template <>
36
+ struct Vec<uint16_t, 1> {
37
+ using Type = uint16_t;
38
+ };
39
+ template <>
40
+ struct Vec<uint16_t, 2> {
41
+ using Type = uint32_t;
42
+ };
43
+ template <>
44
+ struct Vec<uint16_t, 4> {
45
+ using Type = uint2;
46
+ };
47
+ template <>
48
+ struct Vec<uint16_t, 8> {
49
+ using Type = uint4;
50
+ };
51
+
52
+ // FP32 accumulator vector types corresponding to Vec.
53
+ template <>
54
+ struct FloatVec<uint16_t> {
55
+ using Type = float;
56
+ };
57
+ template <>
58
+ struct FloatVec<uint32_t> {
59
+ using Type = float2;
60
+ };
61
+ template <>
62
+ struct FloatVec<uint2> {
63
+ using Type = Float4_;
64
+ };
65
+ template <>
66
+ struct FloatVec<uint4> {
67
+ using Type = Float8_;
68
+ };
69
+
70
+ // Utility functions for type conversions.
71
+ inline __device__ uint32_t h0_h0(uint16_t a) {
72
+ #ifndef USE_ROCM
73
+ uint32_t b;
74
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
75
+ return b;
76
+ #else
77
+ union {
78
+ uint32_t u32;
79
+ uint16_t u16[2];
80
+ } tmp;
81
+ tmp.u16[0] = a;
82
+ tmp.u16[1] = a;
83
+ return tmp.u32;
84
+ #endif
85
+ }
86
+
87
+ inline __device__ float half_to_float(uint16_t h) {
88
+ float f;
89
+ #ifndef USE_ROCM
90
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
91
+ #else
92
+ asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
93
+ #endif
94
+ return f;
95
+ }
96
+
97
+ inline __device__ float2 half2_to_float2(uint32_t v) {
98
+ #ifndef USE_ROCM
99
+ uint16_t lo, hi;
100
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
101
+ return make_float2(half_to_float(lo), half_to_float(hi));
102
+ #else
103
+ union {
104
+ uint32_t u32;
105
+ uint16_t u16[2];
106
+ } tmp;
107
+ tmp.u32 = v;
108
+ float2 ret;
109
+ ret.x = half_to_float(tmp.u16[0]);
110
+ ret.y = half_to_float(tmp.u16[1]);
111
+ return ret;
112
+ #endif
113
+ }
114
+
115
+ inline __device__ uint16_t float_to_half(float f) {
116
+ union {
117
+ uint32_t u32;
118
+ uint16_t u16[2];
119
+ } tmp;
120
+ #ifndef USE_ROCM
121
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
122
+ #else
123
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
124
+ #endif
125
+ return tmp.u16[0];
126
+ }
127
+
128
+ inline __device__ uint32_t float2_to_half2(float2 f) {
129
+ union {
130
+ uint32_t u32;
131
+ uint16_t u16[2];
132
+ } tmp;
133
+ #ifndef USE_ROCM
134
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
135
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
136
+ : "=r"(tmp.u32)
137
+ : "f"(f.y), "f"(f.x));
138
+ #else
139
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
140
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
141
+ #endif
142
+ #else
143
+ tmp.u16[0] = float_to_half(f.x);
144
+ tmp.u16[1] = float_to_half(f.y);
145
+ #endif
146
+ return tmp.u32;
147
+ }
148
+
149
+ // Vector addition.
150
+ inline __device__ uint16_t add(uint16_t a, uint16_t b) {
151
+ uint16_t c;
152
+ #ifndef USE_ROCM
153
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
154
+ #else
155
+ asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
156
+ #endif
157
+ return c;
158
+ }
159
+
160
+ inline __device__ uint32_t add(uint32_t a, uint32_t b) {
161
+ uint32_t c;
162
+ #ifndef USE_ROCM
163
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
164
+ #else
165
+ asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
166
+ #endif
167
+ return c;
168
+ }
169
+
170
+ inline __device__ uint2 add(uint2 a, uint2 b) {
171
+ uint2 c;
172
+ c.x = add(a.x, b.x);
173
+ c.y = add(a.y, b.y);
174
+ return c;
175
+ }
176
+
177
+ inline __device__ uint4 add(uint4 a, uint4 b) {
178
+ uint4 c;
179
+ c.x = add(a.x, b.x);
180
+ c.y = add(a.y, b.y);
181
+ c.z = add(a.z, b.z);
182
+ c.w = add(a.w, b.w);
183
+ return c;
184
+ }
185
+
186
+ inline __device__ float2 add(uint32_t a, float2 fb) {
187
+ float2 fa = half2_to_float2(a);
188
+ return add(fa, fb);
189
+ }
190
+
191
+ inline __device__ Float4_ add(uint2 a, Float4_ fb) {
192
+ Float4_ fc;
193
+ fc.x = add(a.x, fb.x);
194
+ fc.y = add(a.y, fb.y);
195
+ return fc;
196
+ }
197
+
198
+ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
199
+ Float8_ fc;
200
+ fc.x = add(a.x, fb.x);
201
+ fc.y = add(a.y, fb.y);
202
+ fc.z = add(a.z, fb.z);
203
+ fc.w = add(a.w, fb.w);
204
+ return fc;
205
+ }
206
+
207
+ // Vector multiplication.
208
+ template <>
209
+ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
210
+ uint16_t c;
211
+ #ifndef USE_ROCM
212
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
213
+ #else
214
+ asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
215
+ #endif
216
+ return c;
217
+ }
218
+
219
+ template <>
220
+ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
221
+ uint32_t c;
222
+ #ifndef USE_ROCM
223
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
224
+ #else
225
+ asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
226
+ #endif
227
+ return c;
228
+ }
229
+
230
+ template <>
231
+ inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
232
+ return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
233
+ }
234
+
235
+ template <>
236
+ inline __device__ uint2 mul(uint2 a, uint2 b) {
237
+ uint2 c;
238
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
239
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
240
+ return c;
241
+ }
242
+
243
+ template <>
244
+ inline __device__ uint2 mul(uint16_t a, uint2 b) {
245
+ uint32_t s = h0_h0(a);
246
+ uint2 c;
247
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
248
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
249
+ return c;
250
+ }
251
+
252
+ template <>
253
+ inline __device__ uint4 mul(uint4 a, uint4 b) {
254
+ uint4 c;
255
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
256
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
257
+ c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
258
+ c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
259
+ return c;
260
+ }
261
+
262
+ template <>
263
+ inline __device__ uint4 mul(uint16_t a, uint4 b) {
264
+ uint32_t s = h0_h0(a);
265
+ uint4 c;
266
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
267
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
268
+ c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
269
+ c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
270
+ return c;
271
+ }
272
+
273
+ template <>
274
+ inline __device__ float mul(uint16_t a, uint16_t b) {
275
+ float fa = half_to_float(a);
276
+ float fb = half_to_float(b);
277
+ return fa * fb;
278
+ }
279
+
280
+ template <>
281
+ inline __device__ float2 mul(uint32_t a, uint32_t b) {
282
+ float2 fa = half2_to_float2(a);
283
+ float2 fb = half2_to_float2(b);
284
+ return mul<float2, float2, float2>(fa, fb);
285
+ }
286
+
287
+ template <>
288
+ inline __device__ float2 mul(uint16_t a, uint32_t b) {
289
+ return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
290
+ }
291
+
292
+ template <>
293
+ inline __device__ Float4_ mul(uint2 a, uint2 b) {
294
+ Float4_ fc;
295
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
296
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
297
+ return fc;
298
+ }
299
+
300
+ template <>
301
+ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
302
+ uint32_t s = h0_h0(a);
303
+ Float4_ fc;
304
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
305
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
306
+ return fc;
307
+ }
308
+
309
+ template <>
310
+ inline __device__ Float8_ mul(uint4 a, uint4 b) {
311
+ Float8_ fc;
312
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
313
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
314
+ fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
315
+ fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
316
+ return fc;
317
+ }
318
+
319
+ template <>
320
+ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
321
+ uint32_t s = h0_h0(a);
322
+ Float8_ fc;
323
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
324
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
325
+ fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
326
+ fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
327
+ return fc;
328
+ }
329
+
330
+ // Vector fused multiply-add.
331
+ inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
332
+ uint32_t d;
333
+ #ifndef USE_ROCM
334
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
335
+ : "=r"(d)
336
+ : "r"(a), "r"(b), "r"(c));
337
+ #else
338
+ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
339
+ : "=v"(d)
340
+ : "v"(a), "v"(b), "v"(c));
341
+ #endif
342
+ return d;
343
+ }
344
+
345
+ inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
346
+ return fma(h0_h0(a), b, c);
347
+ }
348
+
349
+ inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
350
+ uint2 d;
351
+ d.x = fma(a.x, b.x, c.x);
352
+ d.y = fma(a.y, b.y, c.y);
353
+ return d;
354
+ }
355
+
356
+ inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
357
+ uint32_t s = h0_h0(a);
358
+ uint2 d;
359
+ d.x = fma(s, b.x, c.x);
360
+ d.y = fma(s, b.y, c.y);
361
+ return d;
362
+ }
363
+
364
+ inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
365
+ uint4 d;
366
+ d.x = fma(a.x, b.x, c.x);
367
+ d.y = fma(a.y, b.y, c.y);
368
+ d.z = fma(a.z, b.z, c.z);
369
+ d.w = fma(a.w, b.w, c.w);
370
+ return d;
371
+ }
372
+
373
+ inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
374
+ uint32_t s = h0_h0(a);
375
+ uint4 d;
376
+ d.x = fma(s, b.x, c.x);
377
+ d.y = fma(s, b.y, c.y);
378
+ d.z = fma(s, b.z, c.z);
379
+ d.w = fma(s, b.w, c.w);
380
+ return d;
381
+ }
382
+
383
+ inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
384
+ float fa = half_to_float(a);
385
+ float fb = half_to_float(b);
386
+ return fa * fb + fc;
387
+ }
388
+
389
+ inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
390
+ float2 fa = half2_to_float2(a);
391
+ float2 fb = half2_to_float2(b);
392
+ return fma(fa, fb, fc);
393
+ }
394
+
395
+ inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
396
+ return fma(h0_h0(a), b, fc);
397
+ }
398
+
399
+ inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
400
+ Float4_ fd;
401
+ fd.x = fma(a.x, b.x, fc.x);
402
+ fd.y = fma(a.y, b.y, fc.y);
403
+ return fd;
404
+ }
405
+
406
+ inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
407
+ uint32_t s = h0_h0(a);
408
+ Float4_ fd;
409
+ fd.x = fma(s, b.x, fc.x);
410
+ fd.y = fma(s, b.y, fc.y);
411
+ return fd;
412
+ }
413
+
414
+ inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
415
+ Float8_ fd;
416
+ fd.x = fma(a.x, b.x, fc.x);
417
+ fd.y = fma(a.y, b.y, fc.y);
418
+ fd.z = fma(a.z, b.z, fc.z);
419
+ fd.w = fma(a.w, b.w, fc.w);
420
+ return fd;
421
+ }
422
+
423
+ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
424
+ uint32_t s = h0_h0(a);
425
+ Float8_ fd;
426
+ fd.x = fma(s, b.x, fc.x);
427
+ fd.y = fma(s, b.y, fc.y);
428
+ fd.z = fma(s, b.z, fc.z);
429
+ fd.w = fma(s, b.w, fc.w);
430
+ return fd;
431
+ }
432
+
433
+ // Vector sum.
434
+ template <>
435
+ inline __device__ float sum(uint16_t v) {
436
+ return half_to_float(v);
437
+ }
438
+
439
+ template <>
440
+ inline __device__ float sum(uint32_t v) {
441
+ float2 tmp = half2_to_float2(v);
442
+ return tmp.x + tmp.y;
443
+ }
444
+
445
+ template <>
446
+ inline __device__ float sum(uint2 v) {
447
+ uint32_t c = add(v.x, v.y);
448
+ return sum(c);
449
+ }
450
+
451
+ template <>
452
+ inline __device__ float sum(uint4 v) {
453
+ uint32_t c = add(v.x, v.y);
454
+ c = add(c, v.z);
455
+ c = add(c, v.w);
456
+ return sum(c);
457
+ }
458
+
459
+ // From float32 to float16.
460
+ inline __device__ void from_float(uint16_t& dst, float src) {
461
+ dst = float_to_half(src);
462
+ }
463
+
464
+ inline __device__ void from_float(uint32_t& dst, float2 src) {
465
+ dst = float2_to_half2(src);
466
+ }
467
+
468
+ inline __device__ void from_float(uint2& dst, Float4_ src) {
469
+ dst.x = float2_to_half2(src.x);
470
+ dst.y = float2_to_half2(src.y);
471
+ }
472
+
473
+ inline __device__ void from_float(uint4& dst, Float8_ src) {
474
+ dst.x = float2_to_half2(src.x);
475
+ dst.y = float2_to_half2(src.y);
476
+ dst.z = float2_to_half2(src.z);
477
+ dst.w = float2_to_half2(src.w);
478
+ }
479
+
480
+ // From float16 to float32.
481
+ inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
482
+
483
+ inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
484
+
485
+ inline __device__ Float4_ to_float(uint2 u) {
486
+ Float4_ tmp;
487
+ tmp.x = half2_to_float2(u.x);
488
+ tmp.y = half2_to_float2(u.y);
489
+ return tmp;
490
+ }
491
+
492
+ inline __device__ Float8_ to_float(uint4 u) {
493
+ Float8_ tmp;
494
+ tmp.x = half2_to_float2(u.x);
495
+ tmp.y = half2_to_float2(u.y);
496
+ tmp.z = half2_to_float2(u.z);
497
+ tmp.w = half2_to_float2(u.w);
498
+ return tmp;
499
+ }
500
+
501
+ // Zero-out a variable.
502
+ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
503
+
504
+ } // namespace vllm
paged-attention/attention/dtype_float32.cuh ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * and
5
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
6
+ * Copyright (c) 2023, The vLLM team.
7
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+ #pragma once
22
+
23
+ #include "attention_generic.cuh"
24
+
25
+ #include <stdint.h>
26
+
27
+ namespace vllm {
28
+
29
+ // Define custom FP32 vector data types.
30
+ struct Float4_ {
31
+ float2 x;
32
+ float2 y;
33
+ };
34
+
35
+ struct Float8_ {
36
+ float2 x;
37
+ float2 y;
38
+ float2 z;
39
+ float2 w;
40
+ };
41
+
42
+ // FP32 vector types for Q, K, V.
43
+ template <>
44
+ struct Vec<float, 1> {
45
+ using Type = float;
46
+ };
47
+ template <>
48
+ struct Vec<float, 2> {
49
+ using Type = float2;
50
+ };
51
+ template <>
52
+ struct Vec<float, 4> {
53
+ using Type = float4;
54
+ };
55
+
56
+ // FP32 accumulator vector types corresponding to Vec.
57
+ template <>
58
+ struct FloatVec<float> {
59
+ using Type = float;
60
+ };
61
+ template <>
62
+ struct FloatVec<float2> {
63
+ using Type = float2;
64
+ };
65
+ template <>
66
+ struct FloatVec<float4> {
67
+ using Type = float4;
68
+ };
69
+
70
+ // Vector addition.
71
+ inline __device__ float add(float a, float b) { return a + b; }
72
+
73
+ inline __device__ float2 add(float2 a, float2 b) {
74
+ float2 c;
75
+ c.x = add(a.x, b.x);
76
+ c.y = add(a.y, b.y);
77
+ return c;
78
+ }
79
+
80
+ inline __device__ float4 add(float4 a, float4 b) {
81
+ float4 c;
82
+ c.x = add(a.x, b.x);
83
+ c.y = add(a.y, b.y);
84
+ c.z = add(a.z, b.z);
85
+ c.w = add(a.w, b.w);
86
+ return c;
87
+ }
88
+
89
+ // Vector multiplication.
90
+ template <>
91
+ inline __device__ float mul<float, float>(float a, float b) {
92
+ return a * b;
93
+ }
94
+
95
+ template <>
96
+ inline __device__ float2 mul(float2 a, float2 b) {
97
+ float2 c;
98
+ c.x = a.x * b.x;
99
+ c.y = a.y * b.y;
100
+ return c;
101
+ }
102
+
103
+ template <>
104
+ inline __device__ float2 mul(float a, float2 b) {
105
+ float2 c;
106
+ c.x = a * b.x;
107
+ c.y = a * b.y;
108
+ return c;
109
+ }
110
+
111
+ template <>
112
+ inline __device__ float4 mul(float4 a, float4 b) {
113
+ float4 c;
114
+ c.x = a.x * b.x;
115
+ c.y = a.y * b.y;
116
+ c.z = a.z * b.z;
117
+ c.w = a.w * b.w;
118
+ return c;
119
+ }
120
+
121
+ template <>
122
+ inline __device__ float4 mul(float a, float4 b) {
123
+ float4 c;
124
+ c.x = a * b.x;
125
+ c.y = a * b.y;
126
+ c.z = a * b.z;
127
+ c.w = a * b.w;
128
+ return c;
129
+ }
130
+
131
+ // Vector fused multiply-add.
132
+ inline __device__ float fma(float a, float b, float c) { return a * b + c; }
133
+
134
+ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
135
+ float2 d;
136
+ d.x = fma(a.x, b.x, c.x);
137
+ d.y = fma(a.y, b.y, c.y);
138
+ return d;
139
+ }
140
+
141
+ inline __device__ float2 fma(float a, float2 b, float2 c) {
142
+ float2 d;
143
+ d.x = fma(a, b.x, c.x);
144
+ d.y = fma(a, b.y, c.y);
145
+ return d;
146
+ }
147
+
148
+ inline __device__ float4 fma(float4 a, float4 b, float4 c) {
149
+ float4 d;
150
+ d.x = fma(a.x, b.x, c.x);
151
+ d.y = fma(a.y, b.y, c.y);
152
+ d.z = fma(a.z, b.z, c.z);
153
+ d.w = fma(a.w, b.w, c.w);
154
+ return d;
155
+ }
156
+
157
+ inline __device__ float4 fma(float a, float4 b, float4 c) {
158
+ float4 d;
159
+ d.x = fma(a, b.x, c.x);
160
+ d.y = fma(a, b.y, c.y);
161
+ d.z = fma(a, b.z, c.z);
162
+ d.w = fma(a, b.w, c.w);
163
+ return d;
164
+ }
165
+
166
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
167
+ Float4_ d;
168
+ d.x = fma(a, b.x, c.x);
169
+ d.y = fma(a, b.y, c.y);
170
+ return d;
171
+ }
172
+
173
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
174
+ Float8_ d;
175
+ d.x = fma(a, b.x, c.x);
176
+ d.y = fma(a, b.y, c.y);
177
+ d.z = fma(a, b.z, c.z);
178
+ d.w = fma(a, b.w, c.w);
179
+ return d;
180
+ }
181
+
182
+ // Vector sum.
183
+ template <>
184
+ inline __device__ float sum(float v) {
185
+ return v;
186
+ }
187
+
188
+ template <>
189
+ inline __device__ float sum(float2 v) {
190
+ return v.x + v.y;
191
+ }
192
+
193
+ template <>
194
+ inline __device__ float sum(float4 v) {
195
+ return v.x + v.y + v.z + v.w;
196
+ }
197
+
198
+ template <>
199
+ inline __device__ float sum(Float4_ v) {
200
+ return v.x.x + v.x.y + v.y.x + v.y.y;
201
+ }
202
+
203
+ template <>
204
+ inline __device__ float sum(Float8_ v) {
205
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
206
+ }
207
+
208
+ // Vector dot product.
209
+ inline __device__ float dot(float a, float b) { return a * b; }
210
+
211
+ inline __device__ float dot(float2 a, float2 b) {
212
+ float2 c = mul<float2, float2, float2>(a, b);
213
+ return c.x + c.y;
214
+ }
215
+
216
+ inline __device__ float dot(Float4_ a, Float4_ b) {
217
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
218
+ acc = fma(a.y, b.y, acc);
219
+ return acc.x + acc.y;
220
+ }
221
+
222
+ inline __device__ float dot(Float8_ a, Float8_ b) {
223
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
224
+ acc = fma(a.y, b.y, acc);
225
+ acc = fma(a.z, b.z, acc);
226
+ acc = fma(a.w, b.w, acc);
227
+ return acc.x + acc.y;
228
+ }
229
+
230
+ // From float to float.
231
+ inline __device__ void from_float(float& dst, float src) { dst = src; }
232
+
233
+ inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
234
+
235
+ inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
236
+
237
+ // From float to float.
238
+ inline __device__ float to_float(float u) { return u; }
239
+
240
+ inline __device__ float2 to_float(float2 u) { return u; }
241
+
242
+ inline __device__ float4 to_float(float4 u) { return u; }
243
+
244
+ inline __device__ Float4_ to_float(Float4_ u) { return u; }
245
+
246
+ inline __device__ Float8_ to_float(Float8_ u) { return u; }
247
+
248
+ // Zero-out a variable.
249
+ inline __device__ void zero(float& dst) { dst = 0.f; }
250
+
251
+ } // namespace vllm
paged-attention/attention/dtype_fp8.cuh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "attention_generic.cuh"
4
+
5
+ #include <stdint.h>
6
+ #ifdef ENABLE_FP8
7
+ #ifndef USE_ROCM
8
+ #include <cuda_fp8.h>
9
+ #endif // USE_ROCM
10
+ #endif // ENABLE_FP8
11
+
12
+ namespace vllm {
13
+
14
+ enum class Fp8KVCacheDataType {
15
+ kAuto = 0,
16
+ kFp8E4M3 = 1,
17
+ kFp8E5M2 = 2,
18
+ };
19
+
20
+ // fp8 vector types for quantization of kv cache
21
+ template <>
22
+ struct Vec<uint8_t, 1> {
23
+ using Type = uint8_t;
24
+ };
25
+
26
+ template <>
27
+ struct Vec<uint8_t, 2> {
28
+ using Type = uint16_t;
29
+ };
30
+
31
+ template <>
32
+ struct Vec<uint8_t, 4> {
33
+ using Type = uint32_t;
34
+ };
35
+
36
+ template <>
37
+ struct Vec<uint8_t, 8> {
38
+ using Type = uint2;
39
+ };
40
+
41
+ } // namespace vllm
paged-attention/attention/paged_attention_v1.cu ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+
20
+ #include "attention_kernels.cuh"
21
+
22
+ #ifndef USE_ROCM
23
+ #define WARP_SIZE 32
24
+ #else
25
+ #define WARP_SIZE warpSize
26
+ #endif
27
+
28
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
29
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
30
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
31
+
32
+ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
33
+ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
34
+ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
35
+ BLOCK_SIZE, NUM_THREADS, \
36
+ KV_DTYPE, IS_BLOCK_SPARSE>), \
37
+ shared_mem_size); \
38
+ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
39
+ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
40
+ <<<grid, block, shared_mem_size, stream>>>( \
41
+ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
42
+ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
43
+ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
44
+ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
45
+ blocksparse_vert_stride, blocksparse_block_size, \
46
+ blocksparse_head_sliding_step);
47
+
48
+ // TODO(woosuk): Tune NUM_THREADS.
49
+ template <typename T, typename CACHE_T, int BLOCK_SIZE,
50
+ vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
51
+ int NUM_THREADS = 128>
52
+ void paged_attention_v1_launcher(
53
+ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
54
+ torch::Tensor& value_cache, int num_kv_heads, float scale,
55
+ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
56
+ const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
57
+ torch::Tensor& v_scale, const int tp_rank,
58
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
59
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
60
+ int num_seqs = query.size(0);
61
+ int num_heads = query.size(1);
62
+ int head_size = query.size(2);
63
+ int max_num_blocks_per_seq = block_tables.size(1);
64
+ int q_stride = query.stride(0);
65
+ int kv_block_stride = key_cache.stride(0);
66
+ int kv_head_stride = key_cache.stride(1);
67
+
68
+ [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
69
+ assert(head_size % thread_group_size == 0);
70
+
71
+ // NOTE: alibi_slopes is optional.
72
+ const float* alibi_slopes_ptr =
73
+ alibi_slopes
74
+ ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
75
+ : nullptr;
76
+
77
+ T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
78
+ T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
79
+ CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
80
+ CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
81
+ int* block_tables_ptr = block_tables.data_ptr<int>();
82
+ int* seq_lens_ptr = seq_lens.data_ptr<int>();
83
+ const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
84
+ const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
85
+
86
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
87
+ int padded_max_seq_len =
88
+ DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
89
+ int logits_size = padded_max_seq_len * sizeof(float);
90
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
91
+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
92
+ // Keep that in sync with the logic here!
93
+ int shared_mem_size = std::max(logits_size, outputs_size);
94
+
95
+ dim3 grid(num_heads, num_seqs, 1);
96
+ dim3 block(NUM_THREADS);
97
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
98
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
99
+ switch (head_size) {
100
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
101
+ // head sizes that we use in the model. However, we can easily extend this
102
+ // to support any head size which is a multiple of 16.
103
+ case 32:
104
+ LAUNCH_PAGED_ATTENTION_V1(32);
105
+ break;
106
+ case 64:
107
+ LAUNCH_PAGED_ATTENTION_V1(64);
108
+ break;
109
+ case 80:
110
+ LAUNCH_PAGED_ATTENTION_V1(80);
111
+ break;
112
+ case 96:
113
+ LAUNCH_PAGED_ATTENTION_V1(96);
114
+ break;
115
+ case 112:
116
+ LAUNCH_PAGED_ATTENTION_V1(112);
117
+ break;
118
+ case 120:
119
+ LAUNCH_PAGED_ATTENTION_V1(120);
120
+ break;
121
+ case 128:
122
+ LAUNCH_PAGED_ATTENTION_V1(128);
123
+ break;
124
+ case 192:
125
+ LAUNCH_PAGED_ATTENTION_V1(192);
126
+ break;
127
+ case 256:
128
+ LAUNCH_PAGED_ATTENTION_V1(256);
129
+ break;
130
+ default:
131
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
132
+ break;
133
+ }
134
+ }
135
+
136
+ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
137
+ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
138
+ IS_BLOCK_SPARSE>( \
139
+ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
140
+ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
141
+ blocksparse_local_blocks, blocksparse_vert_stride, \
142
+ blocksparse_block_size, blocksparse_head_sliding_step);
143
+
144
+ #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
145
+ if (is_block_sparse) { \
146
+ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
147
+ } else { \
148
+ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
149
+ }
150
+
151
+ // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
152
+ // 1, 2, 4, 64, 128, 256.
153
+ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
154
+ switch (block_size) { \
155
+ case 8: \
156
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
157
+ break; \
158
+ case 16: \
159
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
160
+ break; \
161
+ case 32: \
162
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
163
+ break; \
164
+ default: \
165
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
166
+ break; \
167
+ }
168
+
169
+ void paged_attention_v1(
170
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
171
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
172
+ torch::Tensor&
173
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
174
+ torch::Tensor&
175
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
176
+ int64_t num_kv_heads, // [num_heads]
177
+ double scale,
178
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
179
+ torch::Tensor& seq_lens, // [num_seqs]
180
+ int64_t block_size, int64_t max_seq_len,
181
+ const std::optional<torch::Tensor>& alibi_slopes,
182
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
183
+ torch::Tensor& v_scale, const int64_t tp_rank,
184
+ const int64_t blocksparse_local_blocks,
185
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
186
+ const int64_t blocksparse_head_sliding_step) {
187
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
188
+
189
+ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
190
+ CALL_V1_LAUNCHER_BLOCK_SIZE)
191
+ }
192
+
193
+ #undef WARP_SIZE
194
+ #undef MAX
195
+ #undef MIN
196
+ #undef DIVIDE_ROUND_UP
paged-attention/attention/paged_attention_v2.cu ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+
20
+ #include "attention_kernels.cuh"
21
+
22
+ #ifndef USE_ROCM
23
+ #define WARP_SIZE 32
24
+ #else
25
+ #define WARP_SIZE warpSize
26
+ #endif
27
+
28
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
29
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
30
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
31
+
32
+ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
33
+ vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
34
+ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
35
+ PARTITION_SIZE> \
36
+ <<<grid, block, shared_mem_size, stream>>>( \
37
+ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
38
+ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
39
+ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
40
+ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
41
+ blocksparse_local_blocks, blocksparse_vert_stride, \
42
+ blocksparse_block_size, blocksparse_head_sliding_step); \
43
+ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
44
+ PARTITION_SIZE> \
45
+ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
46
+ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
47
+ max_num_partitions);
48
+
49
+ template <typename T, typename CACHE_T, int BLOCK_SIZE,
50
+ vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
51
+ int NUM_THREADS = 128, int PARTITION_SIZE = 512>
52
+ void paged_attention_v2_launcher(
53
+ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
54
+ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
55
+ torch::Tensor& value_cache, int num_kv_heads, float scale,
56
+ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
57
+ const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
58
+ torch::Tensor& v_scale, const int tp_rank,
59
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
60
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
61
+ int num_seqs = query.size(0);
62
+ int num_heads = query.size(1);
63
+ int head_size = query.size(2);
64
+ int max_num_blocks_per_seq = block_tables.size(1);
65
+ int q_stride = query.stride(0);
66
+ int kv_block_stride = key_cache.stride(0);
67
+ int kv_head_stride = key_cache.stride(1);
68
+
69
+ [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
70
+ assert(head_size % thread_group_size == 0);
71
+
72
+ // NOTE: alibi_slopes is optional.
73
+ const float* alibi_slopes_ptr =
74
+ alibi_slopes
75
+ ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
76
+ : nullptr;
77
+
78
+ T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
79
+ float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
80
+ float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
81
+ T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
82
+ T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
83
+ CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
84
+ CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
85
+ int* block_tables_ptr = block_tables.data_ptr<int>();
86
+ int* seq_lens_ptr = seq_lens.data_ptr<int>();
87
+ const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
88
+ const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
89
+
90
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
91
+ int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
92
+ int logits_size = PARTITION_SIZE * sizeof(float);
93
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
94
+
95
+ // For paged attention v2 kernel.
96
+ dim3 grid(num_heads, num_seqs, max_num_partitions);
97
+ int shared_mem_size = std::max(logits_size, outputs_size);
98
+ // For paged attention v2 reduce kernel.
99
+ dim3 reduce_grid(num_heads, num_seqs);
100
+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
101
+
102
+ dim3 block(NUM_THREADS);
103
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
104
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
105
+ switch (head_size) {
106
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
107
+ // head sizes that we use in the model. However, we can easily extend this
108
+ // to support any head size which is a multiple of 16.
109
+ case 32:
110
+ LAUNCH_PAGED_ATTENTION_V2(32);
111
+ break;
112
+ case 64:
113
+ LAUNCH_PAGED_ATTENTION_V2(64);
114
+ break;
115
+ case 80:
116
+ LAUNCH_PAGED_ATTENTION_V2(80);
117
+ break;
118
+ case 96:
119
+ LAUNCH_PAGED_ATTENTION_V2(96);
120
+ break;
121
+ case 112:
122
+ LAUNCH_PAGED_ATTENTION_V2(112);
123
+ break;
124
+ case 120:
125
+ LAUNCH_PAGED_ATTENTION_V2(120);
126
+ break;
127
+ case 128:
128
+ LAUNCH_PAGED_ATTENTION_V2(128);
129
+ break;
130
+ case 192:
131
+ LAUNCH_PAGED_ATTENTION_V2(192);
132
+ break;
133
+ case 256:
134
+ LAUNCH_PAGED_ATTENTION_V2(256);
135
+ break;
136
+ default:
137
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
138
+ break;
139
+ }
140
+ }
141
+
142
+ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
143
+ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
144
+ IS_BLOCK_SPARSE>( \
145
+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
146
+ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
147
+ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
148
+ blocksparse_vert_stride, blocksparse_block_size, \
149
+ blocksparse_head_sliding_step);
150
+
151
+ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
152
+ if (is_block_sparse) { \
153
+ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
154
+ } else { \
155
+ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
156
+ }
157
+
158
+ // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
159
+ // 1, 2, 4, 64, 128, 256.
160
+ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
161
+ switch (block_size) { \
162
+ case 8: \
163
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
164
+ break; \
165
+ case 16: \
166
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
167
+ break; \
168
+ case 32: \
169
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
170
+ break; \
171
+ default: \
172
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
173
+ break; \
174
+ }
175
+
176
+ void paged_attention_v2(
177
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
178
+ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
179
+ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
180
+ torch::Tensor&
181
+ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
182
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
183
+ torch::Tensor&
184
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
185
+ torch::Tensor&
186
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
187
+ int64_t num_kv_heads, // [num_heads]
188
+ double scale,
189
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
190
+ torch::Tensor& seq_lens, // [num_seqs]
191
+ int64_t block_size, int64_t max_seq_len,
192
+ const std::optional<torch::Tensor>& alibi_slopes,
193
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
194
+ torch::Tensor& v_scale, const int64_t tp_rank,
195
+ const int64_t blocksparse_local_blocks,
196
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
197
+ const int64_t blocksparse_head_sliding_step) {
198
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
199
+ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
200
+ CALL_V2_LAUNCHER_BLOCK_SIZE)
201
+ }
202
+
203
+ #undef WARP_SIZE
204
+ #undef MAX
205
+ #undef MIN
206
+ #undef DIVIDE_ROUND_UP
paged-attention/cache_kernels.cu ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include "cuda_compat.h"
6
+ #include "dispatch_utils.h"
7
+
8
+ #ifdef USE_ROCM
9
+ #include "quantization/fp8/amd/quant_utils.cuh"
10
+ #else
11
+ #include "quantization/fp8/nvidia/quant_utils.cuh"
12
+ #endif
13
+
14
+ #include <algorithm>
15
+ #include <cassert>
16
+ #include <map>
17
+ #include <vector>
18
+
19
+ #ifdef USE_ROCM
20
+ #include <hip/hip_bf16.h>
21
+ typedef __hip_bfloat16 __nv_bfloat16;
22
+ #endif
23
+
24
+ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
25
+ const torch::Tensor& block_mapping) {
26
+ torch::Device src_device = src.device();
27
+ torch::Device dst_device = dst.device();
28
+ cudaMemcpyKind memcpy_type;
29
+ if (src_device.is_cuda() && dst_device.is_cuda()) {
30
+ TORCH_CHECK(src_device.index() == dst_device.index(),
31
+ "src and dst must be on the same GPU");
32
+ memcpy_type = cudaMemcpyDeviceToDevice;
33
+ } else if (src_device.is_cuda() && dst_device.is_cpu()) {
34
+ memcpy_type = cudaMemcpyDeviceToHost;
35
+ } else if (src_device.is_cpu() && dst_device.is_cuda()) {
36
+ memcpy_type = cudaMemcpyHostToDevice;
37
+ } else {
38
+ TORCH_CHECK(false, "Invalid device combination");
39
+ }
40
+
41
+ // NOTE(youkaichao): keep in mind that `block_mapping` should be
42
+ // a cpu tensor, otherwise every `item` call will require a gpu-cpu
43
+ // synchronization.
44
+ TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
45
+
46
+ char* src_ptr = static_cast<char*>(src.data_ptr());
47
+ char* dst_ptr = static_cast<char*>(dst.data_ptr());
48
+
49
+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
50
+ const at::cuda::OptionalCUDAGuard device_guard(
51
+ src_device.is_cuda() ? src_device : dst_device);
52
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
53
+ // NOTE(woosuk): This can be slow if the number of blocks is large.
54
+ const int64_t num_blocks = block_mapping.size(0);
55
+ for (size_t i = 0; i < num_blocks; i++) {
56
+ int64_t src_block_number = block_mapping[i][0].item<int64_t>();
57
+ int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
58
+ int64_t src_offset = src_block_number * block_size_in_bytes;
59
+ int64_t dst_offset = dst_block_number * block_size_in_bytes;
60
+ cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
61
+ block_size_in_bytes, memcpy_type, stream);
62
+ }
63
+ }
64
+
65
+ namespace vllm {
66
+
67
+ // Grid: (num_layers, num_pairs)
68
+ template <typename scalar_t>
69
+ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
70
+ int64_t* value_cache_ptrs,
71
+ const int64_t* __restrict__ block_mapping,
72
+ const int numel_per_block) {
73
+ const int layer_idx = blockIdx.x;
74
+ const int pair_idx = blockIdx.y;
75
+
76
+ scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
77
+ scalar_t* value_cache =
78
+ reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
79
+ int64_t src_block_number = block_mapping[2 * pair_idx];
80
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
81
+
82
+ const int64_t src_block_offset = src_block_number * numel_per_block;
83
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
84
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
85
+ int64_t src_offset = src_block_offset + i;
86
+ int64_t dst_offset = dst_block_offset + i;
87
+ key_cache[dst_offset] = key_cache[src_offset];
88
+ }
89
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
90
+ int64_t src_offset = src_block_offset + i;
91
+ int64_t dst_offset = dst_block_offset + i;
92
+ value_cache[dst_offset] = value_cache[src_offset];
93
+ }
94
+ }
95
+
96
+ } // namespace vllm
97
+
98
+ // Note: the key_caches and value_caches vectors are constant but
99
+ // not the Tensors they contain. The vectors need to be const refs
100
+ // in order to satisfy pytorch's C++ operator registration code.
101
+ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
102
+ std::vector<torch::Tensor> const& value_caches,
103
+ const torch::Tensor& block_mapping) {
104
+ int num_layers = key_caches.size();
105
+ TORCH_CHECK(num_layers == value_caches.size());
106
+ if (num_layers == 0) {
107
+ return;
108
+ }
109
+ torch::Device cache_device = key_caches[0].device();
110
+ TORCH_CHECK(cache_device.is_cuda());
111
+
112
+ // Create data structures for the kernel.
113
+ // Create an array of pointers to the key and value caches.
114
+ int64_t key_cache_ptrs[num_layers];
115
+ int64_t value_cache_ptrs[num_layers];
116
+ for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
117
+ key_cache_ptrs[layer_idx] =
118
+ reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
119
+ value_cache_ptrs[layer_idx] =
120
+ reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
121
+ }
122
+
123
+ // block_mapping is a 2D tensor with shape (num_pairs, 2).
124
+ int num_pairs = block_mapping.size(0);
125
+
126
+ // Move the data structures to the GPU.
127
+ // NOTE: This synchronizes the CPU and GPU.
128
+ torch::Tensor key_cache_ptrs_tensor =
129
+ torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
130
+ .to(cache_device);
131
+ torch::Tensor value_cache_ptrs_tensor =
132
+ torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
133
+ .to(cache_device);
134
+
135
+ // Launch the kernel.
136
+ const int numel_per_block = key_caches[0][0].numel();
137
+ dim3 grid(num_layers, num_pairs);
138
+ dim3 block(std::min(1024, numel_per_block));
139
+ const at::cuda::OptionalCUDAGuard device_guard(cache_device);
140
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
141
+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
142
+ key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
143
+ vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
144
+ key_cache_ptrs_tensor.data_ptr<int64_t>(),
145
+ value_cache_ptrs_tensor.data_ptr<int64_t>(),
146
+ block_mapping.data_ptr<int64_t>(), numel_per_block);
147
+ }));
148
+ }
149
+
150
+ namespace vllm {
151
+
152
+ template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
153
+ __global__ void reshape_and_cache_kernel(
154
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
155
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
156
+ cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
157
+ // block_size, x]
158
+ cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
159
+ // block_size]
160
+ const int64_t* __restrict__ slot_mapping, // [num_tokens]
161
+ const int key_stride, const int value_stride, const int num_heads,
162
+ const int head_size, const int block_size, const int x,
163
+ const float* k_scale, const float* v_scale) {
164
+ const int64_t token_idx = blockIdx.x;
165
+ const int64_t slot_idx = slot_mapping[token_idx];
166
+ if (slot_idx < 0) {
167
+ // Padding token that should be ignored.
168
+ return;
169
+ }
170
+
171
+ const int64_t block_idx = slot_idx / block_size;
172
+ const int64_t block_offset = slot_idx % block_size;
173
+
174
+ const int n = num_heads * head_size;
175
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
176
+ const int64_t src_key_idx = token_idx * key_stride + i;
177
+ const int64_t src_value_idx = token_idx * value_stride + i;
178
+
179
+ const int head_idx = i / head_size;
180
+ const int head_offset = i % head_size;
181
+ const int x_idx = head_offset / x;
182
+ const int x_offset = head_offset % x;
183
+
184
+ const int64_t tgt_key_idx =
185
+ block_idx * num_heads * (head_size / x) * block_size * x +
186
+ head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
187
+ block_offset * x + x_offset;
188
+ const int64_t tgt_value_idx =
189
+ block_idx * num_heads * head_size * block_size +
190
+ head_idx * head_size * block_size + head_offset * block_size +
191
+ block_offset;
192
+ scalar_t tgt_key = key[src_key_idx];
193
+ scalar_t tgt_value = value[src_value_idx];
194
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
195
+ key_cache[tgt_key_idx] = tgt_key;
196
+ value_cache[tgt_value_idx] = tgt_value;
197
+ } else {
198
+ key_cache[tgt_key_idx] =
199
+ fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
200
+ value_cache[tgt_value_idx] =
201
+ fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
202
+ }
203
+ }
204
+ }
205
+
206
+ template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
207
+ __global__ void reshape_and_cache_flash_kernel(
208
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
209
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
210
+ cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
211
+ // head_size]
212
+ cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
213
+ // head_size]
214
+ const int64_t* __restrict__ slot_mapping, // [num_tokens]
215
+ const int block_stride, const int key_stride, const int value_stride,
216
+ const int num_heads, const int head_size, const int block_size,
217
+ const float* k_scale, const float* v_scale) {
218
+ const int64_t token_idx = blockIdx.x;
219
+ const int64_t slot_idx = slot_mapping[token_idx];
220
+ // NOTE: slot_idx can be -1 if the token is padded
221
+ if (slot_idx < 0) {
222
+ return;
223
+ }
224
+ const int64_t block_idx = slot_idx / block_size;
225
+ const int64_t block_offset = slot_idx % block_size;
226
+ const int n = num_heads * head_size;
227
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
228
+ const int64_t src_key_idx = token_idx * key_stride + i;
229
+ const int64_t src_value_idx = token_idx * value_stride + i;
230
+ const int head_idx = i / head_size;
231
+ const int head_offset = i % head_size;
232
+ const int64_t tgt_key_value_idx = block_idx * block_stride +
233
+ block_offset * num_heads * head_size +
234
+ head_idx * head_size + head_offset;
235
+ scalar_t tgt_key = key[src_key_idx];
236
+ scalar_t tgt_value = value[src_value_idx];
237
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
238
+ key_cache[tgt_key_value_idx] = tgt_key;
239
+ value_cache[tgt_key_value_idx] = tgt_value;
240
+ } else {
241
+ key_cache[tgt_key_value_idx] =
242
+ fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
243
+ value_cache[tgt_key_value_idx] =
244
+ fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
245
+ }
246
+ }
247
+ }
248
+ } // namespace vllm
249
+
250
+ // KV_T is the stored data type of kv-cache.
251
+ // CACHE_T is the data type of key and value tensors.
252
+ // KV_DTYPE is the real data type of kv-cache.
253
+ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
254
+ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
255
+ <<<grid, block, 0, stream>>>( \
256
+ reinterpret_cast<KV_T*>(key.data_ptr()), \
257
+ reinterpret_cast<KV_T*>(value.data_ptr()), \
258
+ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
259
+ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
260
+ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
261
+ num_heads, head_size, block_size, x, \
262
+ reinterpret_cast<const float*>(k_scale.data_ptr()), \
263
+ reinterpret_cast<const float*>(v_scale.data_ptr()));
264
+
265
+ void reshape_and_cache(
266
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
267
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
268
+ torch::Tensor&
269
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
270
+ torch::Tensor&
271
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
272
+ torch::Tensor& slot_mapping, // [num_tokens]
273
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
274
+ torch::Tensor& v_scale) {
275
+ int num_tokens = key.size(0);
276
+ int num_heads = key.size(1);
277
+ int head_size = key.size(2);
278
+ int block_size = key_cache.size(3);
279
+ int x = key_cache.size(4);
280
+
281
+ int key_stride = key.stride(0);
282
+ int value_stride = value.stride(0);
283
+
284
+ dim3 grid(num_tokens);
285
+ dim3 block(std::min(num_heads * head_size, 512));
286
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
287
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
288
+
289
+ DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
290
+ CALL_RESHAPE_AND_CACHE)
291
+ }
292
+
293
+ // KV_T is the stored data type of kv-cache.
294
+ // CACHE_T is the data type of key and value tensors.
295
+ // KV_DTYPE is the real data type of kv-cache.
296
+ #define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
297
+ vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
298
+ <<<grid, block, 0, stream>>>( \
299
+ reinterpret_cast<KV_T*>(key.data_ptr()), \
300
+ reinterpret_cast<KV_T*>(value.data_ptr()), \
301
+ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
302
+ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
303
+ slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
304
+ value_stride, num_heads, head_size, block_size, \
305
+ reinterpret_cast<const float*>(k_scale.data_ptr()), \
306
+ reinterpret_cast<const float*>(v_scale.data_ptr()));
307
+
308
+ void reshape_and_cache_flash(
309
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
310
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
311
+ torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
312
+ torch::Tensor&
313
+ value_cache, // [num_blocks, block_size, num_heads, head_size]
314
+ torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
315
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
316
+ torch::Tensor& v_scale) {
317
+ // NOTE(woosuk): In vLLM V1, key.size(0) can be different from
318
+ // slot_mapping.size(0) because of padding for CUDA graphs.
319
+ // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
320
+ // both include padding.
321
+ // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
322
+ // since key includes padding for CUDA graphs, while slot_mapping does not.
323
+ // In this case, slot_mapping.size(0) represents the actual number of tokens
324
+ // before padding.
325
+ // For compatibility with both cases, we use slot_mapping.size(0) as the
326
+ // number of tokens.
327
+ int num_tokens = slot_mapping.size(0);
328
+ int num_heads = key.size(1);
329
+ int head_size = key.size(2);
330
+ int block_size = key_cache.size(1);
331
+
332
+ int key_stride = key.stride(0);
333
+ int value_stride = value.stride(0);
334
+ int block_stride = key_cache.stride(0);
335
+ TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
336
+
337
+ dim3 grid(num_tokens);
338
+ dim3 block(std::min(num_heads * head_size, 512));
339
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
340
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
341
+
342
+ DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
343
+ CALL_RESHAPE_AND_CACHE_FLASH);
344
+ }
345
+
346
+ namespace vllm {
347
+
348
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
349
+ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
350
+ Tout* __restrict__ dst_cache,
351
+ const float scale,
352
+ const int64_t block_stride) {
353
+ const int64_t block_idx = blockIdx.x;
354
+ for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
355
+ int64_t idx = block_idx * block_stride + i;
356
+ dst_cache[idx] =
357
+ fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
358
+ }
359
+ }
360
+
361
+ } // namespace vllm
362
+
363
+ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
364
+ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
365
+ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
366
+ reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
367
+
368
+ // Only for testing.
369
+ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
370
+ const double scale, const std::string& kv_cache_dtype) {
371
+ torch::Device src_device = src_cache.device();
372
+ torch::Device dst_device = dst_cache.device();
373
+ TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
374
+ TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
375
+ TORCH_CHECK(src_device.index() == dst_device.index(),
376
+ "src and dst must be on the same GPU");
377
+ at::cuda::OptionalCUDAGuard device_guard(src_device);
378
+
379
+ int64_t num_blocks = src_cache.size(0);
380
+ int64_t block_stride = src_cache.stride(0);
381
+
382
+ dim3 grid(num_blocks);
383
+ dim3 block(std::min(block_stride, int64_t(512)));
384
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
385
+
386
+ if (kv_cache_dtype == "auto") {
387
+ if (src_cache.dtype() == at::ScalarType::Float) {
388
+ CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
389
+ } else if (src_cache.dtype() == at::ScalarType::Half) {
390
+ CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
391
+ } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
392
+ CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
393
+ } else if (dst_cache.dtype() == at::ScalarType::Float) {
394
+ CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
395
+ } else if (dst_cache.dtype() == at::ScalarType::Half) {
396
+ CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
397
+ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
398
+ CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
399
+ }
400
+ } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
401
+ if (src_cache.dtype() == at::ScalarType::Float) {
402
+ CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
403
+ } else if (src_cache.dtype() == at::ScalarType::Half) {
404
+ CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
405
+ } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
406
+ CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
407
+ vllm::Fp8KVCacheDataType::kFp8E4M3);
408
+ } else if (dst_cache.dtype() == at::ScalarType::Float) {
409
+ CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
410
+ } else if (dst_cache.dtype() == at::ScalarType::Half) {
411
+ CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
412
+ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
413
+ CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
414
+ vllm::Fp8KVCacheDataType::kFp8E4M3);
415
+ }
416
+ } else {
417
+ TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
418
+ }
419
+ }
paged-attention/cuda_compat.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_ROCM
4
+ #include <hip/hip_runtime.h>
5
+ #endif
6
+
7
+ #ifndef USE_ROCM
8
+ #define WARP_SIZE 32
9
+ #else
10
+ #define WARP_SIZE warpSize
11
+ #endif
12
+
13
+ #ifndef USE_ROCM
14
+ #define VLLM_LDG(arg) __ldg(arg)
15
+ #else
16
+ #define VLLM_LDG(arg) *(arg)
17
+ #endif
18
+
19
+ #ifndef USE_ROCM
20
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
21
+ __shfl_xor_sync(uint32_t(-1), var, lane_mask)
22
+ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
23
+ __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
24
+ #else
25
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
26
+ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
27
+ __shfl_xor(var, lane_mask, width)
28
+ #endif
29
+
30
+ #ifndef USE_ROCM
31
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
32
+ #else
33
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
34
+ #endif
35
+
36
+ #ifndef USE_ROCM
37
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
38
+ __shfl_down_sync(uint32_t(-1), var, lane_delta)
39
+ #else
40
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
41
+ #endif
42
+
43
+ #ifndef USE_ROCM
44
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
45
+ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
46
+ #else
47
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
48
+ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
49
+ #endif
paged-attention/dispatch_utils.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
+ */
5
+ #pragma once
6
+
7
+ #include <torch/all.h>
8
+
9
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
+
14
+ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16
+
17
+ // TODO(luka/varun): use FP8_TYPE macro after refactoring
18
+ #ifndef USE_ROCM
19
+ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
20
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
21
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
22
+ #else
23
+ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
24
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
25
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
26
+ #endif
27
+
28
+ #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
29
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
30
+
31
+ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
33
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
34
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
35
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
36
+
37
+ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
38
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
39
+ VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
40
+
41
+ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
42
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
43
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
44
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
45
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
46
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
47
+
48
+ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
49
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
paged-attention/quantization/fp8/amd/hip_float8.h ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef __HIPCC__
4
+ #include <hip/hip_runtime.h>
5
+ #else
6
+ #include <type_traits>
7
+ #include <stdint.h>
8
+ #include <math.h>
9
+ #include <iostream>
10
+ #endif
11
+
12
+ #include "hip_float8_impl.h"
13
+
14
+ struct alignas(1) hip_fp8 {
15
+ struct from_bits_t {};
16
+ HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
17
+ return from_bits_t();
18
+ }
19
+ uint8_t data;
20
+
21
+ hip_fp8() = default;
22
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
23
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
24
+ explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
25
+ : data(v) {}
26
+
27
+ #ifdef __HIP__MI300__
28
+ // NOTE: ON-DEVICE... always optimal bias
29
+ explicit HIP_FP8_DEVICE hip_fp8(float v)
30
+ : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
31
+
32
+ explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
33
+ : hip_fp8(static_cast<float>(v)) {}
34
+
35
+ // Host only implementation using s/w simulation
36
+ explicit HIP_FP8_HOST
37
+ #else // __HIP__MI300__
38
+ // both Host and DEVICE for non-MI300 using s/w simulation
39
+ explicit HIP_FP8_HOST_DEVICE
40
+ #endif // __HIP__MI300__
41
+ hip_fp8(float v) {
42
+ data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
43
+ true /*clip*/>(v);
44
+ }
45
+
46
+ explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
47
+ : hip_fp8(static_cast<float>(v)) {}
48
+
49
+ #ifdef __HIP__MI300__
50
+ // upcast using device specific intrinsic
51
+ explicit inline HIP_FP8_DEVICE operator float() const {
52
+ float fval;
53
+ uint32_t i32val = static_cast<uint32_t>(data);
54
+
55
+ // upcast
56
+ asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
57
+ : "=v"(fval)
58
+ : "v"(i32val));
59
+
60
+ return fval;
61
+ }
62
+
63
+ explicit inline HIP_FP8_HOST operator float() const
64
+ #else // __HIP__MI300__
65
+ explicit inline HIP_FP8_HOST_DEVICE operator float() const
66
+ #endif // __HIP__MI300__
67
+ {
68
+ return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
69
+ data);
70
+ }
71
+ };
72
+
73
+ namespace std {
74
+ inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
75
+ inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
76
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
77
+ } // namespace std
78
+
79
+ // Special operator overloading
80
+ inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
81
+ return os << float(f8);
82
+ }
83
+
84
+ // all + operator overloading with mixed types
85
+ // mixed types, always converts to f32, does computation in f32, and returns
86
+ // float
87
+ inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
88
+ return (fa + float(b));
89
+ }
90
+
91
+ inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
92
+ return (float(a) + fb);
93
+ }
94
+
95
+ inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
96
+ return hip_fp8(float(a) + float(b));
97
+ }
98
+
99
+ inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
100
+ return a = hip_fp8(float(a) + float(b));
101
+ }
102
+
103
+ // overloading multiplication, always returns float,
104
+ inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
105
+ return float(a) * float(b);
106
+ }
107
+
108
+ inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
109
+ return (a * float(b));
110
+ }
111
+
112
+ inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
113
+ return (float(a) * b);
114
+ }
115
+
116
+ inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
117
+ return ((float)a * float(b));
118
+ }
119
+
120
+ inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
121
+ return ((float)a * float(b));
122
+ }
123
+
124
+ // overloading for compare
125
+ inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
126
+ return (a.data == b.data);
127
+ }
128
+ inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
129
+ return (a.data != b.data);
130
+ }
131
+
132
+ inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
133
+ return static_cast<float>(a) >= static_cast<float>(b);
134
+ }
135
+ inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
136
+ return static_cast<float>(a) > static_cast<float>(b);
137
+ }
paged-attention/quantization/fp8/amd/hip_float8_impl.h ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__HIPCC__) && \
4
+ (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
5
+ #define __HIP__MI300__
6
+ #endif
7
+
8
+ #ifdef __HIPCC__
9
+ #define HIP_FP8_HOST_DEVICE __host__ __device__
10
+ #define HIP_FP8_HOST __host__
11
+ #define HIP_FP8_DEVICE __device__
12
+ #else
13
+ #define HIP_FP8_HOST_DEVICE
14
+ #define HIP_FP8_HOST
15
+ #define HIP_FP8_DEVICE
16
+ #endif
17
+
18
+ namespace hip_fp8_impl {
19
+
20
+ #ifdef __HIP__MI300__
21
+ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
22
+ uint8_t i8data;
23
+ union {
24
+ float fval;
25
+ uint32_t i32val;
26
+ uint8_t i8val[4]; // NOTE: not endian independent
27
+ } val;
28
+
29
+ uint32_t ival = 0;
30
+ val.fval = v;
31
+
32
+ if ((val.i32val & 0x7F800000) !=
33
+ 0x7F800000) { /// propagate NAN/INF, no clipping
34
+ val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
35
+ }
36
+
37
+ ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
38
+ false); // false -> WORD0
39
+ val.i32val = ival;
40
+ i8data = val.i8val[0];
41
+
42
+ return i8data;
43
+ }
44
+ #endif // __HIP__MI300__
45
+
46
+ HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
47
+ #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
48
+ HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
49
+ #endif
50
+
51
+ template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
52
+ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
53
+ uint32_t rng = 0) {
54
+ #ifdef __HIPCC__
55
+ constexpr bool is_half = std::is_same<T, _Float16>::value;
56
+ #else
57
+ constexpr bool is_half = false;
58
+ #endif
59
+ constexpr bool is_float = std::is_same<T, float>::value;
60
+ static_assert(wm + we == 7, "wm+we==7");
61
+ static_assert(is_half || is_float, "Only half and float can be cast to f8");
62
+
63
+ const int mfmt = (sizeof(T) == 4) ? 23 : 10;
64
+ uint32_t x;
65
+ if (sizeof(T) == 4) {
66
+ x = reinterpret_cast<uint32_t&>(_x);
67
+ } else {
68
+ x = reinterpret_cast<uint16_t&>(_x);
69
+ }
70
+
71
+ uint32_t head, mantissa;
72
+ int exponent, bias;
73
+ uint32_t sign;
74
+
75
+ if (sizeof(T) == 4) {
76
+ head = x & 0xFF800000;
77
+ mantissa = x & 0x7FFFFF;
78
+ exponent = (head >> 23) & 0xFF;
79
+ sign = head >> 31;
80
+ bias = 127;
81
+ } else {
82
+ head = x & 0xFC00;
83
+ mantissa = x & 0x3FF;
84
+ exponent = (head >> 10) & 0x1F;
85
+ sign = head >> 15;
86
+ bias = 15;
87
+ }
88
+
89
+ uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
90
+
91
+ // Deal with inf and NaNs
92
+ if (negative_zero_nan) {
93
+ if (sizeof(T) == 4) {
94
+ if ((x & 0x7F800000) == 0x7F800000) {
95
+ return 0x80;
96
+ }
97
+ } else {
98
+ // if(__hisinf(x) || __hisnan(x))
99
+ if ((x & 0x7C00) == 0x7C00) {
100
+ return 0x80;
101
+ }
102
+ }
103
+ } else {
104
+ if (sizeof(T) == 4) {
105
+ if ((x & 0x7F800000) == 0x7F800000) {
106
+ return signed_inf + (mantissa != 0 ? 1 : 0);
107
+ }
108
+ } else {
109
+ if ((x & 0x7C00) == 0x7C00) {
110
+ return signed_inf + (mantissa != 0 ? 1 : 0);
111
+ }
112
+ }
113
+ }
114
+ if (x == 0) {
115
+ return 0;
116
+ }
117
+
118
+ // First need to check if it is normal or denorm as there is a difference of
119
+ // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
120
+ // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
121
+ // to mantissa and truncate. And for RNE, no need to add rng. Then probably
122
+ // need to check whether there is carry and adjust exponent and mantissa again
123
+
124
+ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
125
+ // bits
126
+ const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
127
+ const int f8_denormal_act_exponent =
128
+ 1 - f8_bias; // actual exponent of f8 denormal
129
+ // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
130
+ // f8_exponent is the converted f8 exponent with bias encoding
131
+ // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
132
+ // the difference needs to be adjusted and mantissa shifted
133
+ int act_exponent, f8_exponent, exponent_diff;
134
+
135
+ if (exponent == 0) { // fp32/fp16 is in denormal.
136
+ /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
137
+ mostly concern fp16 here. In this case, f8 is usually in denormal. But there
138
+ could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
139
+ exponent bias 16. It means that there are some numbers in fp16 denormal but they
140
+ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
141
+ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
142
+ (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
143
+ act_exponent = exponent - bias + 1;
144
+ exponent_diff =
145
+ f8_denormal_act_exponent -
146
+ act_exponent; // actual exponent is exponent-bias+1 as it is denormal
147
+ } else { // fp32/fp16 is normal with implicit 1
148
+ act_exponent = exponent - bias;
149
+ if (act_exponent <= f8_denormal_act_exponent) {
150
+ /* This is the case where fp32/fp16 is normal but it is in f8 denormal
151
+ range. For example fp8 nanoo mode, denormal exponent is -7, but if the
152
+ fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
153
+ Therefore it needs to be adjust to -6 and mantissa shift right by 1.
154
+ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
155
+ exponent_diff = f8_denormal_act_exponent - act_exponent;
156
+ } else { // both fp32/fp16 and f8 are in normal range
157
+ exponent_diff = 0; // exponent_diff=0 does not mean there is no
158
+ // difference for this case, act_exponent could be
159
+ // larger. Just that it does not need shift mantissa
160
+ }
161
+ mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
162
+ }
163
+
164
+ bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
165
+ static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
166
+ /* This part is a bit tricky. The judgment of whether it is a tie needs to be
167
+ done before we shift right as shift right could rip off some residual part
168
+ and make something not midpoint look like midpoint. For example, the fp16
169
+ number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
170
+ shift right by 4 bits, it would look like midpoint.
171
+ */
172
+
173
+ if (exponent_diff > 0) {
174
+ mantissa >>= exponent_diff;
175
+ } else if (exponent_diff == -1) {
176
+ mantissa <<= -exponent_diff;
177
+ }
178
+ bool implicit_one = mantissa & (1 << mfmt);
179
+ // if there is no implicit 1, it means the f8 is denormal and need to adjust
180
+ // to denorm exponent
181
+ f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
182
+ f8_bias - (implicit_one ? 0 : 1);
183
+
184
+ // Now we have the exponent and mantissa adjusted
185
+ uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
186
+ bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
187
+ // that is not truncated is 1
188
+ mantissa +=
189
+ (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
190
+ drop_mask;
191
+
192
+ // Now we deal with overflow
193
+ if (f8_exponent == 0) {
194
+ if ((1 << mfmt) & mantissa) {
195
+ f8_exponent = 1; // denormal overflow to become normal, promote exponent
196
+ }
197
+ } else {
198
+ if ((1 << (mfmt + 1)) & mantissa) {
199
+ mantissa >>= 1;
200
+ f8_exponent++;
201
+ }
202
+ }
203
+
204
+ mantissa >>= (mfmt - wm);
205
+
206
+ // above range: quantize to maximum possible float of the same sign
207
+ const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
208
+ if (f8_exponent > max_exp) {
209
+ if (clip) {
210
+ mantissa = (1 << wm) - 1;
211
+ f8_exponent = max_exp;
212
+ } else {
213
+ return signed_inf;
214
+ }
215
+ }
216
+
217
+ if (f8_exponent == 0 && mantissa == 0) {
218
+ return negative_zero_nan ? 0 : (sign << 7);
219
+ }
220
+ mantissa &= (1 << wm) - 1;
221
+ return (sign << 7) | (f8_exponent << wm) | mantissa;
222
+ }
223
+
224
+ template <int we, int wm, typename T = float, bool negative_zero_nan = true>
225
+ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
226
+ #ifdef __HIPCC__
227
+ constexpr bool is_half = std::is_same<T, _Float16>::value;
228
+ #else
229
+ constexpr bool is_half = false;
230
+ #endif
231
+ constexpr bool is_float = std::is_same<T, float>::value;
232
+ static_assert(is_half || is_float, "only half and float are supported");
233
+
234
+ constexpr int weo = is_half ? 5 : 8;
235
+ constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
236
+
237
+ T fInf, fNegInf, fNaN, fNeg0;
238
+
239
+ #ifdef __HIPCC__
240
+ if (is_half) {
241
+ const uint16_t ihInf = 0x7C00;
242
+ const uint16_t ihNegInf = 0xFC00;
243
+ const uint16_t ihNaN = 0x7C01;
244
+ const uint16_t ihNeg0 = 0x8000;
245
+ fInf = reinterpret_cast<const _Float16&>(ihInf);
246
+ fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
247
+ fNaN = reinterpret_cast<const _Float16&>(ihNaN);
248
+ fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
249
+ } else
250
+ #endif
251
+ if (is_float) {
252
+ const uint32_t ifInf = 0x7F800000;
253
+ const uint32_t ifNegInf = 0xFF800000;
254
+ const uint32_t ifNaN = 0x7F800001;
255
+ const uint32_t ifNeg0 = 0x80000000;
256
+ fInf = reinterpret_cast<const float&>(ifInf);
257
+ fNegInf = reinterpret_cast<const float&>(ifNegInf);
258
+ fNaN = reinterpret_cast<const float&>(ifNaN);
259
+ fNeg0 = reinterpret_cast<const float&>(ifNeg0);
260
+ }
261
+
262
+ if (x == 0) {
263
+ return 0;
264
+ }
265
+
266
+ uint32_t sign = x >> 7;
267
+ uint32_t mantissa = x & ((1 << wm) - 1);
268
+ int exponent = (x & 0x7F) >> wm;
269
+ if (negative_zero_nan) {
270
+ if (x == 0x80) {
271
+ return fNaN;
272
+ }
273
+ } else {
274
+ if (x == 0x80) {
275
+ return fNeg0;
276
+ }
277
+ if (exponent == ((1 << we) - 1)) {
278
+ return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
279
+ }
280
+ }
281
+ typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
282
+ if (we == 5 && is_half && !negative_zero_nan) {
283
+ retval = x << 8;
284
+ return reinterpret_cast<const T&>(retval);
285
+ }
286
+
287
+ const int exp_low_cutoff =
288
+ (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
289
+
290
+ // subnormal input
291
+ if (exponent == 0) {
292
+ // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
293
+ int sh = 1 + clz(mantissa) - (32 - wm);
294
+ mantissa <<= sh;
295
+ exponent += 1 - sh;
296
+ mantissa &= ((1 << wm) - 1);
297
+ }
298
+ exponent += exp_low_cutoff - 1;
299
+ mantissa <<= wmo - wm;
300
+
301
+ // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
302
+ if (exponent <= 0) {
303
+ mantissa |= 1 << wmo;
304
+ mantissa >>= 1 - exponent;
305
+ exponent = 0;
306
+ }
307
+
308
+ if (sizeof(T) == 2) {
309
+ retval = (sign << 15) | (exponent << 10) | mantissa;
310
+ } else {
311
+ retval = (sign << 31) | (exponent << 23) | mantissa;
312
+ }
313
+ return reinterpret_cast<const T&>(retval);
314
+ }
315
+
316
+ } // namespace hip_fp8_impl
paged-attention/quantization/fp8/amd/quant_utils.cuh ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include "hip_float8.h"
3
+
4
+ #include <hip/hip_fp16.h>
5
+ #include <hip/hip_bf16.h>
6
+ #include <hip/hip_bfloat16.h>
7
+
8
+ #include "../../../attention/dtype_fp8.cuh"
9
+ #include "../../../attention/dtype_float32.cuh"
10
+ #include "../../../attention/dtype_bfloat16.cuh"
11
+
12
+ namespace vllm {
13
+ #ifdef USE_ROCM
14
+
15
+ namespace fp8 {
16
+ #ifdef ENABLE_FP8
17
+
18
+ template <typename Tout, typename Tin>
19
+ __inline__ __device__ Tout vec_conversion(const Tin& x) {
20
+ return x;
21
+ }
22
+
23
+ template <typename Tout, typename Tin>
24
+ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
25
+ const float scale) {
26
+ return x;
27
+ }
28
+
29
+ // fp8 -> half
30
+ template <>
31
+ __inline__ __device__ uint16_t
32
+ vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
33
+ hip_fp8 f8{a, hip_fp8::from_bits()};
34
+ __half_raw res;
35
+ res.data = static_cast<float>(f8);
36
+ return res.x;
37
+ }
38
+
39
+ // fp8x2 -> half2
40
+ template <>
41
+ __inline__ __device__ uint32_t
42
+ vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
43
+ #if defined(__HIP__MI300__) && \
44
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
45
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
46
+ union {
47
+ __half2_raw h2r;
48
+ uint32_t ui32;
49
+ } tmp;
50
+ tmp.h2r.x.data = f2[0];
51
+ tmp.h2r.y.data = f2[1];
52
+ return tmp.ui32;
53
+ #else
54
+ union {
55
+ uint16_t u16[2];
56
+ uint32_t u32;
57
+ } tmp;
58
+
59
+ tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
60
+ tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
61
+ return tmp.u32;
62
+ #endif
63
+ }
64
+
65
+ // fp8x4 -> half2x2
66
+ template <>
67
+ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
68
+ union {
69
+ uint2 u32x2;
70
+ uint32_t u32[2];
71
+ } tmp;
72
+ tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
73
+ tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
74
+ return tmp.u32x2;
75
+ }
76
+
77
+ // fp8x8 -> half2x4
78
+ template <>
79
+ __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
80
+ union {
81
+ uint4 u64x2;
82
+ uint2 u64[2];
83
+ } tmp;
84
+ tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
85
+ tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
86
+ return tmp.u64x2;
87
+ }
88
+
89
+ using __nv_bfloat16 = __hip_bfloat16;
90
+
91
+ // fp8 -> __nv_bfloat16
92
+ template <>
93
+ __inline__ __device__ __nv_bfloat16
94
+ vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
95
+ hip_fp8 f8{a, hip_fp8::from_bits()};
96
+ float f{f8};
97
+ return __float2bfloat16(f);
98
+ }
99
+
100
+ using __nv_bfloat162 = __hip_bfloat162;
101
+
102
+ // fp8x2 -> __nv_bfloat162
103
+ template <>
104
+ __inline__ __device__ __nv_bfloat162
105
+ vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
106
+ __nv_bfloat162 res;
107
+ res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
108
+ res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
109
+ return res;
110
+ }
111
+
112
+ // fp8x4 -> bf16_4_t
113
+ template <>
114
+ __inline__ __device__ bf16_4_t
115
+ vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
116
+ bf16_4_t res;
117
+ res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
118
+ res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
119
+ return res;
120
+ }
121
+
122
+ // fp8x8 -> bf16_8_t
123
+ template <>
124
+ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
125
+ bf16_4_t tmp1, tmp2;
126
+ tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
127
+ tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
128
+ bf16_8_t res;
129
+ res.x = tmp1.x;
130
+ res.y = tmp1.y;
131
+ res.z = tmp2.x;
132
+ res.w = tmp2.y;
133
+ return res;
134
+ }
135
+
136
+ // fp8 -> float
137
+ template <>
138
+ __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
139
+ hip_fp8 fp8{a, hip_fp8::from_bits()};
140
+ return static_cast<float>(fp8);
141
+ }
142
+
143
+ // fp8x2 -> float2
144
+ template <>
145
+ __inline__ __device__ float2
146
+ vec_conversion<float2, uint16_t>(const uint16_t& a) {
147
+ #if defined(__HIP__MI300__) && \
148
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
149
+ float2 res;
150
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
151
+ res.x = f2[0];
152
+ res.y = f2[1];
153
+ return res;
154
+ #else
155
+ float2 res;
156
+ res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
157
+ res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
158
+ return res;
159
+ #endif
160
+ }
161
+
162
+ // fp8x4 -> float4
163
+ template <>
164
+ __inline__ __device__ Float4_
165
+ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
166
+ Float4_ res;
167
+ res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
168
+ res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
169
+ return res;
170
+ }
171
+
172
+ // fp8x8 -> float8
173
+ template <>
174
+ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
175
+ Float4_ tmp1, tmp2;
176
+ tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
177
+ tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
178
+ Float8_ res;
179
+ res.x = tmp1.x;
180
+ res.y = tmp1.y;
181
+ res.z = tmp2.x;
182
+ res.w = tmp2.y;
183
+ return res;
184
+ }
185
+
186
+ // half -> fp8
187
+ template <>
188
+ __inline__ __device__ uint8_t
189
+ vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
190
+ __half_raw tmp;
191
+ tmp.x = a;
192
+
193
+ hip_fp8 f8{static_cast<float>(tmp.data)};
194
+ return f8.data;
195
+ }
196
+
197
+ // bf16 -> fp8
198
+ template <>
199
+ __inline__ __device__ uint8_t
200
+ vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
201
+ hip_fp8 res{__bfloat162float(a)};
202
+ return res.data;
203
+ }
204
+
205
+ // float -> fp8
206
+ template <>
207
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
208
+ hip_fp8 f8(a);
209
+ return f8.data;
210
+ }
211
+
212
+ // fp8x4 -> float4
213
+ template <>
214
+ __inline__ __device__ float4
215
+ vec_conversion<float4, uint32_t>(const uint32_t& a) {
216
+ Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
217
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
218
+ return res;
219
+ }
220
+
221
+ // float2 -> half2
222
+ template <>
223
+ __inline__ __device__ uint32_t
224
+ vec_conversion<uint32_t, float2>(const float2& a) {
225
+ union {
226
+ half2 float16;
227
+ uint32_t uint32;
228
+ };
229
+
230
+ float16 = __float22half2_rn(a);
231
+ return uint32;
232
+ }
233
+
234
+ // Float4 -> half2x2
235
+ template <>
236
+ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
237
+ uint2 b;
238
+ float2 val;
239
+ val.x = a.x.x;
240
+ val.y = a.x.y;
241
+ b.x = vec_conversion<uint32_t, float2>(val);
242
+
243
+ val.x = a.y.x;
244
+ val.y = a.y.y;
245
+ b.y = vec_conversion<uint32_t, float2>(val);
246
+ return b;
247
+ }
248
+
249
+ // Float4 -> float4
250
+ template <>
251
+ __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
252
+ float4 b;
253
+ b.x = a.x.x;
254
+ b.y = a.x.y;
255
+ b.z = a.y.x;
256
+ b.w = a.y.y;
257
+ return b;
258
+ }
259
+
260
+ // Float8 -> half2x4
261
+ template <>
262
+ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
263
+ uint4 b;
264
+ b.x = vec_conversion<uint32_t, float2>(a.x);
265
+ b.y = vec_conversion<uint32_t, float2>(a.y);
266
+ b.z = vec_conversion<uint32_t, float2>(a.z);
267
+ b.w = vec_conversion<uint32_t, float2>(a.w);
268
+ return b;
269
+ }
270
+
271
+ // float2 -> bfloat162
272
+ template <>
273
+ __inline__ __device__ __nv_bfloat162
274
+ vec_conversion<__nv_bfloat162, float2>(const float2& a) {
275
+ __nv_bfloat162 b = __float22bfloat162_rn(a);
276
+ return b;
277
+ }
278
+
279
+ // Float4 -> bfloat162x2
280
+ template <>
281
+ __inline__ __device__ bf16_4_t
282
+ vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
283
+ bf16_4_t b;
284
+ b.x = __float22bfloat162_rn(a.x);
285
+ b.y = __float22bfloat162_rn(a.y);
286
+ return b;
287
+ }
288
+
289
+ // Float8 -> bfloat162x4
290
+ template <>
291
+ __inline__ __device__ bf16_8_t
292
+ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
293
+ bf16_8_t b;
294
+ b.x = __float22bfloat162_rn(a.x);
295
+ b.y = __float22bfloat162_rn(a.y);
296
+ b.z = __float22bfloat162_rn(a.z);
297
+ b.w = __float22bfloat162_rn(a.w);
298
+ return b;
299
+ }
300
+
301
+ /* Scaled and vectorized conversions, for data exchange between high and low
302
+ precision domains
303
+
304
+ Convention of the scale in API, e.g: FP8_data = Quantization(
305
+ High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
306
+ scale => HP
307
+
308
+ */
309
+
310
+ // fp8 -> half
311
+ template <>
312
+ __inline__ __device__ uint16_t
313
+ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
314
+ hip_fp8 f8{a, hip_fp8::from_bits()};
315
+ __half_raw res;
316
+ res.data = static_cast<float>(f8) * scale;
317
+ return res.x;
318
+ }
319
+
320
+ // fp8x2 -> half2
321
+ template <>
322
+ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
323
+ const uint16_t& a, const float scale) {
324
+ #if defined(__HIP__MI300__) && \
325
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
326
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
327
+ union {
328
+ __half2_raw h2r;
329
+ uint32_t ui32;
330
+ } tmp;
331
+ tmp.h2r.x.data = f2[0] * scale;
332
+ tmp.h2r.y.data = f2[1] * scale;
333
+ return tmp.ui32;
334
+ #else
335
+ union {
336
+ uint16_t u16[2];
337
+ uint32_t u32;
338
+ } tmp;
339
+
340
+ tmp.u16[0] =
341
+ scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
342
+ tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
343
+ static_cast<uint8_t>(a >> 8U), scale);
344
+ return tmp.u32;
345
+ #endif
346
+ }
347
+
348
+ // fp8x4 -> half2x2
349
+ template <>
350
+ __inline__ __device__ uint2
351
+ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
352
+ union {
353
+ uint2 u32x2;
354
+ uint32_t u32[2];
355
+ } tmp;
356
+ tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
357
+ tmp.u32[1] =
358
+ scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
359
+ return tmp.u32x2;
360
+ }
361
+
362
+ // fp8x8 -> half2x4
363
+ template <>
364
+ __inline__ __device__ uint4
365
+ scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
366
+ union {
367
+ uint4 u64x2;
368
+ uint2 u64[2];
369
+ } tmp;
370
+ tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
371
+ tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
372
+ return tmp.u64x2;
373
+ }
374
+
375
+ using __nv_bfloat16 = __hip_bfloat16;
376
+
377
+ // fp8 -> __nv_bfloat16
378
+ template <>
379
+ __inline__ __device__ __nv_bfloat16
380
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
381
+ const float scale) {
382
+ hip_fp8 f8{a, hip_fp8::from_bits()};
383
+ float f{f8};
384
+ return __float2bfloat16(f * scale);
385
+ }
386
+
387
+ using __nv_bfloat162 = __hip_bfloat162;
388
+
389
+ // fp8x2 -> __nv_bfloat162
390
+ template <>
391
+ __inline__ __device__ __nv_bfloat162
392
+ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
393
+ const float scale) {
394
+ __nv_bfloat162 res;
395
+ res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
396
+ res.y =
397
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
398
+ return res;
399
+ }
400
+
401
+ // fp8x4 -> bf16_4_t
402
+ template <>
403
+ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
404
+ const uint32_t& a, const float scale) {
405
+ bf16_4_t res;
406
+ res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
407
+ res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
408
+ scale);
409
+ return res;
410
+ }
411
+
412
+ // fp8x8 -> bf16_8_t
413
+ template <>
414
+ __inline__ __device__ bf16_8_t
415
+ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
416
+ bf16_4_t tmp1, tmp2;
417
+ tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
418
+ tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
419
+ bf16_8_t res;
420
+ res.x = tmp1.x;
421
+ res.y = tmp1.y;
422
+ res.z = tmp2.x;
423
+ res.w = tmp2.y;
424
+ return res;
425
+ }
426
+
427
+ // fp8 -> float
428
+ template <>
429
+ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
430
+ const uint8_t& a, const float scale) {
431
+ hip_fp8 fp8{a, hip_fp8::from_bits()};
432
+ return static_cast<float>(fp8) * scale;
433
+ }
434
+
435
+ // fp8x2 -> float2
436
+ template <>
437
+ __inline__ __device__ float2
438
+ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
439
+ #if defined(__HIP__MI300__) && \
440
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
441
+ float2 res;
442
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
443
+ res.x = f2[0] * scale;
444
+ res.y = f2[1] * scale;
445
+ return res;
446
+ #else
447
+ float2 res;
448
+ res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
449
+ res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
450
+ scale);
451
+ return res;
452
+ #endif
453
+ }
454
+
455
+ // fp8x4 -> float4
456
+ template <>
457
+ __inline__ __device__ Float4_
458
+ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
459
+ Float4_ res;
460
+ res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
461
+ res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
462
+ return res;
463
+ }
464
+
465
+ // fp8x8 -> float8
466
+ template <>
467
+ __inline__ __device__ Float8_
468
+ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
469
+ Float4_ tmp1, tmp2;
470
+ tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
471
+ tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
472
+ Float8_ res;
473
+ res.x = tmp1.x;
474
+ res.y = tmp1.y;
475
+ res.z = tmp2.x;
476
+ res.w = tmp2.y;
477
+ return res;
478
+ }
479
+
480
+ /* Quantize(HP / scale) => FP8 */
481
+
482
+ // TODO(Hai): vectorized to add
483
+
484
+ // half -> fp8
485
+ template <>
486
+ __inline__ __device__ uint8_t
487
+ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
488
+ __half_raw tmp;
489
+ tmp.x = a;
490
+
491
+ hip_fp8 f8{static_cast<float>(tmp.data) / scale};
492
+ return f8.data;
493
+ }
494
+
495
+ // bf16 -> fp8
496
+ template <>
497
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
498
+ const __nv_bfloat16& a, const float scale) {
499
+ hip_fp8 res{__bfloat162float(a) / scale};
500
+ return res.data;
501
+ }
502
+
503
+ // float -> fp8
504
+ template <>
505
+ __inline__ __device__ uint8_t
506
+ scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
507
+ hip_fp8 f8(a / scale);
508
+ return f8.data;
509
+ }
510
+
511
+ // fp8x4 -> float4
512
+ template <>
513
+ __inline__ __device__ float4
514
+ scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
515
+ Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
516
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
517
+ return res;
518
+ }
519
+ #endif // ENABLE_FP8
520
+
521
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
522
+ __inline__ __device__ Tout convert(const Tin& x) {
523
+ #ifdef ENABLE_FP8
524
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
525
+ return vec_conversion<Tout, Tin>(x);
526
+ }
527
+ #endif
528
+ assert(false);
529
+ return {}; // Squash missing return statement warning
530
+ }
531
+
532
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
533
+ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
534
+ #ifdef ENABLE_FP8
535
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
536
+ return scaled_vec_conversion<Tout, Tin>(x, scale);
537
+ }
538
+ #endif
539
+ assert(false);
540
+ return {}; // Squash missing return statement warning
541
+ }
542
+
543
+ // The following macro is used to dispatch the conversion function based on
544
+ // the data type of the key and value cache. The FN is a macro that calls a
545
+ // function with template<typename scalar_t, typename cache_t,
546
+ // Fp8KVCacheDataType kv_dt>.
547
+ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
548
+ if (KV_DTYPE == "auto") { \
549
+ if (SRC_DTYPE == at::ScalarType::Float) { \
550
+ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
551
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
552
+ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
553
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
554
+ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
555
+ } else { \
556
+ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
557
+ } \
558
+ } else { \
559
+ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
560
+ if (SRC_DTYPE == at::ScalarType::Float) { \
561
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
562
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
563
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
564
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
565
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
566
+ } else { \
567
+ TORCH_CHECK(false, \
568
+ "Unsupported input type of kv cache: ", SRC_DTYPE); \
569
+ } \
570
+ } else { \
571
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
572
+ } \
573
+ }
574
+
575
+ } // namespace fp8
576
+ #endif // USE_ROCM
577
+ } // namespace vllm
paged-attention/quantization/fp8/nvidia/quant_utils.cuh ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "../../../attention/attention_dtypes.h"
4
+ #include <assert.h>
5
+ #include <float.h>
6
+ #include <stdint.h>
7
+ #include <type_traits>
8
+
9
+ namespace vllm {
10
+ #ifndef USE_ROCM
11
+
12
+ namespace fp8 {
13
+ #ifdef ENABLE_FP8
14
+
15
+ #if 0 // Disable the following code to reduce the binary size.
16
+ template <typename Tout, typename Tin>
17
+ __inline__ __device__ Tout
18
+ vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
19
+ return x;
20
+ }
21
+
22
+ // fp8 -> half
23
+ template <>
24
+ __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
25
+ const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
26
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
27
+ return res.x;
28
+ }
29
+
30
+ // fp8x2 -> half2
31
+ template <>
32
+ __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
33
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
34
+ union {
35
+ uint16_t u16[2];
36
+ uint32_t u32;
37
+ } tmp;
38
+ __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
39
+ tmp.u16[0] = res.x;
40
+ tmp.u16[1] = res.y;
41
+ return tmp.u32;
42
+ }
43
+
44
+ // fp8x4 -> half2x2
45
+ template <>
46
+ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
47
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
48
+ union {
49
+ uint2 u32x2;
50
+ uint32_t u32[2];
51
+ } tmp;
52
+ tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
53
+ tmp.u32[1] =
54
+ vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
55
+ return tmp.u32x2;
56
+ }
57
+
58
+ // fp8x8 -> half2x4
59
+ template <>
60
+ __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
61
+ const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
62
+ union {
63
+ uint4 u64x2;
64
+ uint2 u64[2];
65
+ } tmp;
66
+ tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
67
+ tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
68
+ return tmp.u64x2;
69
+ }
70
+
71
+ // fp8 -> __nv_bfloat16
72
+ template <>
73
+ __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
74
+ const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
75
+ // Note there is no direct convert function from fp8 to bf16.
76
+ // fp8 -> half
77
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
78
+ // half -> float -> bf16
79
+ float tmp = half_to_float(res.x);
80
+ return __float2bfloat16(tmp);
81
+ }
82
+
83
+ // fp8x2 -> __nv_bfloat162
84
+ template <>
85
+ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
86
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
87
+ __nv_bfloat162 res;
88
+ res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
89
+ res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
90
+ return res;
91
+ }
92
+
93
+ // fp8x4 -> bf16_4_t
94
+ template <>
95
+ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
96
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
97
+ bf16_4_t res;
98
+ res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
99
+ res.y =
100
+ vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
101
+ return res;
102
+ }
103
+
104
+ // fp8x8 -> bf16_8_t
105
+ template <>
106
+ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
107
+ const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
108
+ bf16_4_t tmp1, tmp2;
109
+ tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
110
+ tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
111
+ bf16_8_t res;
112
+ res.x = tmp1.x;
113
+ res.y = tmp1.y;
114
+ res.z = tmp2.x;
115
+ res.w = tmp2.y;
116
+ return res;
117
+ }
118
+
119
+ // fp8 -> float
120
+ template <>
121
+ __inline__ __device__ float
122
+ vec_conversion<float, uint8_t>(const uint8_t &a,
123
+ const __nv_fp8_interpretation_t fp8_type) {
124
+ // fp8 -> half
125
+ uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
126
+ // half -> float
127
+ return half_to_float(tmp);
128
+ }
129
+
130
+ // fp8x2 -> float2
131
+ template <>
132
+ __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
133
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
134
+ // fp8x2 -> half2
135
+ uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
136
+ // half2 -> float2
137
+ return half2_to_float2(tmp);
138
+ }
139
+
140
+ // fp8x4 -> float4
141
+ template <>
142
+ __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
143
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
144
+ Float4_ res;
145
+ res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
146
+ res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
147
+ return res;
148
+ }
149
+
150
+ // fp8x8 -> float8
151
+ template <>
152
+ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
153
+ const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
154
+ Float4_ tmp1, tmp2;
155
+ tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
156
+ tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
157
+ Float8_ res;
158
+ res.x = tmp1.x;
159
+ res.y = tmp1.y;
160
+ res.z = tmp2.x;
161
+ res.w = tmp2.y;
162
+ return res;
163
+ }
164
+
165
+ // half -> fp8
166
+ template <>
167
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
168
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
169
+ __half_raw tmp;
170
+ tmp.x = a;
171
+ __nv_fp8_storage_t res =
172
+ __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
173
+ return (uint8_t)res;
174
+ }
175
+
176
+ // bf16 -> fp8
177
+ template <>
178
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
179
+ const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
180
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
181
+ assert(false);
182
+ #else
183
+ __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
184
+ __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
185
+ return (uint8_t)res;
186
+ #endif
187
+ }
188
+
189
+ // float -> fp8
190
+ template <>
191
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
192
+ const float &a, const __nv_fp8_interpretation_t fp8_type) {
193
+ __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
194
+ return (uint8_t)res;
195
+ }
196
+
197
+ // fp8x4 -> float4
198
+ template <>
199
+ __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
200
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
201
+ Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
202
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
203
+ return res;
204
+ }
205
+
206
+ template <>
207
+ __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
208
+ const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
209
+ union {
210
+ half2 float16;
211
+ uint32_t uint32;
212
+ };
213
+
214
+ float16 = __float22half2_rn(a);
215
+ return uint32;
216
+ }
217
+
218
+ template <>
219
+ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
220
+ const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
221
+ uint2 b;
222
+ float2 val;
223
+ val.x = a.x.x;
224
+ val.y = a.x.y;
225
+ b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
226
+
227
+ val.x = a.y.x;
228
+ val.y = a.y.y;
229
+ b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
230
+
231
+ return b;
232
+ }
233
+
234
+ template <>
235
+ __inline__ __device__ float4 vec_conversion<float4, Float4_>(
236
+ const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
237
+ float4 b;
238
+ b.x = a.x.x;
239
+ b.y = a.x.y;
240
+ b.z = a.y.x;
241
+ b.w = a.y.y;
242
+ return b;
243
+ }
244
+
245
+ template <>
246
+ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
247
+ const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
248
+ uint4 b;
249
+ b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
250
+ b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
251
+ b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
252
+ b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
253
+ return b;
254
+ }
255
+
256
+ template <>
257
+ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
258
+ const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
259
+ __nv_bfloat162 b;
260
+ from_float(b, a);
261
+ return b;
262
+ }
263
+
264
+ template <>
265
+ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
266
+ const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
267
+ bf16_4_t b;
268
+ from_float(b, a);
269
+ return b;
270
+ }
271
+
272
+ template <>
273
+ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
274
+ const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
275
+ bf16_8_t b;
276
+ from_float(b, a);
277
+ return b;
278
+ }
279
+ #endif
280
+
281
+ /* Scaled and vectorized conversions, for data exchange between high and low
282
+ precision domains Convention of the scale in API, e.g: FP8_data =
283
+ Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
284
+ Dequant(FP8) * scale => HP
285
+ */
286
+
287
+ template <typename Tout, typename Tin>
288
+ __inline__ __device__ Tout scaled_vec_conversion(
289
+ const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
290
+ return x;
291
+ }
292
+
293
+ // fp8 -> half
294
+ template <>
295
+ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
296
+ const uint8_t& a, const float scale,
297
+ const __nv_fp8_interpretation_t fp8_type) {
298
+ __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
299
+ return float_to_half(half_to_float(tmp.x) * scale);
300
+ }
301
+
302
+ // fp8x2 -> half2
303
+ template <>
304
+ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
305
+ const uint16_t& a, const float scale,
306
+ const __nv_fp8_interpretation_t fp8_type) {
307
+ union {
308
+ uint16_t u16[2];
309
+ uint32_t u32;
310
+ } tmp;
311
+ __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
312
+ tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
313
+ tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
314
+ return tmp.u32;
315
+ }
316
+
317
+ // fp8x4 -> half2x2
318
+ template <>
319
+ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
320
+ const uint32_t& a, const float scale,
321
+ const __nv_fp8_interpretation_t fp8_type) {
322
+ union {
323
+ uint2 u32x2;
324
+ uint32_t u32[2];
325
+ } tmp;
326
+ tmp.u32[0] =
327
+ scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
328
+ tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
329
+ scale, fp8_type);
330
+ return tmp.u32x2;
331
+ }
332
+
333
+ // fp8x8 -> half2x4
334
+ template <>
335
+ __inline__ __device__ uint4
336
+ scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
337
+ const __nv_fp8_interpretation_t fp8_type) {
338
+ union {
339
+ uint4 u64x2;
340
+ uint2 u64[2];
341
+ } tmp;
342
+ tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
343
+ tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
344
+ return tmp.u64x2;
345
+ }
346
+
347
+ // fp8 -> __nv_bfloat16
348
+ template <>
349
+ __inline__ __device__ __nv_bfloat16
350
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>(
351
+ const uint8_t& a, const float scale,
352
+ const __nv_fp8_interpretation_t fp8_type) {
353
+ // Note there is no direct convert function from fp8 to bf16.
354
+ // fp8 -> half
355
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
356
+ // half -> float -> bf16
357
+ float tmp = half_to_float(res.x);
358
+ return __float2bfloat16(tmp * scale);
359
+ }
360
+
361
+ // fp8x2 -> __nv_bfloat162
362
+ template <>
363
+ __inline__ __device__ __nv_bfloat162
364
+ scaled_vec_conversion<__nv_bfloat162, uint16_t>(
365
+ const uint16_t& a, const float scale,
366
+ const __nv_fp8_interpretation_t fp8_type) {
367
+ __nv_bfloat162 res;
368
+ res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
369
+ fp8_type);
370
+ res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
371
+ scale, fp8_type);
372
+ return res;
373
+ }
374
+
375
+ // fp8x4 -> bf16_4_t
376
+ template <>
377
+ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
378
+ const uint32_t& a, const float scale,
379
+ const __nv_fp8_interpretation_t fp8_type) {
380
+ bf16_4_t res;
381
+ res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
382
+ fp8_type);
383
+ res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
384
+ scale, fp8_type);
385
+ return res;
386
+ }
387
+
388
+ // fp8x8 -> bf16_8_t
389
+ template <>
390
+ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
391
+ const uint2& a, const float scale,
392
+ const __nv_fp8_interpretation_t fp8_type) {
393
+ bf16_4_t tmp1, tmp2;
394
+ tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
395
+ tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
396
+ bf16_8_t res;
397
+ res.x = tmp1.x;
398
+ res.y = tmp1.y;
399
+ res.z = tmp2.x;
400
+ res.w = tmp2.y;
401
+ return res;
402
+ }
403
+
404
+ // fp8 -> float
405
+ template <>
406
+ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
407
+ const uint8_t& a, const float scale,
408
+ const __nv_fp8_interpretation_t fp8_type) {
409
+ // fp8 -> half
410
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
411
+ uint16_t tmp = res.x;
412
+
413
+ // half -> float
414
+ return half_to_float(tmp) * scale;
415
+ }
416
+
417
+ // fp8x2 -> float2
418
+ template <>
419
+ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
420
+ const uint16_t& a, const float scale,
421
+ const __nv_fp8_interpretation_t fp8_type) {
422
+ // fp8x2 -> half2
423
+ uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
424
+ // half2 -> float2
425
+ return half2_to_float2(tmp);
426
+ }
427
+
428
+ // fp8x4 -> float4
429
+ template <>
430
+ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
431
+ const uint32_t& a, const float scale,
432
+ const __nv_fp8_interpretation_t fp8_type) {
433
+ Float4_ res;
434
+ res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
435
+ res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
436
+ fp8_type);
437
+ return res;
438
+ }
439
+
440
+ // fp8x8 -> float8
441
+ template <>
442
+ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
443
+ const uint2& a, const float scale,
444
+ const __nv_fp8_interpretation_t fp8_type) {
445
+ Float4_ tmp1, tmp2;
446
+ tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
447
+ tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
448
+ Float8_ res;
449
+ res.x = tmp1.x;
450
+ res.y = tmp1.y;
451
+ res.z = tmp2.x;
452
+ res.w = tmp2.y;
453
+ return res;
454
+ }
455
+
456
+ // half -> fp8
457
+ template <>
458
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
459
+ const uint16_t& a, const float scale,
460
+ const __nv_fp8_interpretation_t fp8_type) {
461
+ __nv_fp8_storage_t res =
462
+ __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
463
+ return (uint8_t)res;
464
+ }
465
+
466
+ // bf16 -> fp8
467
+ template <>
468
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
469
+ const __nv_bfloat16& a, const float scale,
470
+ const __nv_fp8_interpretation_t fp8_type) {
471
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
472
+ assert(false);
473
+ #else
474
+ __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
475
+ __NV_SATFINITE, fp8_type);
476
+ return (uint8_t)res;
477
+ #endif
478
+ __builtin_unreachable(); // Suppress missing return statement warning
479
+ }
480
+
481
+ // float -> fp8
482
+ template <>
483
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
484
+ const float& a, const float scale,
485
+ const __nv_fp8_interpretation_t fp8_type) {
486
+ __nv_fp8_storage_t res =
487
+ __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
488
+ return (uint8_t)res;
489
+ }
490
+
491
+ // fp8x4 -> float4
492
+ template <>
493
+ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
494
+ const uint32_t& a, const float scale,
495
+ const __nv_fp8_interpretation_t fp8_type) {
496
+ Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
497
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
498
+ return res;
499
+ }
500
+ #endif // ENABLE_FP8
501
+
502
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
503
+ __inline__ __device__ Tout convert(const Tin& x) {
504
+ #if 0 // Disable the following code to reduce the binary size.
505
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
506
+ return vec_conversion<Tout, Tin>(x, __NV_E4M3);
507
+ } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
508
+ return vec_conversion<Tout, Tin>(x, __NV_E5M2);
509
+ }
510
+ #endif
511
+ assert(false);
512
+ __builtin_unreachable(); // Suppress missing return statement warning
513
+ }
514
+
515
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
516
+ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
517
+ #ifdef ENABLE_FP8
518
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
519
+ return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
520
+ } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
521
+ return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
522
+ }
523
+ #endif
524
+ assert(false);
525
+ __builtin_unreachable(); // Suppress missing return statement warning
526
+ }
527
+
528
+ // The following macro is used to dispatch the conversion function based on
529
+ // the data type of the key and value cache. The FN is a macro that calls a
530
+ // function with template<typename scalar_t, typename cache_t,
531
+ // Fp8KVCacheDataType kv_dt>.
532
+ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
533
+ if (KV_DTYPE == "auto") { \
534
+ if (SRC_DTYPE == at::ScalarType::Float) { \
535
+ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
536
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
537
+ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
538
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
539
+ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
540
+ } else { \
541
+ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
542
+ } \
543
+ } else { \
544
+ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
545
+ if (SRC_DTYPE == at::ScalarType::Float) { \
546
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
547
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
548
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
549
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
550
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
551
+ } else { \
552
+ TORCH_CHECK(false, \
553
+ "Unsupported input type of kv cache: ", SRC_DTYPE); \
554
+ } \
555
+ } else if (KV_DTYPE == "fp8_e5m2") { \
556
+ if (SRC_DTYPE == at::ScalarType::Float) { \
557
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
558
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
559
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
560
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
561
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
562
+ } else { \
563
+ TORCH_CHECK(false, \
564
+ "Unsupported input type of kv cache: ", SRC_DTYPE); \
565
+ } \
566
+ } else { \
567
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
568
+ } \
569
+ }
570
+
571
+ } // namespace fp8
572
+ #endif // not USE_ROCM
573
+ } // namespace vllm
tests/kernels/__init__.py ADDED
File without changes
tests/kernels/allclose_default.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Reference default values of atol and rtol are from
4
+ # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5
+ default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6
+ default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
7
+
8
+
9
+ def get_default_atol(output) -> float:
10
+ return default_atol[output.dtype]
11
+
12
+
13
+ def get_default_rtol(output) -> float:
14
+ return default_rtol[output.dtype]
tests/kernels/conftest.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import attention as ops
4
+ import pytest
5
+ import torch
6
+
7
+
8
+ @pytest.fixture()
9
+ def kv_cache_factory():
10
+ return create_kv_caches_with_random
11
+
12
+
13
+ @pytest.fixture()
14
+ def kv_cache_factory_flashinfer():
15
+ return create_kv_caches_with_random_flash
16
+
17
+
18
+ STR_DTYPE_TO_TORCH_DTYPE = {
19
+ "half": torch.half,
20
+ "bfloat16": torch.bfloat16,
21
+ "float": torch.float,
22
+ "fp8": torch.uint8,
23
+ "fp8_e4m3": torch.uint8,
24
+ "fp8_e5m2": torch.uint8,
25
+ }
26
+
27
+
28
+ def create_kv_caches_with_random(
29
+ num_blocks: int,
30
+ block_size: int,
31
+ num_layers: int,
32
+ num_heads: int,
33
+ head_size: int,
34
+ cache_dtype: Optional[Union[str, torch.dtype]],
35
+ model_dtype: Optional[Union[str, torch.dtype]] = None,
36
+ seed: int = 0,
37
+ device: Optional[str] = "cuda",
38
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
39
+
40
+ if cache_dtype == "fp8" and head_size % 16:
41
+ raise ValueError(
42
+ f"Does not support key cache of type fp8 with head_size {head_size}"
43
+ )
44
+ from attention.platforms import current_platform
45
+
46
+ current_platform.seed_everything(seed)
47
+
48
+ torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
49
+
50
+ scale = head_size**-0.5
51
+ x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
52
+ key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
53
+ key_caches: List[torch.Tensor] = []
54
+ for _ in range(num_layers):
55
+ key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
56
+ if cache_dtype in ["auto", "half", "bfloat16", "float"]:
57
+ key_cache.uniform_(-scale, scale)
58
+ elif cache_dtype == "fp8":
59
+ _generate_random_fp8(key_cache, -scale, scale)
60
+ else:
61
+ raise ValueError(f"Does not support key cache of type {cache_dtype}")
62
+ key_caches.append(key_cache)
63
+
64
+ value_cache_shape = (num_blocks, num_heads, head_size, block_size)
65
+ value_caches: List[torch.Tensor] = []
66
+ for _ in range(num_layers):
67
+ value_cache = torch.empty(
68
+ size=value_cache_shape, dtype=torch_dtype, device=device
69
+ )
70
+ if cache_dtype in ["auto", "half", "bfloat16", "float"]:
71
+ value_cache.uniform_(-scale, scale)
72
+ elif cache_dtype == "fp8":
73
+ _generate_random_fp8(value_cache, -scale, scale)
74
+ else:
75
+ raise ValueError(f"Does not support value cache of type {cache_dtype}")
76
+ value_caches.append(value_cache)
77
+ return key_caches, value_caches
78
+
79
+
80
+ def create_kv_caches_with_random_flash(
81
+ num_blocks: int,
82
+ block_size: int,
83
+ num_layers: int,
84
+ num_heads: int,
85
+ head_size: int,
86
+ cache_dtype: Optional[Union[str, torch.dtype]],
87
+ model_dtype: Optional[Union[str, torch.dtype]] = None,
88
+ seed: int = 0,
89
+ device: Optional[str] = "cuda",
90
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
91
+ from attention.platforms import current_platform
92
+
93
+ current_platform.seed_everything(seed)
94
+
95
+ torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
96
+ key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
97
+ scale = head_size**-0.5
98
+
99
+ key_caches: List[torch.Tensor] = []
100
+ value_caches: List[torch.Tensor] = []
101
+
102
+ for _ in range(num_layers):
103
+ key_value_cache = torch.empty(
104
+ size=key_value_cache_shape, dtype=torch_dtype, device=device
105
+ )
106
+ if cache_dtype in ["auto", "half", "bfloat16", "float"]:
107
+ key_value_cache.uniform_(-scale, scale)
108
+ elif cache_dtype == "fp8":
109
+ _generate_random_fp8(key_value_cache, -scale, scale)
110
+ else:
111
+ raise ValueError(f"Does not support key cache of type {cache_dtype}")
112
+ key_caches.append(key_value_cache[:, 0])
113
+ value_caches.append(key_value_cache[:, 1])
114
+ return key_caches, value_caches
115
+
116
+
117
+ def get_kv_cache_torch_dtype(
118
+ cache_dtype: Optional[Union[str, torch.dtype]],
119
+ model_dtype: Optional[Union[str, torch.dtype]] = None,
120
+ ) -> torch.dtype:
121
+ if isinstance(cache_dtype, str):
122
+ if cache_dtype == "auto":
123
+ if isinstance(model_dtype, str):
124
+ torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
125
+ elif isinstance(model_dtype, torch.dtype):
126
+ torch_dtype = model_dtype
127
+ else:
128
+ raise ValueError(f"Invalid model dtype: {model_dtype}")
129
+ elif cache_dtype in ["half", "bfloat16", "float"]:
130
+ torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
131
+ elif cache_dtype == "fp8":
132
+ torch_dtype = torch.uint8
133
+ else:
134
+ raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
135
+ elif isinstance(cache_dtype, torch.dtype):
136
+ torch_dtype = cache_dtype
137
+ else:
138
+ raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
139
+ return torch_dtype
140
+
141
+
142
+ def _generate_random_fp8(
143
+ tensor: torch.Tensor,
144
+ low: float,
145
+ high: float,
146
+ ) -> None:
147
+ # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
148
+ # it may occur Inf or NaN if we directly use torch.randint
149
+ # to generate random data for fp8 data.
150
+ # For example, s.11111.00 in fp8e5m2 format represents Inf.
151
+ # | E4M3 | E5M2
152
+ # -----|-------------|-------------------
153
+ # Inf | N/A | s.11111.00
154
+ # NaN | s.1111.111 | s.11111.{01,10,11}
155
+ tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
156
+ tensor_tmp.uniform_(low, high)
157
+ ops.convert_fp8(tensor, tensor_tmp)
158
+ del tensor_tmp
tests/kernels/test_attention.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Optional, Tuple
3
+
4
+ import attention as ops
5
+ import pytest
6
+ import torch
7
+ from attention.platforms import current_platform
8
+
9
+ from .allclose_default import get_default_atol, get_default_rtol
10
+ from .utils import get_max_shared_memory_bytes, opcheck
11
+
12
+ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13
+ # This will change depending on the compute capability.
14
+ # - 512 as a buffer
15
+ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
16
+ # There may not be enough gpu memory due to large NUM_BLOCKS.
17
+ # Reduce NUM_BLOCKS when it happens.
18
+ NUM_BLOCKS = 4321 # Arbitrary values for testing
19
+ PARTITION_SIZE = 512
20
+ # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
21
+ DTYPES = (
22
+ [torch.half, torch.bfloat16, torch.float]
23
+ if not current_platform.is_rocm()
24
+ else [torch.half, torch.bfloat16]
25
+ )
26
+ NUM_GEN_SEQS = [7] # Arbitrary values for testing
27
+ NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
28
+ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
29
+
30
+ # This should be sync with get_supported_head_sizes() in
31
+ # vllm.attention.ops.paged_attn.PagedAttention
32
+ HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
33
+
34
+ BLOCK_SIZES = [16, 32]
35
+ USE_ALIBI = [False, True]
36
+ KV_CACHE_DTYPE = ["auto", "fp8"]
37
+ SEEDS = [0]
38
+ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
39
+
40
+
41
+ def ref_masked_attention(
42
+ query: torch.Tensor,
43
+ key: torch.Tensor,
44
+ value: torch.Tensor,
45
+ scale: float,
46
+ attn_mask: Optional[torch.Tensor] = None,
47
+ ) -> torch.Tensor:
48
+ attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
49
+ if attn_mask is not None:
50
+ attn_weights = attn_weights + attn_mask.float()
51
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
52
+ out = torch.einsum("hqk,khd->qhd", attn_weights, value)
53
+ return out
54
+
55
+
56
+ def ref_single_query_cached_kv_attention(
57
+ output: torch.Tensor,
58
+ query: torch.Tensor,
59
+ num_queries_per_kv: int,
60
+ key_cache: torch.Tensor,
61
+ value_cache: torch.Tensor,
62
+ block_tables: torch.Tensor,
63
+ seq_lens: torch.Tensor,
64
+ scale: float,
65
+ alibi_slopes: Optional[torch.Tensor],
66
+ ) -> None:
67
+ num_query_heads = query.shape[1]
68
+ num_kv_heads = value_cache.shape[1]
69
+ head_size = value_cache.shape[2]
70
+ block_size = value_cache.shape[3]
71
+ num_seqs = query.shape[0]
72
+
73
+ block_tables_lst = block_tables.cpu().tolist()
74
+ seq_lens_lst = seq_lens.cpu().tolist()
75
+ for i in range(num_seqs):
76
+ q = query[i].unsqueeze(0)
77
+ block_table = block_tables_lst[i]
78
+ seq_len = int(seq_lens_lst[i])
79
+
80
+ keys_lst: List[torch.Tensor] = []
81
+ values_lst: List[torch.Tensor] = []
82
+ for j in range(seq_len):
83
+ block_number = int(block_table[j // block_size])
84
+ block_offset = j % block_size
85
+
86
+ k = key_cache[block_number, :, :, block_offset, :]
87
+ k = k.reshape(num_kv_heads, head_size)
88
+ keys_lst.append(k)
89
+
90
+ v = value_cache[block_number, :, :, block_offset]
91
+ values_lst.append(v)
92
+ keys = torch.stack(keys_lst, dim=0)
93
+ values = torch.stack(values_lst, dim=0)
94
+ if num_queries_per_kv > 1:
95
+ # Handle MQA and GQA
96
+ keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
97
+ values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
98
+
99
+ alibi_bias = None
100
+ if alibi_slopes is not None:
101
+ # Create the ALiBi bias used in the paged attention kernel.
102
+ position_ids = torch.arange(seq_len).int()
103
+ alibi_bias = (position_ids - seq_len + 1).float()
104
+ alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
105
+
106
+ out = ref_masked_attention(q, keys, values, scale, alibi_bias)
107
+ out = out.view(num_query_heads, head_size)
108
+ output[i].copy_(out, non_blocking=True)
109
+
110
+
111
+ @pytest.mark.parametrize(
112
+ "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
113
+ )
114
+ @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
115
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
116
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
117
+ @pytest.mark.parametrize("use_alibi", USE_ALIBI)
118
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
119
+ @pytest.mark.parametrize("dtype", DTYPES)
120
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
121
+ @pytest.mark.parametrize("seed", SEEDS)
122
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
123
+ def test_paged_attention(
124
+ kv_cache_factory,
125
+ version: str,
126
+ num_seqs: int,
127
+ num_heads: Tuple[int, int],
128
+ head_size: int,
129
+ use_alibi: bool,
130
+ block_size: int,
131
+ dtype: torch.dtype,
132
+ kv_cache_dtype: str,
133
+ seed: int,
134
+ device: str,
135
+ ) -> None:
136
+ if (kv_cache_dtype == "fp8" and head_size % 16) or (
137
+ version == "rocm" and head_size not in (64, 128)
138
+ ):
139
+ pytest.skip()
140
+
141
+ current_platform.seed_everything(seed)
142
+ torch.set_default_device(device)
143
+ scale = float(1.0 / (head_size**0.5))
144
+ num_query_heads, num_kv_heads = num_heads
145
+ query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
146
+ query.uniform_(-scale, scale)
147
+
148
+ assert num_query_heads % num_kv_heads == 0
149
+ num_queries_per_kv = num_query_heads // num_kv_heads
150
+ alibi_slopes = None
151
+ if use_alibi:
152
+ alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
153
+
154
+ seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
155
+ seq_lens[-1] = MAX_SEQ_LEN
156
+ max_seq_len = max(seq_lens)
157
+ seq_lens = torch.tensor(seq_lens, dtype=torch.int)
158
+
159
+ # Create the block tables.
160
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
161
+ block_tables_lst: List[List[int]] = []
162
+ for _ in range(num_seqs):
163
+ block_table = [
164
+ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
165
+ ]
166
+ block_tables_lst.append(block_table)
167
+
168
+ block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
169
+
170
+ # Create the KV caches.
171
+ key_caches, value_caches = kv_cache_factory(
172
+ NUM_BLOCKS,
173
+ block_size,
174
+ 1,
175
+ num_kv_heads,
176
+ head_size,
177
+ kv_cache_dtype,
178
+ dtype,
179
+ seed,
180
+ device,
181
+ )
182
+ key_cache, value_cache = key_caches[0], value_caches[0]
183
+
184
+ # Using default kv_scale
185
+ k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
186
+
187
+ # Call the paged attention kernel.
188
+ output = torch.empty_like(query)
189
+ if version == "v1":
190
+ ops.paged_attention_v1(
191
+ output,
192
+ query,
193
+ key_cache,
194
+ value_cache,
195
+ num_kv_heads,
196
+ scale,
197
+ block_tables,
198
+ seq_lens,
199
+ block_size,
200
+ max_seq_len,
201
+ alibi_slopes,
202
+ kv_cache_dtype,
203
+ k_scale,
204
+ v_scale,
205
+ )
206
+
207
+ opcheck(
208
+ ops.ops.paged_attention_v1,
209
+ (
210
+ output,
211
+ query,
212
+ key_cache,
213
+ value_cache,
214
+ num_kv_heads,
215
+ scale,
216
+ block_tables,
217
+ seq_lens,
218
+ block_size,
219
+ max_seq_len,
220
+ alibi_slopes,
221
+ kv_cache_dtype,
222
+ k_scale,
223
+ v_scale,
224
+ 0,
225
+ 0,
226
+ 0,
227
+ 64,
228
+ 0,
229
+ ),
230
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
231
+ )
232
+
233
+ elif version in ("v2", "rocm"):
234
+ num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
235
+ assert PARTITION_SIZE % block_size == 0
236
+ num_seqs, num_heads, head_size = output.shape
237
+ tmp_output = torch.empty(
238
+ size=(num_seqs, num_heads, num_partitions, head_size),
239
+ dtype=output.dtype,
240
+ )
241
+ exp_sums = torch.empty(
242
+ size=(num_seqs, num_heads, num_partitions),
243
+ dtype=torch.float32,
244
+ )
245
+ max_logits = torch.empty_like(exp_sums)
246
+ if version == "v2":
247
+ ops.paged_attention_v2(
248
+ output,
249
+ exp_sums,
250
+ max_logits,
251
+ tmp_output,
252
+ query,
253
+ key_cache,
254
+ value_cache,
255
+ num_kv_heads,
256
+ scale,
257
+ block_tables,
258
+ seq_lens,
259
+ block_size,
260
+ max_seq_len,
261
+ alibi_slopes,
262
+ kv_cache_dtype,
263
+ k_scale,
264
+ v_scale,
265
+ )
266
+
267
+ opcheck(
268
+ ops.ops.paged_attention_v2,
269
+ (
270
+ output,
271
+ exp_sums,
272
+ max_logits,
273
+ tmp_output,
274
+ query,
275
+ key_cache,
276
+ value_cache,
277
+ num_kv_heads,
278
+ scale,
279
+ block_tables,
280
+ seq_lens,
281
+ block_size,
282
+ max_seq_len,
283
+ alibi_slopes,
284
+ kv_cache_dtype,
285
+ k_scale,
286
+ v_scale,
287
+ 0,
288
+ 0,
289
+ 0,
290
+ 64,
291
+ 0,
292
+ ),
293
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
294
+ )
295
+
296
+ else:
297
+ ops.paged_attention_rocm(
298
+ output,
299
+ exp_sums,
300
+ max_logits,
301
+ tmp_output,
302
+ query,
303
+ key_cache,
304
+ value_cache,
305
+ num_kv_heads,
306
+ scale,
307
+ block_tables,
308
+ seq_lens,
309
+ block_size,
310
+ max_seq_len,
311
+ alibi_slopes,
312
+ kv_cache_dtype,
313
+ k_scale,
314
+ v_scale,
315
+ )
316
+
317
+ opcheck(
318
+ torch.ops._rocm_C.paged_attention,
319
+ (
320
+ output,
321
+ exp_sums,
322
+ max_logits,
323
+ tmp_output,
324
+ query,
325
+ key_cache,
326
+ value_cache,
327
+ num_kv_heads,
328
+ scale,
329
+ block_tables,
330
+ seq_lens,
331
+ block_size,
332
+ max_seq_len,
333
+ alibi_slopes,
334
+ kv_cache_dtype,
335
+ k_scale,
336
+ v_scale,
337
+ ),
338
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
339
+ )
340
+
341
+ else:
342
+ raise AssertionError(f"Unknown version: {version}")
343
+
344
+ # Run the reference implementation.
345
+ if kv_cache_dtype == "fp8":
346
+ # Convert cache data back to dtype.
347
+ x = 16 // torch.tensor([], dtype=dtype).element_size()
348
+ key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
349
+ dequantized_key_cache = torch.empty(
350
+ size=key_cache_shape, dtype=dtype, device=device
351
+ )
352
+ ops.convert_fp8(dequantized_key_cache, key_cache)
353
+ key_cache = dequantized_key_cache
354
+
355
+ value_cache_shape = value_cache.shape
356
+ dequantized_value_cache = torch.empty(
357
+ size=value_cache_shape, dtype=dtype, device=device
358
+ )
359
+ ops.convert_fp8(dequantized_value_cache, value_cache)
360
+ value_cache = dequantized_value_cache
361
+
362
+ ref_output = torch.empty_like(query)
363
+ ref_single_query_cached_kv_attention(
364
+ ref_output,
365
+ query,
366
+ num_queries_per_kv,
367
+ key_cache,
368
+ value_cache,
369
+ block_tables,
370
+ seq_lens,
371
+ scale,
372
+ alibi_slopes,
373
+ )
374
+
375
+ # NOTE(woosuk): Due to the kernel-level differences in the two
376
+ # implementations, there is a small numerical difference in the two
377
+ # outputs. Thus, we use a relaxed tolerance for the test.
378
+ atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
379
+ rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
380
+
381
+ # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
382
+ # so we use a relaxed tolerance for the test.
383
+ atol, rtol = 1e-3, 1e-5
384
+ if kv_cache_dtype == "fp8":
385
+ atol, rtol = 1e-2, 1e-5
386
+ torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
387
+
388
+
389
+ def ref_multi_query_kv_attention(
390
+ cu_seq_lens: List[int],
391
+ query: torch.Tensor,
392
+ key: torch.Tensor,
393
+ value: torch.Tensor,
394
+ scale: float,
395
+ dtype: torch.dtype,
396
+ ) -> torch.Tensor:
397
+ num_seqs = len(cu_seq_lens) - 1
398
+ ref_outputs: List[torch.Tensor] = []
399
+ for i in range(num_seqs):
400
+ start_idx = cu_seq_lens[i]
401
+ end_idx = cu_seq_lens[i + 1]
402
+ seq_len = end_idx - start_idx
403
+
404
+ # Create attention mask.
405
+ attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
406
+ attn_mask = attn_mask * torch.finfo(dtype).min
407
+ attn_mask = attn_mask.to(dtype=dtype)
408
+
409
+ ref_output = ref_masked_attention(
410
+ query[start_idx:end_idx],
411
+ key[start_idx:end_idx],
412
+ value[start_idx:end_idx],
413
+ scale,
414
+ attn_mask=attn_mask,
415
+ )
416
+ ref_outputs.append(ref_output)
417
+
418
+ return torch.cat(ref_outputs, dim=0)
tests/kernels/test_cache.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Tuple
3
+
4
+ import attention as ops
5
+ import pytest
6
+ import torch
7
+ from attention.platforms import current_platform
8
+
9
+ from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
+
11
+ COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
12
+ DTYPES = [torch.half, torch.bfloat16, torch.float]
13
+ NUM_TOKENS = [42] # Arbitrary values for testing
14
+ NUM_LAYERS = [1] # Arbitrary values for testing
15
+ NUM_HEADS = [8] # Arbitrary values for testing
16
+ HEAD_SIZES = [64, 80, 120, 256]
17
+ BLOCK_SIZES = [8, 16, 32]
18
+
19
+ # Arbitrary values for testing
20
+ # don't make it too large. e.g. [1024, 36000] will OOM
21
+ NUM_BLOCKS = [1024, 10000]
22
+
23
+ NUM_MAPPINGS = [256] # Arbitrary values for testing
24
+ SEEDS = [0]
25
+ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
26
+
27
+ # We assume fp8 is always enabled for testing.
28
+ KV_CACHE_DTYPE = ["auto", "fp8"]
29
+
30
+
31
+ @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
32
+ @pytest.mark.parametrize("num_layers", NUM_LAYERS)
33
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
34
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
35
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
36
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
37
+ @pytest.mark.parametrize("dtype", DTYPES)
38
+ @pytest.mark.parametrize("seed", SEEDS)
39
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
40
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
41
+ @torch.inference_mode()
42
+ def test_copy_blocks(
43
+ kv_cache_factory,
44
+ num_mappings: int,
45
+ num_layers: int,
46
+ num_heads: int,
47
+ head_size: int,
48
+ block_size: int,
49
+ num_blocks: int,
50
+ dtype: torch.dtype,
51
+ seed: int,
52
+ kv_cache_dtype: str,
53
+ device: str,
54
+ ) -> None:
55
+ if kv_cache_dtype == "fp8" and head_size % 16:
56
+ pytest.skip()
57
+ current_platform.seed_everything(seed)
58
+ torch.set_default_device(device)
59
+ # Generate random block mappings where each source block is mapped to two
60
+ # destination blocks.
61
+ assert 2 * num_mappings <= num_blocks
62
+ src_blocks = random.sample(range(num_blocks), num_mappings)
63
+ remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
64
+ dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
65
+ block_mapping: List[Tuple[int, int]] = []
66
+ for i in range(num_mappings):
67
+ src = src_blocks[i]
68
+ dst1 = dst_blocks[2 * i]
69
+ dst2 = dst_blocks[2 * i + 1]
70
+ block_mapping.append((src, dst1))
71
+ block_mapping.append((src, dst2))
72
+
73
+ # Create the KV caches.
74
+ key_caches, value_caches = kv_cache_factory(
75
+ num_blocks,
76
+ block_size,
77
+ num_layers,
78
+ num_heads,
79
+ head_size,
80
+ kv_cache_dtype,
81
+ dtype,
82
+ seed,
83
+ device,
84
+ )
85
+
86
+ # Clone the KV caches.
87
+ cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
88
+ cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
89
+
90
+ # Call the copy blocks kernel.
91
+ block_mapping_tensor = torch.tensor(
92
+ block_mapping, dtype=torch.int64, device=device
93
+ ).view(-1, 2)
94
+
95
+ opcheck(
96
+ ops.ops.copy_blocks,
97
+ (key_caches, value_caches, block_mapping_tensor),
98
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS,
99
+ cond=(head_size == HEAD_SIZES[0]),
100
+ )
101
+ ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
102
+
103
+ # Run the reference implementation.
104
+ for src, dst in block_mapping:
105
+ for cloned_key_cache in cloned_key_caches:
106
+ cloned_key_cache[dst].copy_(cloned_key_cache[src])
107
+ for cloned_value_cache in cloned_value_caches:
108
+ cloned_value_cache[dst].copy_(cloned_value_cache[src])
109
+
110
+ # Compare the results.
111
+ for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
112
+ torch.testing.assert_close(key_cache, cloned_key_cache)
113
+ for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
114
+ torch.testing.assert_close(value_cache, cloned_value_cache)
115
+
116
+
117
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
118
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
119
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
120
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
121
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
122
+ @pytest.mark.parametrize("dtype", DTYPES)
123
+ @pytest.mark.parametrize("seed", SEEDS)
124
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
125
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
126
+ @torch.inference_mode()
127
+ def test_reshape_and_cache(
128
+ kv_cache_factory,
129
+ num_tokens: int,
130
+ num_heads: int,
131
+ head_size: int,
132
+ block_size: int,
133
+ num_blocks: int,
134
+ dtype: torch.dtype,
135
+ seed: int,
136
+ device: str,
137
+ kv_cache_dtype: str,
138
+ ) -> None:
139
+ if kv_cache_dtype == "fp8" and head_size % 16:
140
+ pytest.skip()
141
+ current_platform.seed_everything(seed)
142
+ torch.set_default_device(device)
143
+ # Create a random slot mapping.
144
+ num_slots = block_size * num_blocks
145
+ slot_mapping_lst = random.sample(range(num_slots), num_tokens)
146
+ slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
147
+
148
+ qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
149
+ _, key, value = qkv.unbind(dim=1)
150
+
151
+ # Create the KV caches.
152
+ key_caches, value_caches = kv_cache_factory(
153
+ num_blocks,
154
+ block_size,
155
+ 1,
156
+ num_heads,
157
+ head_size,
158
+ kv_cache_dtype,
159
+ dtype,
160
+ seed,
161
+ device,
162
+ )
163
+ key_cache, value_cache = key_caches[0], value_caches[0]
164
+
165
+ # Clone the KV caches.
166
+ if kv_cache_dtype == "fp8":
167
+ cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
168
+ ops.convert_fp8(cloned_key_cache, key_cache)
169
+ cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
170
+ ops.convert_fp8(cloned_value_cache, value_cache)
171
+ else:
172
+ cloned_key_cache = key_cache.clone()
173
+ cloned_value_cache = value_cache.clone()
174
+
175
+ # Using default kv_scale
176
+ k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
177
+
178
+ # Call the reshape_and_cache kernel.
179
+ opcheck(
180
+ ops.ops.reshape_and_cache,
181
+ (
182
+ key,
183
+ value,
184
+ key_cache,
185
+ value_cache,
186
+ slot_mapping,
187
+ kv_cache_dtype,
188
+ k_scale,
189
+ v_scale,
190
+ ),
191
+ cond=(head_size == HEAD_SIZES[0]),
192
+ )
193
+ ops.reshape_and_cache(
194
+ key,
195
+ value,
196
+ key_cache,
197
+ value_cache,
198
+ slot_mapping,
199
+ kv_cache_dtype,
200
+ k_scale,
201
+ v_scale,
202
+ )
203
+
204
+ if kv_cache_dtype == "fp8":
205
+ result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
206
+ ops.convert_fp8(result_key_cache, key_cache)
207
+ result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
208
+ ops.convert_fp8(result_value_cache, value_cache)
209
+
210
+ # Run the reference implementation.
211
+ reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
212
+ block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
213
+ block_indicies_lst = block_indicies.cpu().tolist()
214
+ block_offsets = slot_mapping % block_size
215
+ block_offsets_lst = block_offsets.cpu().tolist()
216
+ for i in range(num_tokens):
217
+ block_idx = block_indicies_lst[i]
218
+ block_offset = block_offsets_lst[i]
219
+ cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
220
+ cloned_value_cache[block_idx, :, :, block_offset] = value[i]
221
+
222
+ if kv_cache_dtype == "fp8":
223
+ torch.testing.assert_close(
224
+ result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
225
+ )
226
+ torch.testing.assert_close(
227
+ result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
228
+ )
229
+ else:
230
+ torch.testing.assert_close(key_cache, cloned_key_cache)
231
+ torch.testing.assert_close(value_cache, cloned_value_cache)
232
+
233
+
234
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
235
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
236
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
237
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
238
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
239
+ @pytest.mark.parametrize("dtype", DTYPES)
240
+ @pytest.mark.parametrize("seed", SEEDS)
241
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
242
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
243
+ @torch.inference_mode()
244
+ def test_reshape_and_cache_flash(
245
+ kv_cache_factory_flashinfer,
246
+ num_tokens: int,
247
+ num_heads: int,
248
+ head_size: int,
249
+ block_size: int,
250
+ num_blocks: int,
251
+ dtype: torch.dtype,
252
+ seed: int,
253
+ device: str,
254
+ kv_cache_dtype: str,
255
+ ) -> None:
256
+ current_platform.seed_everything(seed)
257
+ torch.set_default_device(device)
258
+
259
+ # Create a random slot mapping.
260
+ num_slots = block_size * num_blocks
261
+ slot_mapping_lst = random.sample(range(num_slots), num_tokens)
262
+ slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
263
+
264
+ qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
265
+ _, key, value = qkv.unbind(dim=1)
266
+
267
+ # Create the KV caches.
268
+ key_caches, value_caches = kv_cache_factory_flashinfer(
269
+ num_blocks,
270
+ block_size,
271
+ 1,
272
+ num_heads,
273
+ head_size,
274
+ kv_cache_dtype,
275
+ dtype,
276
+ device=device,
277
+ )
278
+ key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous()
279
+ del key_caches
280
+ del value_caches
281
+
282
+ k_scale = (key.amax() / 256.0).to(torch.float32)
283
+ v_scale = (value.amax() / 256.0).to(torch.float32)
284
+
285
+ # Clone the KV caches.
286
+ if kv_cache_dtype == "fp8":
287
+ cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
288
+ ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
289
+ cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
290
+ ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype)
291
+ else:
292
+ cloned_key_cache = key_cache.clone()
293
+ cloned_value_cache = value_cache.clone()
294
+
295
+ # Call the reshape_and_cache kernel.
296
+ opcheck(
297
+ ops.ops.reshape_and_cache_flash,
298
+ (
299
+ key,
300
+ value,
301
+ key_cache,
302
+ value_cache,
303
+ slot_mapping,
304
+ kv_cache_dtype,
305
+ k_scale,
306
+ v_scale,
307
+ ),
308
+ cond=(head_size == HEAD_SIZES[0]),
309
+ )
310
+ ops.reshape_and_cache_flash(
311
+ key,
312
+ value,
313
+ key_cache,
314
+ value_cache,
315
+ slot_mapping,
316
+ kv_cache_dtype,
317
+ k_scale,
318
+ v_scale,
319
+ )
320
+
321
+ if kv_cache_dtype == "fp8":
322
+ result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
323
+ ops.convert_fp8(
324
+ result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype
325
+ )
326
+ result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
327
+ ops.convert_fp8(
328
+ result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype
329
+ )
330
+
331
+ # Run the reference implementation.
332
+ block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
333
+ block_indicies_lst = block_indicies.cpu().tolist()
334
+ block_offsets = slot_mapping % block_size
335
+ block_offsets_lst = block_offsets.cpu().tolist()
336
+ for i in range(num_tokens):
337
+ block_idx = block_indicies_lst[i]
338
+ block_offset = block_offsets_lst[i]
339
+ cloned_key_cache[block_idx, block_offset, :, :] = key[i]
340
+ cloned_value_cache[block_idx, block_offset, :, :] = value[i]
341
+
342
+ if kv_cache_dtype == "fp8":
343
+ torch.testing.assert_close(
344
+ result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
345
+ )
346
+ torch.testing.assert_close(
347
+ result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
348
+ )
349
+ else:
350
+ torch.testing.assert_close(key_cache, cloned_key_cache)
351
+ torch.testing.assert_close(value_cache, cloned_value_cache)
352
+
353
+
354
+ @pytest.mark.parametrize("direction", COPYING_DIRECTION)
355
+ @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
356
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
357
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
358
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
359
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
360
+ @pytest.mark.parametrize("dtype", DTYPES)
361
+ @pytest.mark.parametrize("seed", SEEDS)
362
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
363
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
364
+ @torch.inference_mode()
365
+ def test_swap_blocks(
366
+ kv_cache_factory,
367
+ direction: Tuple[str, str],
368
+ num_mappings: int,
369
+ num_heads: int,
370
+ head_size: int,
371
+ block_size: int,
372
+ num_blocks: int,
373
+ dtype: torch.dtype,
374
+ seed: int,
375
+ device: str,
376
+ kv_cache_dtype: str,
377
+ ) -> None:
378
+ if kv_cache_dtype == "fp8" and "cpu" in direction:
379
+ pytest.skip()
380
+ if kv_cache_dtype == "fp8" and head_size % 16:
381
+ pytest.skip()
382
+
383
+ current_platform.seed_everything(seed)
384
+
385
+ src_device = device if direction[0] == "cuda" else "cpu"
386
+ dst_device = device if direction[1] == "cuda" else "cpu"
387
+
388
+ src_blocks = random.sample(range(num_blocks), num_mappings)
389
+ # For the same device, mapping must not overlap
390
+ if src_device == dst_device:
391
+ remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
392
+ dst_blocks = random.sample(remaining_blocks, num_mappings)
393
+ else:
394
+ dst_blocks = random.sample(range(num_blocks), num_mappings)
395
+
396
+ block_mapping = list(zip(src_blocks, dst_blocks))
397
+ block_mapping_tensor = torch.tensor(
398
+ block_mapping, dtype=torch.int64, device="cpu"
399
+ ).view(-1, 2)
400
+
401
+ # Create the KV caches on the first device.
402
+ src_key_caches, src_value_caches = kv_cache_factory(
403
+ num_blocks,
404
+ block_size,
405
+ 1,
406
+ num_heads,
407
+ head_size,
408
+ kv_cache_dtype,
409
+ dtype,
410
+ seed,
411
+ src_device,
412
+ )
413
+
414
+ # Create the KV caches on the second device.
415
+ dist_key_caches, dist_value_caches = kv_cache_factory(
416
+ num_blocks,
417
+ block_size,
418
+ 1,
419
+ num_heads,
420
+ head_size,
421
+ kv_cache_dtype,
422
+ dtype,
423
+ seed,
424
+ dst_device,
425
+ )
426
+
427
+ src_key_caches_clone = src_key_caches[0].clone()
428
+ src_value_caches_clone = src_value_caches[0].clone()
429
+
430
+ # Call the swap_blocks kernel.
431
+ do_opcheck = head_size == HEAD_SIZES[0]
432
+ opcheck(
433
+ ops.ops.swap_blocks,
434
+ (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
435
+ cond=do_opcheck,
436
+ )
437
+ opcheck(
438
+ ops.ops.swap_blocks,
439
+ (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
440
+ cond=do_opcheck,
441
+ )
442
+
443
+ ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
444
+ ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
445
+
446
+ for src, dst in block_mapping:
447
+ torch.testing.assert_close(
448
+ src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
449
+ )
450
+ torch.testing.assert_close(
451
+ src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
452
+ )
453
+
454
+
455
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
456
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
457
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
458
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
459
+ @pytest.mark.parametrize("dtype", DTYPES)
460
+ @pytest.mark.parametrize("seed", SEEDS)
461
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
462
+ @torch.inference_mode()
463
+ def test_fp8_e4m3_conversion(
464
+ num_heads: int,
465
+ head_size: int,
466
+ block_size: int,
467
+ num_blocks: int,
468
+ dtype: torch.dtype,
469
+ seed: int,
470
+ device: str,
471
+ ) -> None:
472
+ current_platform.seed_everything(seed)
473
+
474
+ low = -224.0
475
+ high = 224.0
476
+ shape = (num_blocks, num_heads, head_size, block_size)
477
+ cache = torch.empty(shape, dtype=dtype, device=device)
478
+ cache.uniform_(low, high)
479
+
480
+ cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
481
+ ops.convert_fp8(cache_fp8, cache)
482
+
483
+ converted_cache = torch.empty_like(cache)
484
+ ops.convert_fp8(converted_cache, cache_fp8)
485
+
486
+ torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
tests/kernels/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kernel test utils"""
2
+
3
+ import itertools
4
+ import random
5
+ import unittest
6
+ from functools import lru_cache
7
+ from numbers import Number
8
+ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
9
+
10
+ import pytest
11
+ import torch
12
+ from torch._prims_common import TensorLikeType
13
+
14
+ # For now, disable "test_aot_dispatch_dynamic" since there are some
15
+ # bugs related to this test in PyTorch 2.4.
16
+ DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
17
+ "test_schema",
18
+ "test_autograd_registration",
19
+ "test_faketensor",
20
+ )
21
+
22
+ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
23
+ "test_schema",
24
+ "test_autograd_registration",
25
+ "test_faketensor",
26
+ "test_aot_dispatch_dynamic",
27
+ )
28
+
29
+
30
+ # Copied/modified from torch._refs.__init__.py
31
+ def fp8_allclose(
32
+ a: TensorLikeType,
33
+ b: TensorLikeType,
34
+ rtol: float = 1e-05,
35
+ atol: float = 1e-08,
36
+ equal_nan: bool = False,
37
+ ) -> bool:
38
+ """
39
+ Reference implementation of torch.allclose
40
+ """
41
+ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
42
+
43
+ return bool(
44
+ torch.all(
45
+ torch.isclose(
46
+ a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
47
+ )
48
+ ).item()
49
+ )
50
+
51
+
52
+ def compute_max_diff(output, output_ref):
53
+ return torch.mean(torch.abs(output - output_ref)) / torch.mean(
54
+ torch.abs(output_ref)
55
+ )
56
+
57
+
58
+ # A special version of op check that has a restricted default set of test_utils
59
+ # and a patched version of allclose that supports fp8 types.
60
+ def opcheck(
61
+ op: Union[
62
+ torch._ops.OpOverload,
63
+ torch._ops.OpOverloadPacket,
64
+ torch._library.custom_ops.CustomOpDef,
65
+ ],
66
+ args: Tuple[Any, ...],
67
+ kwargs: Optional[Dict[str, Any]] = None,
68
+ *,
69
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
70
+ raise_exception: bool = True,
71
+ cond: bool = True
72
+ ) -> Dict[str, str]:
73
+ with unittest.mock.patch("torch.allclose", new=fp8_allclose):
74
+ return (
75
+ torch.library.opcheck(
76
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
77
+ )
78
+ if cond
79
+ else {}
80
+ )
81
+
82
+
83
+ @lru_cache(maxsize=None)
84
+ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
85
+ """Returns the maximum shared memory per thread block in bytes."""
86
+ from attention import ops
87
+
88
+ max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
89
+ # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
90
+ # will fail
91
+ assert max_shared_mem > 0, "max_shared_mem can not be zero"
92
+ return int(max_shared_mem)
torch-ext/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
torch-ext/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
torch-ext/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
torch-ext/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+
5
+ #include "torch_binding.h"
6
+
7
+ // Note on op signatures:
8
+ // The X_meta signatures are for the meta functions corresponding to op X.
9
+ // They must be kept in sync with the signature for X. Generally, only
10
+ // functions that return Tensors require a meta function.
11
+ //
12
+ // See the following links for detailed docs on op registration and function
13
+ // schemas.
14
+ // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
15
+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
16
+
17
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
18
+ // Attention ops
19
+ // Compute the attention between an input query and the cached
20
+ // keys/values using PagedAttention.
21
+ ops.def(
22
+ "paged_attention_v1("
23
+ " Tensor! out, Tensor query, Tensor key_cache,"
24
+ " Tensor value_cache, int num_kv_heads, float scale,"
25
+ " Tensor block_tables, Tensor seq_lens, int block_size,"
26
+ " int max_seq_len, Tensor? alibi_slopes,"
27
+ " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
28
+ " int tp_rank, int blocksparse_local_blocks,"
29
+ " int blocksparse_vert_stride, int blocksparse_block_size,"
30
+ " int blocksparse_head_sliding_step) -> ()");
31
+ ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
32
+
33
+ // PagedAttention V2.
34
+ ops.def(
35
+ "paged_attention_v2("
36
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
37
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
38
+ " Tensor value_cache, int num_kv_heads, float scale,"
39
+ " Tensor block_tables, Tensor seq_lens, int block_size,"
40
+ " int max_seq_len, Tensor? alibi_slopes,"
41
+ " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
42
+ " int tp_rank, int blocksparse_local_blocks,"
43
+ " int blocksparse_vert_stride, int blocksparse_block_size,"
44
+ " int blocksparse_head_sliding_step) -> ()");
45
+ ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
46
+
47
+ // Swap in (out) the cache blocks from src to dst.
48
+ ops.def(
49
+ "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
50
+ ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
51
+
52
+ // Copy the cache blocks from src to dst.
53
+ ops.def(
54
+ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
55
+ "Tensor block_mapping) -> ()");
56
+ ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
57
+
58
+ // Reshape the key and value tensors and cache them.
59
+ ops.def(
60
+ "reshape_and_cache(Tensor key, Tensor value,"
61
+ " Tensor! key_cache, Tensor! value_cache,"
62
+ " Tensor slot_mapping,"
63
+ " str kv_cache_dtype,"
64
+ " Tensor k_scale, Tensor v_scale) -> ()");
65
+ ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
66
+
67
+ // Reshape the key and value tensors and cache them.
68
+ ops.def(
69
+ "reshape_and_cache_flash(Tensor key, Tensor value,"
70
+ " Tensor! key_cache,"
71
+ " Tensor! value_cache,"
72
+ " Tensor slot_mapping,"
73
+ " str kv_cache_dtype,"
74
+ " Tensor k_scale, Tensor v_scale) -> ()");
75
+ ops.impl("reshape_and_cache_flash", torch::kCUDA,
76
+ &reshape_and_cache_flash);
77
+
78
+ // Gets the specified device attribute.
79
+ ops.def("get_device_attribute(int attribute, int device_id) -> int");
80
+ ops.impl("get_device_attribute", &get_device_attribute);
81
+
82
+ // Gets the maximum shared memory per block device attribute.
83
+ ops.def(
84
+ "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
85
+ ops.impl("get_max_shared_memory_per_block_device_attribute",
86
+ &get_max_shared_memory_per_block_device_attribute);
87
+
88
+ // Convert the key and value cache to fp8 data type.
89
+ ops.def(
90
+ "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
91
+ "str kv_cache_dtype) -> ()");
92
+ ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
93
+ }
94
+
95
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void paged_attention_v1(
6
+ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
7
+ torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
8
+ torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
9
+ int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
10
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
11
+ torch::Tensor& v_scale, const int64_t tp_rank,
12
+ const int64_t blocksparse_local_blocks,
13
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
14
+ const int64_t blocksparse_head_sliding_step);
15
+
16
+ void paged_attention_v2(
17
+ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
18
+ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
19
+ torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
20
+ torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
21
+ int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
22
+ const std::string& kv_cache_dtype, torch::Tensor& k_scale,
23
+ torch::Tensor& v_scale, const int64_t tp_rank,
24
+ const int64_t blocksparse_local_blocks,
25
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
26
+ const int64_t blocksparse_head_sliding_step);
27
+
28
+ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
29
+ const torch::Tensor& block_mapping);
30
+
31
+ // Note: the key_caches and value_caches vectors are constant but
32
+ // not the Tensors they contain. The vectors need to be const refs
33
+ // in order to satisfy pytorch's C++ operator registration code.
34
+ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
35
+ std::vector<torch::Tensor> const& value_caches,
36
+ const torch::Tensor& block_mapping);
37
+
38
+ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
39
+ torch::Tensor& key_cache, torch::Tensor& value_cache,
40
+ torch::Tensor& slot_mapping,
41
+ const std::string& kv_cache_dtype,
42
+ torch::Tensor& k_scale, torch::Tensor& v_scale);
43
+
44
+ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
45
+ torch::Tensor& key_cache,
46
+ torch::Tensor& value_cache,
47
+ torch::Tensor& slot_mapping,
48
+ const std::string& kv_cache_dtype,
49
+ torch::Tensor& k_scale, torch::Tensor& v_scale);
50
+
51
+ int64_t get_device_attribute(int64_t attribute, int64_t device_id);
52
+
53
+ int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
54
+
55
+ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
56
+ const double scale, const std::string& kv_cache_dtype);