Port vLLM attention kernels
Browse files- .gitattributes +1 -0
- README.md +4 -0
- build.toml +44 -0
- cuda-utils/cuda_utils_kernels.cu +29 -0
- flake.nix +14 -0
- paged-attention/attention/attention_dtypes.h +7 -0
- paged-attention/attention/attention_generic.cuh +65 -0
- paged-attention/attention/attention_kernels.cuh +676 -0
- paged-attention/attention/attention_utils.cuh +57 -0
- paged-attention/attention/dtype_bfloat16.cuh +463 -0
- paged-attention/attention/dtype_float16.cuh +504 -0
- paged-attention/attention/dtype_float32.cuh +251 -0
- paged-attention/attention/dtype_fp8.cuh +41 -0
- paged-attention/attention/paged_attention_v1.cu +196 -0
- paged-attention/attention/paged_attention_v2.cu +206 -0
- paged-attention/cache_kernels.cu +419 -0
- paged-attention/cuda_compat.h +49 -0
- paged-attention/dispatch_utils.h +49 -0
- paged-attention/quantization/fp8/amd/hip_float8.h +137 -0
- paged-attention/quantization/fp8/amd/hip_float8_impl.h +316 -0
- paged-attention/quantization/fp8/amd/quant_utils.cuh +577 -0
- paged-attention/quantization/fp8/nvidia/quant_utils.cuh +573 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/allclose_default.py +14 -0
- tests/kernels/conftest.py +158 -0
- tests/kernels/test_attention.py +418 -0
- tests/kernels/test_cache.py +486 -0
- tests/kernels/utils.py +92 -0
- torch-ext/attention/__init__.py +21 -0
- torch-ext/attention/_custom_ops.py +173 -0
- torch-ext/attention/platforms.py +62 -0
- torch-ext/registration.h +27 -0
- torch-ext/torch_binding.cpp +95 -0
- torch-ext/torch_binding.h +56 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,7 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
+
|
5 |
+
## attention
|
6 |
+
|
7 |
+
Paged attention kernels from [vLLM](https://github.com/vllm-project/).
|
build.toml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "attention"
|
6 |
+
src = [
|
7 |
+
"torch-ext/registration.h",
|
8 |
+
"torch-ext/torch_binding.cpp",
|
9 |
+
"torch-ext/torch_binding.h"
|
10 |
+
]
|
11 |
+
pyroot = "torch-ext"
|
12 |
+
|
13 |
+
[kernel.cuda_utils]
|
14 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
15 |
+
src = [
|
16 |
+
"cuda-utils/cuda_utils_kernels.cu",
|
17 |
+
]
|
18 |
+
depends = []
|
19 |
+
|
20 |
+
|
21 |
+
[kernel.paged_attention]
|
22 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
23 |
+
src = [
|
24 |
+
"paged-attention/attention/attention_dtypes.h",
|
25 |
+
"paged-attention/attention/attention_generic.cuh",
|
26 |
+
"paged-attention/attention/attention_kernels.cuh",
|
27 |
+
"paged-attention/attention/attention_utils.cuh",
|
28 |
+
"paged-attention/attention/dtype_bfloat16.cuh",
|
29 |
+
"paged-attention/attention/dtype_float16.cuh",
|
30 |
+
"paged-attention/attention/dtype_float32.cuh",
|
31 |
+
"paged-attention/attention/dtype_fp8.cuh",
|
32 |
+
"paged-attention/attention/paged_attention_v1.cu",
|
33 |
+
"paged-attention/attention/paged_attention_v2.cu",
|
34 |
+
"paged-attention/cache_kernels.cu",
|
35 |
+
"paged-attention/cuda_compat.h",
|
36 |
+
"paged-attention/dispatch_utils.h",
|
37 |
+
"paged-attention/quantization/fp8/amd/hip_float8.h",
|
38 |
+
"paged-attention/quantization/fp8/amd/hip_float8_impl.h",
|
39 |
+
"paged-attention/quantization/fp8/amd/quant_utils.cuh",
|
40 |
+
"paged-attention/quantization/fp8/nvidia/quant_utils.cuh",
|
41 |
+
]
|
42 |
+
include = [ "." ]
|
43 |
+
depends = [ "torch" ]
|
44 |
+
|
cuda-utils/cuda_utils_kernels.cu
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifdef USE_ROCM
|
2 |
+
#include <hip/hip_runtime.h>
|
3 |
+
#include <hip/hip_runtime_api.h>
|
4 |
+
#endif
|
5 |
+
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
|
6 |
+
int device, value;
|
7 |
+
if (device_id < 0) {
|
8 |
+
cudaGetDevice(&device);
|
9 |
+
} else {
|
10 |
+
device = device_id;
|
11 |
+
}
|
12 |
+
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
13 |
+
device);
|
14 |
+
return value;
|
15 |
+
}
|
16 |
+
|
17 |
+
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
|
18 |
+
int64_t attribute;
|
19 |
+
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
20 |
+
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
21 |
+
|
22 |
+
#ifdef USE_ROCM
|
23 |
+
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
24 |
+
#else
|
25 |
+
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
26 |
+
#endif
|
27 |
+
|
28 |
+
return get_device_attribute(attribute, device_id);
|
29 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for attention kernels";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
14 |
+
}
|
paged-attention/attention/attention_dtypes.h
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "attention_generic.cuh"
|
4 |
+
#include "dtype_float16.cuh"
|
5 |
+
#include "dtype_float32.cuh"
|
6 |
+
#include "dtype_bfloat16.cuh"
|
7 |
+
#include "dtype_fp8.cuh"
|
paged-attention/attention/attention_generic.cuh
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include <stdint.h>
|
22 |
+
|
23 |
+
namespace vllm {
|
24 |
+
|
25 |
+
// A vector type to store Q, K, V elements.
|
26 |
+
template <typename T, int VEC_SIZE>
|
27 |
+
struct Vec {};
|
28 |
+
|
29 |
+
// A vector type to store FP32 accumulators.
|
30 |
+
template <typename T>
|
31 |
+
struct FloatVec {};
|
32 |
+
|
33 |
+
// Template vector operations.
|
34 |
+
template <typename Acc, typename A, typename B>
|
35 |
+
inline __device__ Acc mul(A a, B b);
|
36 |
+
|
37 |
+
template <typename T>
|
38 |
+
inline __device__ float sum(T v);
|
39 |
+
|
40 |
+
template <typename T>
|
41 |
+
inline __device__ float dot(T a, T b) {
|
42 |
+
return sum(mul<T, T, T>(a, b));
|
43 |
+
}
|
44 |
+
|
45 |
+
template <typename A, typename T>
|
46 |
+
inline __device__ float dot(T a, T b) {
|
47 |
+
return sum(mul<A, T, T>(a, b));
|
48 |
+
}
|
49 |
+
|
50 |
+
template <typename T>
|
51 |
+
inline __device__ void zero(T& dst) {
|
52 |
+
constexpr int WORDS = sizeof(T) / 4;
|
53 |
+
union {
|
54 |
+
T raw;
|
55 |
+
uint32_t words[WORDS];
|
56 |
+
} tmp;
|
57 |
+
|
58 |
+
#pragma unroll
|
59 |
+
for (int ii = 0; ii < WORDS; ++ii) {
|
60 |
+
tmp.words[ii] = 0u;
|
61 |
+
}
|
62 |
+
dst = tmp.raw;
|
63 |
+
}
|
64 |
+
|
65 |
+
} // namespace vllm
|
paged-attention/attention/attention_kernels.cuh
ADDED
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
|
20 |
+
#include <torch/all.h>
|
21 |
+
#include <ATen/cuda/CUDAContext.h>
|
22 |
+
#include <c10/cuda/CUDAGuard.h>
|
23 |
+
#include <algorithm>
|
24 |
+
|
25 |
+
#include "attention_dtypes.h"
|
26 |
+
#include "attention_utils.cuh"
|
27 |
+
|
28 |
+
#ifdef USE_ROCM
|
29 |
+
#include <hip/hip_bf16.h>
|
30 |
+
#include "../quantization/fp8/amd/quant_utils.cuh"
|
31 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
32 |
+
#else
|
33 |
+
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
34 |
+
#endif
|
35 |
+
|
36 |
+
#ifndef USE_ROCM
|
37 |
+
#define WARP_SIZE 32
|
38 |
+
#else
|
39 |
+
#define WARP_SIZE warpSize
|
40 |
+
#endif
|
41 |
+
|
42 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
43 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
44 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
45 |
+
|
46 |
+
namespace vllm {
|
47 |
+
|
48 |
+
// Utility function for attention softmax.
|
49 |
+
template <int NUM_WARPS>
|
50 |
+
inline __device__ float block_sum(float* red_smem, float sum) {
|
51 |
+
// Decompose the thread index into warp / lane.
|
52 |
+
int warp = threadIdx.x / WARP_SIZE;
|
53 |
+
int lane = threadIdx.x % WARP_SIZE;
|
54 |
+
|
55 |
+
// Compute the sum per warp.
|
56 |
+
#pragma unroll
|
57 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
58 |
+
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
59 |
+
}
|
60 |
+
|
61 |
+
// Warp leaders store the data to shared memory.
|
62 |
+
if (lane == 0) {
|
63 |
+
red_smem[warp] = sum;
|
64 |
+
}
|
65 |
+
|
66 |
+
// Make sure the data is in shared memory.
|
67 |
+
__syncthreads();
|
68 |
+
|
69 |
+
// The warps compute the final sums.
|
70 |
+
if (lane < NUM_WARPS) {
|
71 |
+
sum = red_smem[lane];
|
72 |
+
}
|
73 |
+
|
74 |
+
// Parallel reduction inside the warp.
|
75 |
+
#pragma unroll
|
76 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
77 |
+
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
78 |
+
}
|
79 |
+
|
80 |
+
// Broadcast to other threads.
|
81 |
+
return VLLM_SHFL_SYNC(sum, 0);
|
82 |
+
}
|
83 |
+
|
84 |
+
// TODO(woosuk): Merge the last two dimensions of the grid.
|
85 |
+
// Grid: (num_heads, num_seqs, max_num_partitions).
|
86 |
+
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
87 |
+
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
88 |
+
bool IS_BLOCK_SPARSE,
|
89 |
+
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
90 |
+
__device__ void paged_attention_kernel(
|
91 |
+
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
92 |
+
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
93 |
+
// max_num_partitions]
|
94 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
95 |
+
// head_size]
|
96 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
97 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
98 |
+
// head_size/x, block_size, x]
|
99 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
100 |
+
// head_size, block_size]
|
101 |
+
const int num_kv_heads, // [num_heads]
|
102 |
+
const float scale,
|
103 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
104 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
105 |
+
const int max_num_blocks_per_seq,
|
106 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
107 |
+
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
108 |
+
const float* k_scale, const float* v_scale, const int tp_rank,
|
109 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
110 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
111 |
+
const int seq_idx = blockIdx.y;
|
112 |
+
const int partition_idx = blockIdx.z;
|
113 |
+
const int max_num_partitions = gridDim.z;
|
114 |
+
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
115 |
+
const int seq_len = seq_lens[seq_idx];
|
116 |
+
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
|
117 |
+
// No work to do. Terminate the thread block.
|
118 |
+
return;
|
119 |
+
}
|
120 |
+
|
121 |
+
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
122 |
+
const int num_blocks_per_partition =
|
123 |
+
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
124 |
+
|
125 |
+
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
126 |
+
const int start_block_idx =
|
127 |
+
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
128 |
+
const int end_block_idx =
|
129 |
+
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
130 |
+
const int num_blocks = end_block_idx - start_block_idx;
|
131 |
+
|
132 |
+
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
133 |
+
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
134 |
+
const int end_token_idx =
|
135 |
+
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
136 |
+
const int num_tokens = end_token_idx - start_token_idx;
|
137 |
+
|
138 |
+
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
139 |
+
constexpr int NUM_THREAD_GROUPS =
|
140 |
+
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
141 |
+
// divides NUM_THREADS
|
142 |
+
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
143 |
+
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
144 |
+
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
145 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
146 |
+
const int thread_idx = threadIdx.x;
|
147 |
+
const int warp_idx = thread_idx / WARP_SIZE;
|
148 |
+
const int lane = thread_idx % WARP_SIZE;
|
149 |
+
|
150 |
+
const int head_idx = blockIdx.x;
|
151 |
+
const int num_heads = gridDim.x;
|
152 |
+
const int num_queries_per_kv = num_heads / num_kv_heads;
|
153 |
+
const int kv_head_idx = head_idx / num_queries_per_kv;
|
154 |
+
const float alibi_slope =
|
155 |
+
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
156 |
+
|
157 |
+
// A vector type to store a part of a key or a query.
|
158 |
+
// The vector size is configured in such a way that the threads in a thread
|
159 |
+
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
160 |
+
// thread group is 4 and the data type is half, then the vector size is 16 /
|
161 |
+
// (4 * sizeof(half)) == 2.
|
162 |
+
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
163 |
+
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
164 |
+
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
165 |
+
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
166 |
+
|
167 |
+
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
168 |
+
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
169 |
+
|
170 |
+
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
171 |
+
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
172 |
+
|
173 |
+
// Load the query to registers.
|
174 |
+
// Each thread in a thread group has a different part of the query.
|
175 |
+
// For example, if the the thread group size is 4, then the first thread in
|
176 |
+
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
177 |
+
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
178 |
+
// q is split from a qkv tensor, it may not be contiguous.
|
179 |
+
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
180 |
+
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
181 |
+
#pragma unroll
|
182 |
+
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
183 |
+
i += NUM_THREAD_GROUPS) {
|
184 |
+
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
185 |
+
q_vecs[thread_group_offset][i] =
|
186 |
+
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
187 |
+
}
|
188 |
+
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
|
189 |
+
// memory wall right before we use q_vecs
|
190 |
+
|
191 |
+
// Memory planning.
|
192 |
+
extern __shared__ char shared_mem[];
|
193 |
+
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
194 |
+
float* logits = reinterpret_cast<float*>(shared_mem);
|
195 |
+
// Workspace for reduction.
|
196 |
+
__shared__ float red_smem[2 * NUM_WARPS];
|
197 |
+
|
198 |
+
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
199 |
+
// Each thread group fetches x elements from the key at a time.
|
200 |
+
constexpr int x = 16 / sizeof(cache_t);
|
201 |
+
float qk_max = -FLT_MAX;
|
202 |
+
|
203 |
+
// Iterate over the key blocks.
|
204 |
+
// Each warp fetches a block of keys for each iteration.
|
205 |
+
// Each thread group in a warp fetches a key from the block, and computes
|
206 |
+
// dot product with the query.
|
207 |
+
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
208 |
+
|
209 |
+
// blocksparse specific vars
|
210 |
+
int bs_block_offset;
|
211 |
+
int q_bs_block_id;
|
212 |
+
if constexpr (IS_BLOCK_SPARSE) {
|
213 |
+
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
|
214 |
+
// blocksparse_block_size);
|
215 |
+
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
|
216 |
+
if (blocksparse_head_sliding_step >= 0)
|
217 |
+
// sliding on q heads
|
218 |
+
bs_block_offset =
|
219 |
+
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
|
220 |
+
else
|
221 |
+
// sliding on kv heads
|
222 |
+
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
|
223 |
+
(-blocksparse_head_sliding_step) +
|
224 |
+
1;
|
225 |
+
}
|
226 |
+
|
227 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
228 |
+
block_idx += NUM_WARPS) {
|
229 |
+
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
230 |
+
// int64 because int32 can lead to overflow when this variable is multiplied
|
231 |
+
// by large numbers (e.g., kv_block_stride).
|
232 |
+
// For blocksparse attention: skip computation on blocks that are not
|
233 |
+
// attended
|
234 |
+
if constexpr (IS_BLOCK_SPARSE) {
|
235 |
+
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
236 |
+
const bool is_remote =
|
237 |
+
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
|
238 |
+
const bool is_local =
|
239 |
+
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
|
240 |
+
if (!is_remote && !is_local) {
|
241 |
+
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
242 |
+
const int physical_block_offset =
|
243 |
+
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
244 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
245 |
+
|
246 |
+
if (thread_group_offset == 0) {
|
247 |
+
// NOTE(linxihui): assign very large number to skipped tokens to
|
248 |
+
// avoid contribution to the sumexp softmax normalizer. This will
|
249 |
+
// not be used at computing sum(softmax*v) as the blocks will be
|
250 |
+
// skipped.
|
251 |
+
logits[token_idx - start_token_idx] = -FLT_MAX;
|
252 |
+
}
|
253 |
+
}
|
254 |
+
continue;
|
255 |
+
}
|
256 |
+
}
|
257 |
+
const int64_t physical_block_number =
|
258 |
+
static_cast<int64_t>(block_table[block_idx]);
|
259 |
+
|
260 |
+
// Load a key to registers.
|
261 |
+
// Each thread in a thread group has a different part of the key.
|
262 |
+
// For example, if the the thread group size is 4, then the first thread in
|
263 |
+
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
264 |
+
// has 1, 5, 9, ... th vectors of the key, and so on.
|
265 |
+
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
266 |
+
const int physical_block_offset =
|
267 |
+
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
268 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
269 |
+
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
270 |
+
|
271 |
+
#pragma unroll
|
272 |
+
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
273 |
+
const cache_t* k_ptr =
|
274 |
+
k_cache + physical_block_number * kv_block_stride +
|
275 |
+
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
276 |
+
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
277 |
+
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
278 |
+
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
279 |
+
|
280 |
+
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
281 |
+
k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
282 |
+
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
283 |
+
} else {
|
284 |
+
// Vector conversion from Quant_vec to K_vec.
|
285 |
+
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
286 |
+
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
287 |
+
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
288 |
+
k_vec_quant, *k_scale);
|
289 |
+
}
|
290 |
+
}
|
291 |
+
|
292 |
+
// Compute dot product.
|
293 |
+
// This includes a reduction across the threads in the same thread group.
|
294 |
+
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
295 |
+
q_vecs[thread_group_offset], k_vecs);
|
296 |
+
// Add the ALiBi bias if slopes are given.
|
297 |
+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
298 |
+
|
299 |
+
if (thread_group_offset == 0) {
|
300 |
+
// Store the partial reductions to shared memory.
|
301 |
+
// NOTE(woosuk): It is required to zero out the masked logits.
|
302 |
+
const bool mask = token_idx >= seq_len;
|
303 |
+
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
304 |
+
// Update the max value.
|
305 |
+
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
306 |
+
}
|
307 |
+
}
|
308 |
+
}
|
309 |
+
|
310 |
+
// Perform reduction across the threads in the same warp to get the
|
311 |
+
// max qk value for each "warp" (not across the thread block yet).
|
312 |
+
// The 0-th thread of each thread group already has its max qk value.
|
313 |
+
#pragma unroll
|
314 |
+
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
315 |
+
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
316 |
+
}
|
317 |
+
if (lane == 0) {
|
318 |
+
red_smem[warp_idx] = qk_max;
|
319 |
+
}
|
320 |
+
__syncthreads();
|
321 |
+
|
322 |
+
// TODO(woosuk): Refactor this part.
|
323 |
+
// Get the max qk value for the sequence.
|
324 |
+
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
325 |
+
#pragma unroll
|
326 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
327 |
+
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
328 |
+
}
|
329 |
+
// Broadcast the max qk value to all threads.
|
330 |
+
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
331 |
+
|
332 |
+
// Get the sum of the exp values.
|
333 |
+
float exp_sum = 0.f;
|
334 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
335 |
+
float val = __expf(logits[i] - qk_max);
|
336 |
+
logits[i] = val;
|
337 |
+
exp_sum += val;
|
338 |
+
}
|
339 |
+
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
340 |
+
|
341 |
+
// Compute softmax.
|
342 |
+
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
343 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
344 |
+
logits[i] *= inv_sum;
|
345 |
+
}
|
346 |
+
__syncthreads();
|
347 |
+
|
348 |
+
// If partitioning is enabled, store the max logit and exp_sum.
|
349 |
+
if (USE_PARTITIONING && thread_idx == 0) {
|
350 |
+
float* max_logits_ptr = max_logits +
|
351 |
+
seq_idx * num_heads * max_num_partitions +
|
352 |
+
head_idx * max_num_partitions + partition_idx;
|
353 |
+
*max_logits_ptr = qk_max;
|
354 |
+
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
355 |
+
head_idx * max_num_partitions + partition_idx;
|
356 |
+
*exp_sums_ptr = exp_sum;
|
357 |
+
}
|
358 |
+
|
359 |
+
// Each thread will fetch 16 bytes from the value cache at a time.
|
360 |
+
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
361 |
+
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
362 |
+
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
363 |
+
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
364 |
+
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
365 |
+
|
366 |
+
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
367 |
+
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
368 |
+
constexpr int NUM_ROWS_PER_THREAD =
|
369 |
+
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
370 |
+
|
371 |
+
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
372 |
+
float accs[NUM_ROWS_PER_THREAD];
|
373 |
+
#pragma unroll
|
374 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
375 |
+
accs[i] = 0.f;
|
376 |
+
}
|
377 |
+
|
378 |
+
scalar_t zero_value;
|
379 |
+
zero(zero_value);
|
380 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
381 |
+
block_idx += NUM_WARPS) {
|
382 |
+
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
383 |
+
// int64 because int32 can lead to overflow when this variable is multiplied
|
384 |
+
// by large numbers (e.g., kv_block_stride).
|
385 |
+
// For blocksparse attention: skip computation on blocks that are not
|
386 |
+
// attended
|
387 |
+
if constexpr (IS_BLOCK_SPARSE) {
|
388 |
+
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
389 |
+
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
|
390 |
+
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
|
391 |
+
continue;
|
392 |
+
}
|
393 |
+
}
|
394 |
+
const int64_t physical_block_number =
|
395 |
+
static_cast<int64_t>(block_table[block_idx]);
|
396 |
+
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
397 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
398 |
+
L_vec logits_vec;
|
399 |
+
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
400 |
+
start_token_idx));
|
401 |
+
|
402 |
+
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
403 |
+
kv_head_idx * kv_head_stride;
|
404 |
+
#pragma unroll
|
405 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
406 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
407 |
+
if (row_idx < HEAD_SIZE) {
|
408 |
+
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
409 |
+
V_vec v_vec;
|
410 |
+
|
411 |
+
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
412 |
+
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
413 |
+
} else {
|
414 |
+
V_quant_vec v_quant_vec =
|
415 |
+
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
416 |
+
// Vector conversion from V_quant_vec to V_vec.
|
417 |
+
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
418 |
+
*v_scale);
|
419 |
+
}
|
420 |
+
if (block_idx == num_seq_blocks - 1) {
|
421 |
+
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
422 |
+
// context, we should explicitly zero out the values since they may
|
423 |
+
// contain NaNs. See
|
424 |
+
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
425 |
+
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
426 |
+
#pragma unroll
|
427 |
+
for (int j = 0; j < V_VEC_SIZE; j++) {
|
428 |
+
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
|
429 |
+
}
|
430 |
+
}
|
431 |
+
accs[i] += dot(logits_vec, v_vec);
|
432 |
+
}
|
433 |
+
}
|
434 |
+
}
|
435 |
+
|
436 |
+
// Perform reduction within each warp.
|
437 |
+
#pragma unroll
|
438 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
439 |
+
float acc = accs[i];
|
440 |
+
#pragma unroll
|
441 |
+
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
442 |
+
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
443 |
+
}
|
444 |
+
accs[i] = acc;
|
445 |
+
}
|
446 |
+
|
447 |
+
// NOTE(woosuk): A barrier is required because the shared memory space for
|
448 |
+
// logits is reused for the output.
|
449 |
+
__syncthreads();
|
450 |
+
|
451 |
+
// Perform reduction across warps.
|
452 |
+
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
453 |
+
#pragma unroll
|
454 |
+
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
455 |
+
int mid = i / 2;
|
456 |
+
// Upper warps write to shared memory.
|
457 |
+
if (warp_idx >= mid && warp_idx < i) {
|
458 |
+
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
459 |
+
#pragma unroll
|
460 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
461 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
462 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
463 |
+
dst[row_idx] = accs[i];
|
464 |
+
}
|
465 |
+
}
|
466 |
+
}
|
467 |
+
__syncthreads();
|
468 |
+
|
469 |
+
// Lower warps update the output.
|
470 |
+
if (warp_idx < mid) {
|
471 |
+
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
472 |
+
#pragma unroll
|
473 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
474 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
475 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
476 |
+
accs[i] += src[row_idx];
|
477 |
+
}
|
478 |
+
}
|
479 |
+
}
|
480 |
+
__syncthreads();
|
481 |
+
}
|
482 |
+
|
483 |
+
// Write the final output.
|
484 |
+
if (warp_idx == 0) {
|
485 |
+
scalar_t* out_ptr =
|
486 |
+
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
487 |
+
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
488 |
+
#pragma unroll
|
489 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
490 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
491 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
492 |
+
from_float(*(out_ptr + row_idx), accs[i]);
|
493 |
+
}
|
494 |
+
}
|
495 |
+
}
|
496 |
+
}
|
497 |
+
|
498 |
+
// Grid: (num_heads, num_seqs, 1).
|
499 |
+
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
500 |
+
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
501 |
+
bool IS_BLOCK_SPARSE>
|
502 |
+
__global__ void paged_attention_v1_kernel(
|
503 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
504 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
505 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
506 |
+
// head_size/x, block_size, x]
|
507 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
508 |
+
// head_size, block_size]
|
509 |
+
const int num_kv_heads, // [num_heads]
|
510 |
+
const float scale,
|
511 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
512 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
513 |
+
const int max_num_blocks_per_seq,
|
514 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
515 |
+
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
516 |
+
const float* k_scale, const float* v_scale, const int tp_rank,
|
517 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
518 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
519 |
+
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
520 |
+
KV_DTYPE, IS_BLOCK_SPARSE>(
|
521 |
+
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
522 |
+
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
523 |
+
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
524 |
+
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
525 |
+
blocksparse_vert_stride, blocksparse_block_size,
|
526 |
+
blocksparse_head_sliding_step);
|
527 |
+
}
|
528 |
+
|
529 |
+
// Grid: (num_heads, num_seqs, max_num_partitions).
|
530 |
+
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
531 |
+
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
532 |
+
bool IS_BLOCK_SPARSE,
|
533 |
+
int PARTITION_SIZE>
|
534 |
+
__global__ void paged_attention_v2_kernel(
|
535 |
+
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
536 |
+
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
537 |
+
// max_num_partitions]
|
538 |
+
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
539 |
+
// max_num_partitions, head_size]
|
540 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
541 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
542 |
+
// head_size/x, block_size, x]
|
543 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
544 |
+
// head_size, block_size]
|
545 |
+
const int num_kv_heads, // [num_heads]
|
546 |
+
const float scale,
|
547 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
548 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
549 |
+
const int max_num_blocks_per_seq,
|
550 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
551 |
+
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
552 |
+
const float* k_scale, const float* v_scale, const int tp_rank,
|
553 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
554 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
555 |
+
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
556 |
+
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
557 |
+
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
558 |
+
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
559 |
+
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
560 |
+
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
561 |
+
blocksparse_head_sliding_step);
|
562 |
+
}
|
563 |
+
|
564 |
+
// Grid: (num_heads, num_seqs).
|
565 |
+
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
566 |
+
int PARTITION_SIZE>
|
567 |
+
__global__ void paged_attention_v2_reduce_kernel(
|
568 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
569 |
+
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
570 |
+
// max_num_partitions]
|
571 |
+
const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
572 |
+
// max_num_partitions]
|
573 |
+
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
574 |
+
// max_num_partitions, head_size]
|
575 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
576 |
+
const int max_num_partitions) {
|
577 |
+
const int num_heads = gridDim.x;
|
578 |
+
const int head_idx = blockIdx.x;
|
579 |
+
const int seq_idx = blockIdx.y;
|
580 |
+
const int seq_len = seq_lens[seq_idx];
|
581 |
+
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
582 |
+
if (num_partitions == 1) {
|
583 |
+
// No need to reduce. Only copy tmp_out to out.
|
584 |
+
scalar_t* out_ptr =
|
585 |
+
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
586 |
+
const scalar_t* tmp_out_ptr =
|
587 |
+
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
588 |
+
head_idx * max_num_partitions * HEAD_SIZE;
|
589 |
+
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
590 |
+
out_ptr[i] = tmp_out_ptr[i];
|
591 |
+
}
|
592 |
+
// Terminate the thread block.
|
593 |
+
return;
|
594 |
+
}
|
595 |
+
|
596 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
597 |
+
const int warp_idx = threadIdx.x / WARP_SIZE;
|
598 |
+
const int lane = threadIdx.x % WARP_SIZE;
|
599 |
+
|
600 |
+
// Size: 2 * num_partitions.
|
601 |
+
extern __shared__ char shared_mem[];
|
602 |
+
// Workspace for reduction.
|
603 |
+
__shared__ float red_smem[2 * NUM_WARPS];
|
604 |
+
|
605 |
+
// Load max logits to shared memory.
|
606 |
+
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
607 |
+
const float* max_logits_ptr = max_logits +
|
608 |
+
seq_idx * num_heads * max_num_partitions +
|
609 |
+
head_idx * max_num_partitions;
|
610 |
+
float max_logit = -FLT_MAX;
|
611 |
+
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
612 |
+
const float l = max_logits_ptr[i];
|
613 |
+
shared_max_logits[i] = l;
|
614 |
+
max_logit = fmaxf(max_logit, l);
|
615 |
+
}
|
616 |
+
__syncthreads();
|
617 |
+
|
618 |
+
// Get the global max logit.
|
619 |
+
// Reduce within the warp.
|
620 |
+
#pragma unroll
|
621 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
622 |
+
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
623 |
+
}
|
624 |
+
if (lane == 0) {
|
625 |
+
red_smem[warp_idx] = max_logit;
|
626 |
+
}
|
627 |
+
__syncthreads();
|
628 |
+
// Reduce across warps.
|
629 |
+
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
630 |
+
#pragma unroll
|
631 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
632 |
+
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
633 |
+
}
|
634 |
+
// Broadcast the max value to all threads.
|
635 |
+
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
636 |
+
|
637 |
+
// Load rescaled exp sums to shared memory.
|
638 |
+
float* shared_exp_sums =
|
639 |
+
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
640 |
+
const float* exp_sums_ptr = exp_sums +
|
641 |
+
seq_idx * num_heads * max_num_partitions +
|
642 |
+
head_idx * max_num_partitions;
|
643 |
+
float global_exp_sum = 0.0f;
|
644 |
+
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
645 |
+
float l = shared_max_logits[i];
|
646 |
+
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
647 |
+
global_exp_sum += rescaled_exp_sum;
|
648 |
+
shared_exp_sums[i] = rescaled_exp_sum;
|
649 |
+
}
|
650 |
+
__syncthreads();
|
651 |
+
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
652 |
+
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
653 |
+
|
654 |
+
// Aggregate tmp_out to out.
|
655 |
+
const scalar_t* tmp_out_ptr =
|
656 |
+
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
657 |
+
head_idx * max_num_partitions * HEAD_SIZE;
|
658 |
+
scalar_t* out_ptr =
|
659 |
+
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
660 |
+
#pragma unroll
|
661 |
+
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
662 |
+
float acc = 0.0f;
|
663 |
+
for (int j = 0; j < num_partitions; ++j) {
|
664 |
+
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
665 |
+
inv_global_exp_sum;
|
666 |
+
}
|
667 |
+
from_float(out_ptr[i], acc);
|
668 |
+
}
|
669 |
+
}
|
670 |
+
|
671 |
+
} // namespace vllm
|
672 |
+
|
673 |
+
#undef WARP_SIZE
|
674 |
+
#undef MAX
|
675 |
+
#undef MIN
|
676 |
+
#undef DIVIDE_ROUND_UP
|
paged-attention/attention/attention_utils.cuh
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "../cuda_compat.h"
|
22 |
+
#include "attention_dtypes.h"
|
23 |
+
|
24 |
+
#include <float.h>
|
25 |
+
#include <type_traits>
|
26 |
+
|
27 |
+
namespace vllm {
|
28 |
+
|
29 |
+
// Q*K^T operation.
|
30 |
+
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
31 |
+
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
32 |
+
using A_vec = typename FloatVec<Vec>::Type;
|
33 |
+
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
34 |
+
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
35 |
+
#pragma unroll
|
36 |
+
for (int ii = 1; ii < N; ++ii) {
|
37 |
+
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
|
38 |
+
}
|
39 |
+
|
40 |
+
// Finalize the reduction across lanes.
|
41 |
+
float qk = sum(qk_vec);
|
42 |
+
#pragma unroll
|
43 |
+
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
44 |
+
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
45 |
+
}
|
46 |
+
return qk;
|
47 |
+
}
|
48 |
+
|
49 |
+
template <typename T, int THREAD_GROUP_SIZE>
|
50 |
+
struct Qk_dot {
|
51 |
+
template <typename Vec, int N>
|
52 |
+
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
53 |
+
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
54 |
+
}
|
55 |
+
};
|
56 |
+
|
57 |
+
} // namespace vllm
|
paged-attention/attention/dtype_bfloat16.cuh
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* and
|
5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
6 |
+
* Copyright (c) 2023, The vLLM team.
|
7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
#pragma once
|
22 |
+
|
23 |
+
#include "attention_generic.cuh"
|
24 |
+
#include "dtype_float32.cuh"
|
25 |
+
|
26 |
+
#ifndef USE_ROCM
|
27 |
+
#include <cuda_bf16.h>
|
28 |
+
#include <cuda_fp16.h>
|
29 |
+
#else
|
30 |
+
#include <hip/hip_bf16.h>
|
31 |
+
#include <hip/hip_fp16.h>
|
32 |
+
|
33 |
+
typedef __hip_bfloat162 __nv_bfloat162;
|
34 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
35 |
+
#endif
|
36 |
+
|
37 |
+
#include <stdint.h>
|
38 |
+
|
39 |
+
namespace vllm {
|
40 |
+
|
41 |
+
// Define custom BF16 vector data types.
|
42 |
+
struct bf16_4_t {
|
43 |
+
__nv_bfloat162 x;
|
44 |
+
__nv_bfloat162 y;
|
45 |
+
};
|
46 |
+
|
47 |
+
struct bf16_8_t {
|
48 |
+
__nv_bfloat162 x;
|
49 |
+
__nv_bfloat162 y;
|
50 |
+
__nv_bfloat162 z;
|
51 |
+
__nv_bfloat162 w;
|
52 |
+
};
|
53 |
+
|
54 |
+
// BF16 vector types for Q, K, V.
|
55 |
+
template <>
|
56 |
+
struct Vec<__nv_bfloat16, 1> {
|
57 |
+
using Type = __nv_bfloat16;
|
58 |
+
};
|
59 |
+
template <>
|
60 |
+
struct Vec<__nv_bfloat16, 2> {
|
61 |
+
using Type = __nv_bfloat162;
|
62 |
+
};
|
63 |
+
template <>
|
64 |
+
struct Vec<__nv_bfloat16, 4> {
|
65 |
+
using Type = bf16_4_t;
|
66 |
+
};
|
67 |
+
template <>
|
68 |
+
struct Vec<__nv_bfloat16, 8> {
|
69 |
+
using Type = bf16_8_t;
|
70 |
+
};
|
71 |
+
|
72 |
+
// FP32 accumulator vector types corresponding to Vec.
|
73 |
+
template <>
|
74 |
+
struct FloatVec<__nv_bfloat16> {
|
75 |
+
using Type = float;
|
76 |
+
};
|
77 |
+
template <>
|
78 |
+
struct FloatVec<__nv_bfloat162> {
|
79 |
+
using Type = float2;
|
80 |
+
};
|
81 |
+
template <>
|
82 |
+
struct FloatVec<bf16_4_t> {
|
83 |
+
using Type = Float4_;
|
84 |
+
};
|
85 |
+
template <>
|
86 |
+
struct FloatVec<bf16_8_t> {
|
87 |
+
using Type = Float8_;
|
88 |
+
};
|
89 |
+
|
90 |
+
// Utility functions for type conversions.
|
91 |
+
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
92 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
93 |
+
assert(false);
|
94 |
+
#else
|
95 |
+
return __bfloat1622float2(val);
|
96 |
+
#endif
|
97 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
98 |
+
}
|
99 |
+
|
100 |
+
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
101 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
102 |
+
assert(false);
|
103 |
+
#else
|
104 |
+
return __bfloat162bfloat162(val);
|
105 |
+
#endif
|
106 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
107 |
+
}
|
108 |
+
|
109 |
+
// Vector addition.
|
110 |
+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
111 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
112 |
+
assert(false);
|
113 |
+
#else
|
114 |
+
#ifndef USE_ROCM
|
115 |
+
return a + b;
|
116 |
+
#else
|
117 |
+
return __hadd(a, b);
|
118 |
+
#endif
|
119 |
+
#endif
|
120 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
121 |
+
}
|
122 |
+
|
123 |
+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
124 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
125 |
+
assert(false);
|
126 |
+
#else
|
127 |
+
return __hadd2(a, b);
|
128 |
+
#endif
|
129 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
130 |
+
}
|
131 |
+
|
132 |
+
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
133 |
+
bf16_4_t c;
|
134 |
+
c.x = add(a.x, b.x);
|
135 |
+
c.y = add(a.y, b.y);
|
136 |
+
return c;
|
137 |
+
}
|
138 |
+
|
139 |
+
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
140 |
+
bf16_8_t c;
|
141 |
+
c.x = add(a.x, b.x);
|
142 |
+
c.y = add(a.y, b.y);
|
143 |
+
c.z = add(a.z, b.z);
|
144 |
+
c.w = add(a.w, b.w);
|
145 |
+
return c;
|
146 |
+
}
|
147 |
+
|
148 |
+
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
149 |
+
float2 fa = bf1622float2(a);
|
150 |
+
return add(fa, fb);
|
151 |
+
}
|
152 |
+
|
153 |
+
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
154 |
+
Float4_ fc;
|
155 |
+
fc.x = add(a.x, fb.x);
|
156 |
+
fc.y = add(a.y, fb.y);
|
157 |
+
return fc;
|
158 |
+
}
|
159 |
+
|
160 |
+
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
161 |
+
Float8_ fc;
|
162 |
+
fc.x = add(a.x, fb.x);
|
163 |
+
fc.y = add(a.y, fb.y);
|
164 |
+
fc.z = add(a.z, fb.z);
|
165 |
+
fc.w = add(a.w, fb.w);
|
166 |
+
return fc;
|
167 |
+
}
|
168 |
+
|
169 |
+
// Vector multiplication.
|
170 |
+
template <>
|
171 |
+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
172 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
173 |
+
assert(false);
|
174 |
+
#else
|
175 |
+
return __hmul(a, b);
|
176 |
+
#endif
|
177 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
178 |
+
}
|
179 |
+
|
180 |
+
template <>
|
181 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
182 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
183 |
+
assert(false);
|
184 |
+
#else
|
185 |
+
return __hmul2(a, b);
|
186 |
+
#endif
|
187 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
188 |
+
}
|
189 |
+
|
190 |
+
template <>
|
191 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
192 |
+
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
193 |
+
}
|
194 |
+
|
195 |
+
template <>
|
196 |
+
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
197 |
+
bf16_4_t c;
|
198 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
199 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
200 |
+
return c;
|
201 |
+
}
|
202 |
+
|
203 |
+
template <>
|
204 |
+
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
205 |
+
__nv_bfloat162 s = bf162bf162(a);
|
206 |
+
bf16_4_t c;
|
207 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
208 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
209 |
+
return c;
|
210 |
+
}
|
211 |
+
|
212 |
+
template <>
|
213 |
+
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
214 |
+
bf16_8_t c;
|
215 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
216 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
217 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
218 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
219 |
+
return c;
|
220 |
+
}
|
221 |
+
|
222 |
+
template <>
|
223 |
+
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
224 |
+
__nv_bfloat162 s = bf162bf162(a);
|
225 |
+
bf16_8_t c;
|
226 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
227 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
228 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
229 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
230 |
+
return c;
|
231 |
+
}
|
232 |
+
|
233 |
+
template <>
|
234 |
+
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
235 |
+
float fa = __bfloat162float(a);
|
236 |
+
float fb = __bfloat162float(b);
|
237 |
+
return fa * fb;
|
238 |
+
}
|
239 |
+
|
240 |
+
template <>
|
241 |
+
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
242 |
+
float2 fa = bf1622float2(a);
|
243 |
+
float2 fb = bf1622float2(b);
|
244 |
+
return mul<float2, float2, float2>(fa, fb);
|
245 |
+
}
|
246 |
+
|
247 |
+
template <>
|
248 |
+
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
249 |
+
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
250 |
+
}
|
251 |
+
|
252 |
+
template <>
|
253 |
+
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
254 |
+
Float4_ fc;
|
255 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
256 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
257 |
+
return fc;
|
258 |
+
}
|
259 |
+
|
260 |
+
template <>
|
261 |
+
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
262 |
+
__nv_bfloat162 s = bf162bf162(a);
|
263 |
+
Float4_ fc;
|
264 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
265 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
266 |
+
return fc;
|
267 |
+
}
|
268 |
+
|
269 |
+
template <>
|
270 |
+
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
271 |
+
Float8_ fc;
|
272 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
273 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
274 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
275 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
276 |
+
return fc;
|
277 |
+
}
|
278 |
+
|
279 |
+
template <>
|
280 |
+
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
281 |
+
__nv_bfloat162 s = bf162bf162(a);
|
282 |
+
Float8_ fc;
|
283 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
284 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
285 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
286 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
287 |
+
return fc;
|
288 |
+
}
|
289 |
+
|
290 |
+
// Vector fused multiply-add.
|
291 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
292 |
+
__nv_bfloat162 c) {
|
293 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
294 |
+
assert(false);
|
295 |
+
#else
|
296 |
+
return __hfma2(a, b, c);
|
297 |
+
#endif
|
298 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
299 |
+
}
|
300 |
+
|
301 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
302 |
+
__nv_bfloat162 c) {
|
303 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
304 |
+
assert(false);
|
305 |
+
#else
|
306 |
+
return __hfma2(bf162bf162(a), b, c);
|
307 |
+
#endif
|
308 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
309 |
+
}
|
310 |
+
|
311 |
+
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
312 |
+
bf16_4_t d;
|
313 |
+
d.x = fma(a.x, b.x, c.x);
|
314 |
+
d.y = fma(a.y, b.y, c.y);
|
315 |
+
return d;
|
316 |
+
}
|
317 |
+
|
318 |
+
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
319 |
+
__nv_bfloat162 s = bf162bf162(a);
|
320 |
+
bf16_4_t d;
|
321 |
+
d.x = fma(s, b.x, c.x);
|
322 |
+
d.y = fma(s, b.y, c.y);
|
323 |
+
return d;
|
324 |
+
}
|
325 |
+
|
326 |
+
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
327 |
+
bf16_8_t d;
|
328 |
+
d.x = fma(a.x, b.x, c.x);
|
329 |
+
d.y = fma(a.y, b.y, c.y);
|
330 |
+
d.z = fma(a.z, b.z, c.z);
|
331 |
+
d.w = fma(a.w, b.w, c.w);
|
332 |
+
return d;
|
333 |
+
}
|
334 |
+
|
335 |
+
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
336 |
+
__nv_bfloat162 s = bf162bf162(a);
|
337 |
+
bf16_8_t d;
|
338 |
+
d.x = fma(s, b.x, c.x);
|
339 |
+
d.y = fma(s, b.y, c.y);
|
340 |
+
d.z = fma(s, b.z, c.z);
|
341 |
+
d.w = fma(s, b.w, c.w);
|
342 |
+
return d;
|
343 |
+
}
|
344 |
+
|
345 |
+
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
346 |
+
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
347 |
+
}
|
348 |
+
|
349 |
+
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
350 |
+
float2 fa = bf1622float2(a);
|
351 |
+
float2 fb = bf1622float2(b);
|
352 |
+
return fma(fa, fb, fc);
|
353 |
+
}
|
354 |
+
|
355 |
+
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
356 |
+
return fma(bf162bf162(a), b, fc);
|
357 |
+
}
|
358 |
+
|
359 |
+
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
360 |
+
Float4_ fd;
|
361 |
+
fd.x = fma(a.x, b.x, fc.x);
|
362 |
+
fd.y = fma(a.y, b.y, fc.y);
|
363 |
+
return fd;
|
364 |
+
}
|
365 |
+
|
366 |
+
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
367 |
+
__nv_bfloat162 s = bf162bf162(a);
|
368 |
+
Float4_ fd;
|
369 |
+
fd.x = fma(s, b.x, fc.x);
|
370 |
+
fd.y = fma(s, b.y, fc.y);
|
371 |
+
return fd;
|
372 |
+
}
|
373 |
+
|
374 |
+
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
375 |
+
Float8_ fd;
|
376 |
+
fd.x = fma(a.x, b.x, fc.x);
|
377 |
+
fd.y = fma(a.y, b.y, fc.y);
|
378 |
+
fd.z = fma(a.z, b.z, fc.z);
|
379 |
+
fd.w = fma(a.w, b.w, fc.w);
|
380 |
+
return fd;
|
381 |
+
}
|
382 |
+
|
383 |
+
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
384 |
+
__nv_bfloat162 s = bf162bf162(a);
|
385 |
+
Float8_ fd;
|
386 |
+
fd.x = fma(s, b.x, fc.x);
|
387 |
+
fd.y = fma(s, b.y, fc.y);
|
388 |
+
fd.z = fma(s, b.z, fc.z);
|
389 |
+
fd.w = fma(s, b.w, fc.w);
|
390 |
+
return fd;
|
391 |
+
}
|
392 |
+
|
393 |
+
// Vector sum.
|
394 |
+
template <>
|
395 |
+
inline __device__ float sum(__nv_bfloat16 v) {
|
396 |
+
return __bfloat162float(v);
|
397 |
+
}
|
398 |
+
|
399 |
+
template <>
|
400 |
+
inline __device__ float sum(__nv_bfloat162 v) {
|
401 |
+
float2 vf = bf1622float2(v);
|
402 |
+
return vf.x + vf.y;
|
403 |
+
}
|
404 |
+
|
405 |
+
template <>
|
406 |
+
inline __device__ float sum(bf16_4_t v) {
|
407 |
+
return sum(v.x) + sum(v.y);
|
408 |
+
}
|
409 |
+
|
410 |
+
template <>
|
411 |
+
inline __device__ float sum(bf16_8_t v) {
|
412 |
+
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
413 |
+
}
|
414 |
+
|
415 |
+
// From float32 to bfloat16.
|
416 |
+
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
417 |
+
dst = __float2bfloat16(src);
|
418 |
+
}
|
419 |
+
|
420 |
+
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
421 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
422 |
+
assert(false);
|
423 |
+
#else
|
424 |
+
dst = __float22bfloat162_rn(src);
|
425 |
+
#endif
|
426 |
+
}
|
427 |
+
|
428 |
+
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
429 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
430 |
+
assert(false);
|
431 |
+
#else
|
432 |
+
dst.x = __float22bfloat162_rn(src.x);
|
433 |
+
dst.y = __float22bfloat162_rn(src.y);
|
434 |
+
#endif
|
435 |
+
}
|
436 |
+
|
437 |
+
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
438 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
439 |
+
assert(false);
|
440 |
+
#else
|
441 |
+
dst.x = __float22bfloat162_rn(src.x);
|
442 |
+
dst.y = __float22bfloat162_rn(src.y);
|
443 |
+
dst.z = __float22bfloat162_rn(src.z);
|
444 |
+
dst.w = __float22bfloat162_rn(src.w);
|
445 |
+
#endif
|
446 |
+
}
|
447 |
+
|
448 |
+
// From bfloat16 to float32.
|
449 |
+
inline __device__ float to_float(__nv_bfloat16 u) {
|
450 |
+
return __bfloat162float(u);
|
451 |
+
}
|
452 |
+
|
453 |
+
// Zero-out a variable.
|
454 |
+
inline __device__ void zero(__nv_bfloat16& dst) {
|
455 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
456 |
+
assert(false);
|
457 |
+
#else
|
458 |
+
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
459 |
+
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
460 |
+
#endif
|
461 |
+
}
|
462 |
+
|
463 |
+
} // namespace vllm
|
paged-attention/attention/dtype_float16.cuh
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* and
|
5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
6 |
+
* Copyright (c) 2023, The vLLM team.
|
7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
#pragma once
|
22 |
+
|
23 |
+
#include "attention_generic.cuh"
|
24 |
+
#include "dtype_float32.cuh"
|
25 |
+
|
26 |
+
#ifdef USE_ROCM
|
27 |
+
#include <hip/hip_fp16.h>
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#include <stdint.h>
|
31 |
+
|
32 |
+
namespace vllm {
|
33 |
+
|
34 |
+
// FP16 vector types for Q, K, V.
|
35 |
+
template <>
|
36 |
+
struct Vec<uint16_t, 1> {
|
37 |
+
using Type = uint16_t;
|
38 |
+
};
|
39 |
+
template <>
|
40 |
+
struct Vec<uint16_t, 2> {
|
41 |
+
using Type = uint32_t;
|
42 |
+
};
|
43 |
+
template <>
|
44 |
+
struct Vec<uint16_t, 4> {
|
45 |
+
using Type = uint2;
|
46 |
+
};
|
47 |
+
template <>
|
48 |
+
struct Vec<uint16_t, 8> {
|
49 |
+
using Type = uint4;
|
50 |
+
};
|
51 |
+
|
52 |
+
// FP32 accumulator vector types corresponding to Vec.
|
53 |
+
template <>
|
54 |
+
struct FloatVec<uint16_t> {
|
55 |
+
using Type = float;
|
56 |
+
};
|
57 |
+
template <>
|
58 |
+
struct FloatVec<uint32_t> {
|
59 |
+
using Type = float2;
|
60 |
+
};
|
61 |
+
template <>
|
62 |
+
struct FloatVec<uint2> {
|
63 |
+
using Type = Float4_;
|
64 |
+
};
|
65 |
+
template <>
|
66 |
+
struct FloatVec<uint4> {
|
67 |
+
using Type = Float8_;
|
68 |
+
};
|
69 |
+
|
70 |
+
// Utility functions for type conversions.
|
71 |
+
inline __device__ uint32_t h0_h0(uint16_t a) {
|
72 |
+
#ifndef USE_ROCM
|
73 |
+
uint32_t b;
|
74 |
+
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
75 |
+
return b;
|
76 |
+
#else
|
77 |
+
union {
|
78 |
+
uint32_t u32;
|
79 |
+
uint16_t u16[2];
|
80 |
+
} tmp;
|
81 |
+
tmp.u16[0] = a;
|
82 |
+
tmp.u16[1] = a;
|
83 |
+
return tmp.u32;
|
84 |
+
#endif
|
85 |
+
}
|
86 |
+
|
87 |
+
inline __device__ float half_to_float(uint16_t h) {
|
88 |
+
float f;
|
89 |
+
#ifndef USE_ROCM
|
90 |
+
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
91 |
+
#else
|
92 |
+
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
93 |
+
#endif
|
94 |
+
return f;
|
95 |
+
}
|
96 |
+
|
97 |
+
inline __device__ float2 half2_to_float2(uint32_t v) {
|
98 |
+
#ifndef USE_ROCM
|
99 |
+
uint16_t lo, hi;
|
100 |
+
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
101 |
+
return make_float2(half_to_float(lo), half_to_float(hi));
|
102 |
+
#else
|
103 |
+
union {
|
104 |
+
uint32_t u32;
|
105 |
+
uint16_t u16[2];
|
106 |
+
} tmp;
|
107 |
+
tmp.u32 = v;
|
108 |
+
float2 ret;
|
109 |
+
ret.x = half_to_float(tmp.u16[0]);
|
110 |
+
ret.y = half_to_float(tmp.u16[1]);
|
111 |
+
return ret;
|
112 |
+
#endif
|
113 |
+
}
|
114 |
+
|
115 |
+
inline __device__ uint16_t float_to_half(float f) {
|
116 |
+
union {
|
117 |
+
uint32_t u32;
|
118 |
+
uint16_t u16[2];
|
119 |
+
} tmp;
|
120 |
+
#ifndef USE_ROCM
|
121 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
122 |
+
#else
|
123 |
+
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
124 |
+
#endif
|
125 |
+
return tmp.u16[0];
|
126 |
+
}
|
127 |
+
|
128 |
+
inline __device__ uint32_t float2_to_half2(float2 f) {
|
129 |
+
union {
|
130 |
+
uint32_t u32;
|
131 |
+
uint16_t u16[2];
|
132 |
+
} tmp;
|
133 |
+
#ifndef USE_ROCM
|
134 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
135 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
136 |
+
: "=r"(tmp.u32)
|
137 |
+
: "f"(f.y), "f"(f.x));
|
138 |
+
#else
|
139 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
140 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
141 |
+
#endif
|
142 |
+
#else
|
143 |
+
tmp.u16[0] = float_to_half(f.x);
|
144 |
+
tmp.u16[1] = float_to_half(f.y);
|
145 |
+
#endif
|
146 |
+
return tmp.u32;
|
147 |
+
}
|
148 |
+
|
149 |
+
// Vector addition.
|
150 |
+
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
151 |
+
uint16_t c;
|
152 |
+
#ifndef USE_ROCM
|
153 |
+
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
154 |
+
#else
|
155 |
+
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
156 |
+
#endif
|
157 |
+
return c;
|
158 |
+
}
|
159 |
+
|
160 |
+
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
161 |
+
uint32_t c;
|
162 |
+
#ifndef USE_ROCM
|
163 |
+
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
164 |
+
#else
|
165 |
+
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
166 |
+
#endif
|
167 |
+
return c;
|
168 |
+
}
|
169 |
+
|
170 |
+
inline __device__ uint2 add(uint2 a, uint2 b) {
|
171 |
+
uint2 c;
|
172 |
+
c.x = add(a.x, b.x);
|
173 |
+
c.y = add(a.y, b.y);
|
174 |
+
return c;
|
175 |
+
}
|
176 |
+
|
177 |
+
inline __device__ uint4 add(uint4 a, uint4 b) {
|
178 |
+
uint4 c;
|
179 |
+
c.x = add(a.x, b.x);
|
180 |
+
c.y = add(a.y, b.y);
|
181 |
+
c.z = add(a.z, b.z);
|
182 |
+
c.w = add(a.w, b.w);
|
183 |
+
return c;
|
184 |
+
}
|
185 |
+
|
186 |
+
inline __device__ float2 add(uint32_t a, float2 fb) {
|
187 |
+
float2 fa = half2_to_float2(a);
|
188 |
+
return add(fa, fb);
|
189 |
+
}
|
190 |
+
|
191 |
+
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
192 |
+
Float4_ fc;
|
193 |
+
fc.x = add(a.x, fb.x);
|
194 |
+
fc.y = add(a.y, fb.y);
|
195 |
+
return fc;
|
196 |
+
}
|
197 |
+
|
198 |
+
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
199 |
+
Float8_ fc;
|
200 |
+
fc.x = add(a.x, fb.x);
|
201 |
+
fc.y = add(a.y, fb.y);
|
202 |
+
fc.z = add(a.z, fb.z);
|
203 |
+
fc.w = add(a.w, fb.w);
|
204 |
+
return fc;
|
205 |
+
}
|
206 |
+
|
207 |
+
// Vector multiplication.
|
208 |
+
template <>
|
209 |
+
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
210 |
+
uint16_t c;
|
211 |
+
#ifndef USE_ROCM
|
212 |
+
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
213 |
+
#else
|
214 |
+
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
215 |
+
#endif
|
216 |
+
return c;
|
217 |
+
}
|
218 |
+
|
219 |
+
template <>
|
220 |
+
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
221 |
+
uint32_t c;
|
222 |
+
#ifndef USE_ROCM
|
223 |
+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
224 |
+
#else
|
225 |
+
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
226 |
+
#endif
|
227 |
+
return c;
|
228 |
+
}
|
229 |
+
|
230 |
+
template <>
|
231 |
+
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
232 |
+
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
233 |
+
}
|
234 |
+
|
235 |
+
template <>
|
236 |
+
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
237 |
+
uint2 c;
|
238 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
239 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
240 |
+
return c;
|
241 |
+
}
|
242 |
+
|
243 |
+
template <>
|
244 |
+
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
245 |
+
uint32_t s = h0_h0(a);
|
246 |
+
uint2 c;
|
247 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
248 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
249 |
+
return c;
|
250 |
+
}
|
251 |
+
|
252 |
+
template <>
|
253 |
+
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
254 |
+
uint4 c;
|
255 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
256 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
257 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
258 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
259 |
+
return c;
|
260 |
+
}
|
261 |
+
|
262 |
+
template <>
|
263 |
+
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
264 |
+
uint32_t s = h0_h0(a);
|
265 |
+
uint4 c;
|
266 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
267 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
268 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
269 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
270 |
+
return c;
|
271 |
+
}
|
272 |
+
|
273 |
+
template <>
|
274 |
+
inline __device__ float mul(uint16_t a, uint16_t b) {
|
275 |
+
float fa = half_to_float(a);
|
276 |
+
float fb = half_to_float(b);
|
277 |
+
return fa * fb;
|
278 |
+
}
|
279 |
+
|
280 |
+
template <>
|
281 |
+
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
282 |
+
float2 fa = half2_to_float2(a);
|
283 |
+
float2 fb = half2_to_float2(b);
|
284 |
+
return mul<float2, float2, float2>(fa, fb);
|
285 |
+
}
|
286 |
+
|
287 |
+
template <>
|
288 |
+
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
289 |
+
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
290 |
+
}
|
291 |
+
|
292 |
+
template <>
|
293 |
+
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
294 |
+
Float4_ fc;
|
295 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
296 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
297 |
+
return fc;
|
298 |
+
}
|
299 |
+
|
300 |
+
template <>
|
301 |
+
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
302 |
+
uint32_t s = h0_h0(a);
|
303 |
+
Float4_ fc;
|
304 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
305 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
306 |
+
return fc;
|
307 |
+
}
|
308 |
+
|
309 |
+
template <>
|
310 |
+
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
311 |
+
Float8_ fc;
|
312 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
313 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
314 |
+
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
315 |
+
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
316 |
+
return fc;
|
317 |
+
}
|
318 |
+
|
319 |
+
template <>
|
320 |
+
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
321 |
+
uint32_t s = h0_h0(a);
|
322 |
+
Float8_ fc;
|
323 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
324 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
325 |
+
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
326 |
+
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
327 |
+
return fc;
|
328 |
+
}
|
329 |
+
|
330 |
+
// Vector fused multiply-add.
|
331 |
+
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
332 |
+
uint32_t d;
|
333 |
+
#ifndef USE_ROCM
|
334 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
335 |
+
: "=r"(d)
|
336 |
+
: "r"(a), "r"(b), "r"(c));
|
337 |
+
#else
|
338 |
+
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
339 |
+
: "=v"(d)
|
340 |
+
: "v"(a), "v"(b), "v"(c));
|
341 |
+
#endif
|
342 |
+
return d;
|
343 |
+
}
|
344 |
+
|
345 |
+
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
346 |
+
return fma(h0_h0(a), b, c);
|
347 |
+
}
|
348 |
+
|
349 |
+
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
350 |
+
uint2 d;
|
351 |
+
d.x = fma(a.x, b.x, c.x);
|
352 |
+
d.y = fma(a.y, b.y, c.y);
|
353 |
+
return d;
|
354 |
+
}
|
355 |
+
|
356 |
+
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
357 |
+
uint32_t s = h0_h0(a);
|
358 |
+
uint2 d;
|
359 |
+
d.x = fma(s, b.x, c.x);
|
360 |
+
d.y = fma(s, b.y, c.y);
|
361 |
+
return d;
|
362 |
+
}
|
363 |
+
|
364 |
+
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
365 |
+
uint4 d;
|
366 |
+
d.x = fma(a.x, b.x, c.x);
|
367 |
+
d.y = fma(a.y, b.y, c.y);
|
368 |
+
d.z = fma(a.z, b.z, c.z);
|
369 |
+
d.w = fma(a.w, b.w, c.w);
|
370 |
+
return d;
|
371 |
+
}
|
372 |
+
|
373 |
+
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
374 |
+
uint32_t s = h0_h0(a);
|
375 |
+
uint4 d;
|
376 |
+
d.x = fma(s, b.x, c.x);
|
377 |
+
d.y = fma(s, b.y, c.y);
|
378 |
+
d.z = fma(s, b.z, c.z);
|
379 |
+
d.w = fma(s, b.w, c.w);
|
380 |
+
return d;
|
381 |
+
}
|
382 |
+
|
383 |
+
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
384 |
+
float fa = half_to_float(a);
|
385 |
+
float fb = half_to_float(b);
|
386 |
+
return fa * fb + fc;
|
387 |
+
}
|
388 |
+
|
389 |
+
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
390 |
+
float2 fa = half2_to_float2(a);
|
391 |
+
float2 fb = half2_to_float2(b);
|
392 |
+
return fma(fa, fb, fc);
|
393 |
+
}
|
394 |
+
|
395 |
+
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
396 |
+
return fma(h0_h0(a), b, fc);
|
397 |
+
}
|
398 |
+
|
399 |
+
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
400 |
+
Float4_ fd;
|
401 |
+
fd.x = fma(a.x, b.x, fc.x);
|
402 |
+
fd.y = fma(a.y, b.y, fc.y);
|
403 |
+
return fd;
|
404 |
+
}
|
405 |
+
|
406 |
+
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
407 |
+
uint32_t s = h0_h0(a);
|
408 |
+
Float4_ fd;
|
409 |
+
fd.x = fma(s, b.x, fc.x);
|
410 |
+
fd.y = fma(s, b.y, fc.y);
|
411 |
+
return fd;
|
412 |
+
}
|
413 |
+
|
414 |
+
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
415 |
+
Float8_ fd;
|
416 |
+
fd.x = fma(a.x, b.x, fc.x);
|
417 |
+
fd.y = fma(a.y, b.y, fc.y);
|
418 |
+
fd.z = fma(a.z, b.z, fc.z);
|
419 |
+
fd.w = fma(a.w, b.w, fc.w);
|
420 |
+
return fd;
|
421 |
+
}
|
422 |
+
|
423 |
+
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
424 |
+
uint32_t s = h0_h0(a);
|
425 |
+
Float8_ fd;
|
426 |
+
fd.x = fma(s, b.x, fc.x);
|
427 |
+
fd.y = fma(s, b.y, fc.y);
|
428 |
+
fd.z = fma(s, b.z, fc.z);
|
429 |
+
fd.w = fma(s, b.w, fc.w);
|
430 |
+
return fd;
|
431 |
+
}
|
432 |
+
|
433 |
+
// Vector sum.
|
434 |
+
template <>
|
435 |
+
inline __device__ float sum(uint16_t v) {
|
436 |
+
return half_to_float(v);
|
437 |
+
}
|
438 |
+
|
439 |
+
template <>
|
440 |
+
inline __device__ float sum(uint32_t v) {
|
441 |
+
float2 tmp = half2_to_float2(v);
|
442 |
+
return tmp.x + tmp.y;
|
443 |
+
}
|
444 |
+
|
445 |
+
template <>
|
446 |
+
inline __device__ float sum(uint2 v) {
|
447 |
+
uint32_t c = add(v.x, v.y);
|
448 |
+
return sum(c);
|
449 |
+
}
|
450 |
+
|
451 |
+
template <>
|
452 |
+
inline __device__ float sum(uint4 v) {
|
453 |
+
uint32_t c = add(v.x, v.y);
|
454 |
+
c = add(c, v.z);
|
455 |
+
c = add(c, v.w);
|
456 |
+
return sum(c);
|
457 |
+
}
|
458 |
+
|
459 |
+
// From float32 to float16.
|
460 |
+
inline __device__ void from_float(uint16_t& dst, float src) {
|
461 |
+
dst = float_to_half(src);
|
462 |
+
}
|
463 |
+
|
464 |
+
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
465 |
+
dst = float2_to_half2(src);
|
466 |
+
}
|
467 |
+
|
468 |
+
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
469 |
+
dst.x = float2_to_half2(src.x);
|
470 |
+
dst.y = float2_to_half2(src.y);
|
471 |
+
}
|
472 |
+
|
473 |
+
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
474 |
+
dst.x = float2_to_half2(src.x);
|
475 |
+
dst.y = float2_to_half2(src.y);
|
476 |
+
dst.z = float2_to_half2(src.z);
|
477 |
+
dst.w = float2_to_half2(src.w);
|
478 |
+
}
|
479 |
+
|
480 |
+
// From float16 to float32.
|
481 |
+
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
482 |
+
|
483 |
+
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
484 |
+
|
485 |
+
inline __device__ Float4_ to_float(uint2 u) {
|
486 |
+
Float4_ tmp;
|
487 |
+
tmp.x = half2_to_float2(u.x);
|
488 |
+
tmp.y = half2_to_float2(u.y);
|
489 |
+
return tmp;
|
490 |
+
}
|
491 |
+
|
492 |
+
inline __device__ Float8_ to_float(uint4 u) {
|
493 |
+
Float8_ tmp;
|
494 |
+
tmp.x = half2_to_float2(u.x);
|
495 |
+
tmp.y = half2_to_float2(u.y);
|
496 |
+
tmp.z = half2_to_float2(u.z);
|
497 |
+
tmp.w = half2_to_float2(u.w);
|
498 |
+
return tmp;
|
499 |
+
}
|
500 |
+
|
501 |
+
// Zero-out a variable.
|
502 |
+
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
503 |
+
|
504 |
+
} // namespace vllm
|
paged-attention/attention/dtype_float32.cuh
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* and
|
5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
6 |
+
* Copyright (c) 2023, The vLLM team.
|
7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
#pragma once
|
22 |
+
|
23 |
+
#include "attention_generic.cuh"
|
24 |
+
|
25 |
+
#include <stdint.h>
|
26 |
+
|
27 |
+
namespace vllm {
|
28 |
+
|
29 |
+
// Define custom FP32 vector data types.
|
30 |
+
struct Float4_ {
|
31 |
+
float2 x;
|
32 |
+
float2 y;
|
33 |
+
};
|
34 |
+
|
35 |
+
struct Float8_ {
|
36 |
+
float2 x;
|
37 |
+
float2 y;
|
38 |
+
float2 z;
|
39 |
+
float2 w;
|
40 |
+
};
|
41 |
+
|
42 |
+
// FP32 vector types for Q, K, V.
|
43 |
+
template <>
|
44 |
+
struct Vec<float, 1> {
|
45 |
+
using Type = float;
|
46 |
+
};
|
47 |
+
template <>
|
48 |
+
struct Vec<float, 2> {
|
49 |
+
using Type = float2;
|
50 |
+
};
|
51 |
+
template <>
|
52 |
+
struct Vec<float, 4> {
|
53 |
+
using Type = float4;
|
54 |
+
};
|
55 |
+
|
56 |
+
// FP32 accumulator vector types corresponding to Vec.
|
57 |
+
template <>
|
58 |
+
struct FloatVec<float> {
|
59 |
+
using Type = float;
|
60 |
+
};
|
61 |
+
template <>
|
62 |
+
struct FloatVec<float2> {
|
63 |
+
using Type = float2;
|
64 |
+
};
|
65 |
+
template <>
|
66 |
+
struct FloatVec<float4> {
|
67 |
+
using Type = float4;
|
68 |
+
};
|
69 |
+
|
70 |
+
// Vector addition.
|
71 |
+
inline __device__ float add(float a, float b) { return a + b; }
|
72 |
+
|
73 |
+
inline __device__ float2 add(float2 a, float2 b) {
|
74 |
+
float2 c;
|
75 |
+
c.x = add(a.x, b.x);
|
76 |
+
c.y = add(a.y, b.y);
|
77 |
+
return c;
|
78 |
+
}
|
79 |
+
|
80 |
+
inline __device__ float4 add(float4 a, float4 b) {
|
81 |
+
float4 c;
|
82 |
+
c.x = add(a.x, b.x);
|
83 |
+
c.y = add(a.y, b.y);
|
84 |
+
c.z = add(a.z, b.z);
|
85 |
+
c.w = add(a.w, b.w);
|
86 |
+
return c;
|
87 |
+
}
|
88 |
+
|
89 |
+
// Vector multiplication.
|
90 |
+
template <>
|
91 |
+
inline __device__ float mul<float, float>(float a, float b) {
|
92 |
+
return a * b;
|
93 |
+
}
|
94 |
+
|
95 |
+
template <>
|
96 |
+
inline __device__ float2 mul(float2 a, float2 b) {
|
97 |
+
float2 c;
|
98 |
+
c.x = a.x * b.x;
|
99 |
+
c.y = a.y * b.y;
|
100 |
+
return c;
|
101 |
+
}
|
102 |
+
|
103 |
+
template <>
|
104 |
+
inline __device__ float2 mul(float a, float2 b) {
|
105 |
+
float2 c;
|
106 |
+
c.x = a * b.x;
|
107 |
+
c.y = a * b.y;
|
108 |
+
return c;
|
109 |
+
}
|
110 |
+
|
111 |
+
template <>
|
112 |
+
inline __device__ float4 mul(float4 a, float4 b) {
|
113 |
+
float4 c;
|
114 |
+
c.x = a.x * b.x;
|
115 |
+
c.y = a.y * b.y;
|
116 |
+
c.z = a.z * b.z;
|
117 |
+
c.w = a.w * b.w;
|
118 |
+
return c;
|
119 |
+
}
|
120 |
+
|
121 |
+
template <>
|
122 |
+
inline __device__ float4 mul(float a, float4 b) {
|
123 |
+
float4 c;
|
124 |
+
c.x = a * b.x;
|
125 |
+
c.y = a * b.y;
|
126 |
+
c.z = a * b.z;
|
127 |
+
c.w = a * b.w;
|
128 |
+
return c;
|
129 |
+
}
|
130 |
+
|
131 |
+
// Vector fused multiply-add.
|
132 |
+
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
133 |
+
|
134 |
+
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
135 |
+
float2 d;
|
136 |
+
d.x = fma(a.x, b.x, c.x);
|
137 |
+
d.y = fma(a.y, b.y, c.y);
|
138 |
+
return d;
|
139 |
+
}
|
140 |
+
|
141 |
+
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
142 |
+
float2 d;
|
143 |
+
d.x = fma(a, b.x, c.x);
|
144 |
+
d.y = fma(a, b.y, c.y);
|
145 |
+
return d;
|
146 |
+
}
|
147 |
+
|
148 |
+
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
149 |
+
float4 d;
|
150 |
+
d.x = fma(a.x, b.x, c.x);
|
151 |
+
d.y = fma(a.y, b.y, c.y);
|
152 |
+
d.z = fma(a.z, b.z, c.z);
|
153 |
+
d.w = fma(a.w, b.w, c.w);
|
154 |
+
return d;
|
155 |
+
}
|
156 |
+
|
157 |
+
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
158 |
+
float4 d;
|
159 |
+
d.x = fma(a, b.x, c.x);
|
160 |
+
d.y = fma(a, b.y, c.y);
|
161 |
+
d.z = fma(a, b.z, c.z);
|
162 |
+
d.w = fma(a, b.w, c.w);
|
163 |
+
return d;
|
164 |
+
}
|
165 |
+
|
166 |
+
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
167 |
+
Float4_ d;
|
168 |
+
d.x = fma(a, b.x, c.x);
|
169 |
+
d.y = fma(a, b.y, c.y);
|
170 |
+
return d;
|
171 |
+
}
|
172 |
+
|
173 |
+
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
174 |
+
Float8_ d;
|
175 |
+
d.x = fma(a, b.x, c.x);
|
176 |
+
d.y = fma(a, b.y, c.y);
|
177 |
+
d.z = fma(a, b.z, c.z);
|
178 |
+
d.w = fma(a, b.w, c.w);
|
179 |
+
return d;
|
180 |
+
}
|
181 |
+
|
182 |
+
// Vector sum.
|
183 |
+
template <>
|
184 |
+
inline __device__ float sum(float v) {
|
185 |
+
return v;
|
186 |
+
}
|
187 |
+
|
188 |
+
template <>
|
189 |
+
inline __device__ float sum(float2 v) {
|
190 |
+
return v.x + v.y;
|
191 |
+
}
|
192 |
+
|
193 |
+
template <>
|
194 |
+
inline __device__ float sum(float4 v) {
|
195 |
+
return v.x + v.y + v.z + v.w;
|
196 |
+
}
|
197 |
+
|
198 |
+
template <>
|
199 |
+
inline __device__ float sum(Float4_ v) {
|
200 |
+
return v.x.x + v.x.y + v.y.x + v.y.y;
|
201 |
+
}
|
202 |
+
|
203 |
+
template <>
|
204 |
+
inline __device__ float sum(Float8_ v) {
|
205 |
+
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
206 |
+
}
|
207 |
+
|
208 |
+
// Vector dot product.
|
209 |
+
inline __device__ float dot(float a, float b) { return a * b; }
|
210 |
+
|
211 |
+
inline __device__ float dot(float2 a, float2 b) {
|
212 |
+
float2 c = mul<float2, float2, float2>(a, b);
|
213 |
+
return c.x + c.y;
|
214 |
+
}
|
215 |
+
|
216 |
+
inline __device__ float dot(Float4_ a, Float4_ b) {
|
217 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
218 |
+
acc = fma(a.y, b.y, acc);
|
219 |
+
return acc.x + acc.y;
|
220 |
+
}
|
221 |
+
|
222 |
+
inline __device__ float dot(Float8_ a, Float8_ b) {
|
223 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
224 |
+
acc = fma(a.y, b.y, acc);
|
225 |
+
acc = fma(a.z, b.z, acc);
|
226 |
+
acc = fma(a.w, b.w, acc);
|
227 |
+
return acc.x + acc.y;
|
228 |
+
}
|
229 |
+
|
230 |
+
// From float to float.
|
231 |
+
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
232 |
+
|
233 |
+
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
234 |
+
|
235 |
+
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
236 |
+
|
237 |
+
// From float to float.
|
238 |
+
inline __device__ float to_float(float u) { return u; }
|
239 |
+
|
240 |
+
inline __device__ float2 to_float(float2 u) { return u; }
|
241 |
+
|
242 |
+
inline __device__ float4 to_float(float4 u) { return u; }
|
243 |
+
|
244 |
+
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
245 |
+
|
246 |
+
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
247 |
+
|
248 |
+
// Zero-out a variable.
|
249 |
+
inline __device__ void zero(float& dst) { dst = 0.f; }
|
250 |
+
|
251 |
+
} // namespace vllm
|
paged-attention/attention/dtype_fp8.cuh
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "attention_generic.cuh"
|
4 |
+
|
5 |
+
#include <stdint.h>
|
6 |
+
#ifdef ENABLE_FP8
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cuda_fp8.h>
|
9 |
+
#endif // USE_ROCM
|
10 |
+
#endif // ENABLE_FP8
|
11 |
+
|
12 |
+
namespace vllm {
|
13 |
+
|
14 |
+
enum class Fp8KVCacheDataType {
|
15 |
+
kAuto = 0,
|
16 |
+
kFp8E4M3 = 1,
|
17 |
+
kFp8E5M2 = 2,
|
18 |
+
};
|
19 |
+
|
20 |
+
// fp8 vector types for quantization of kv cache
|
21 |
+
template <>
|
22 |
+
struct Vec<uint8_t, 1> {
|
23 |
+
using Type = uint8_t;
|
24 |
+
};
|
25 |
+
|
26 |
+
template <>
|
27 |
+
struct Vec<uint8_t, 2> {
|
28 |
+
using Type = uint16_t;
|
29 |
+
};
|
30 |
+
|
31 |
+
template <>
|
32 |
+
struct Vec<uint8_t, 4> {
|
33 |
+
using Type = uint32_t;
|
34 |
+
};
|
35 |
+
|
36 |
+
template <>
|
37 |
+
struct Vec<uint8_t, 8> {
|
38 |
+
using Type = uint2;
|
39 |
+
};
|
40 |
+
|
41 |
+
} // namespace vllm
|
paged-attention/attention/paged_attention_v1.cu
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
|
20 |
+
#include "attention_kernels.cuh"
|
21 |
+
|
22 |
+
#ifndef USE_ROCM
|
23 |
+
#define WARP_SIZE 32
|
24 |
+
#else
|
25 |
+
#define WARP_SIZE warpSize
|
26 |
+
#endif
|
27 |
+
|
28 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
29 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
30 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
31 |
+
|
32 |
+
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
33 |
+
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
34 |
+
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
|
35 |
+
BLOCK_SIZE, NUM_THREADS, \
|
36 |
+
KV_DTYPE, IS_BLOCK_SPARSE>), \
|
37 |
+
shared_mem_size); \
|
38 |
+
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
39 |
+
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
|
40 |
+
<<<grid, block, shared_mem_size, stream>>>( \
|
41 |
+
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
42 |
+
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
43 |
+
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
44 |
+
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
|
45 |
+
blocksparse_vert_stride, blocksparse_block_size, \
|
46 |
+
blocksparse_head_sliding_step);
|
47 |
+
|
48 |
+
// TODO(woosuk): Tune NUM_THREADS.
|
49 |
+
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
50 |
+
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
51 |
+
int NUM_THREADS = 128>
|
52 |
+
void paged_attention_v1_launcher(
|
53 |
+
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
54 |
+
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
55 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
56 |
+
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
57 |
+
torch::Tensor& v_scale, const int tp_rank,
|
58 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
59 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
60 |
+
int num_seqs = query.size(0);
|
61 |
+
int num_heads = query.size(1);
|
62 |
+
int head_size = query.size(2);
|
63 |
+
int max_num_blocks_per_seq = block_tables.size(1);
|
64 |
+
int q_stride = query.stride(0);
|
65 |
+
int kv_block_stride = key_cache.stride(0);
|
66 |
+
int kv_head_stride = key_cache.stride(1);
|
67 |
+
|
68 |
+
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
69 |
+
assert(head_size % thread_group_size == 0);
|
70 |
+
|
71 |
+
// NOTE: alibi_slopes is optional.
|
72 |
+
const float* alibi_slopes_ptr =
|
73 |
+
alibi_slopes
|
74 |
+
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
75 |
+
: nullptr;
|
76 |
+
|
77 |
+
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
78 |
+
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
79 |
+
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
80 |
+
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
81 |
+
int* block_tables_ptr = block_tables.data_ptr<int>();
|
82 |
+
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
83 |
+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
84 |
+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
85 |
+
|
86 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
87 |
+
int padded_max_seq_len =
|
88 |
+
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
89 |
+
int logits_size = padded_max_seq_len * sizeof(float);
|
90 |
+
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
91 |
+
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
92 |
+
// Keep that in sync with the logic here!
|
93 |
+
int shared_mem_size = std::max(logits_size, outputs_size);
|
94 |
+
|
95 |
+
dim3 grid(num_heads, num_seqs, 1);
|
96 |
+
dim3 block(NUM_THREADS);
|
97 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
98 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
99 |
+
switch (head_size) {
|
100 |
+
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
101 |
+
// head sizes that we use in the model. However, we can easily extend this
|
102 |
+
// to support any head size which is a multiple of 16.
|
103 |
+
case 32:
|
104 |
+
LAUNCH_PAGED_ATTENTION_V1(32);
|
105 |
+
break;
|
106 |
+
case 64:
|
107 |
+
LAUNCH_PAGED_ATTENTION_V1(64);
|
108 |
+
break;
|
109 |
+
case 80:
|
110 |
+
LAUNCH_PAGED_ATTENTION_V1(80);
|
111 |
+
break;
|
112 |
+
case 96:
|
113 |
+
LAUNCH_PAGED_ATTENTION_V1(96);
|
114 |
+
break;
|
115 |
+
case 112:
|
116 |
+
LAUNCH_PAGED_ATTENTION_V1(112);
|
117 |
+
break;
|
118 |
+
case 120:
|
119 |
+
LAUNCH_PAGED_ATTENTION_V1(120);
|
120 |
+
break;
|
121 |
+
case 128:
|
122 |
+
LAUNCH_PAGED_ATTENTION_V1(128);
|
123 |
+
break;
|
124 |
+
case 192:
|
125 |
+
LAUNCH_PAGED_ATTENTION_V1(192);
|
126 |
+
break;
|
127 |
+
case 256:
|
128 |
+
LAUNCH_PAGED_ATTENTION_V1(256);
|
129 |
+
break;
|
130 |
+
default:
|
131 |
+
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
132 |
+
break;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
|
136 |
+
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
137 |
+
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
138 |
+
IS_BLOCK_SPARSE>( \
|
139 |
+
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
140 |
+
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
|
141 |
+
blocksparse_local_blocks, blocksparse_vert_stride, \
|
142 |
+
blocksparse_block_size, blocksparse_head_sliding_step);
|
143 |
+
|
144 |
+
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
145 |
+
if (is_block_sparse) { \
|
146 |
+
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
147 |
+
} else { \
|
148 |
+
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
149 |
+
}
|
150 |
+
|
151 |
+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
152 |
+
// 1, 2, 4, 64, 128, 256.
|
153 |
+
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
154 |
+
switch (block_size) { \
|
155 |
+
case 8: \
|
156 |
+
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
157 |
+
break; \
|
158 |
+
case 16: \
|
159 |
+
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
160 |
+
break; \
|
161 |
+
case 32: \
|
162 |
+
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
163 |
+
break; \
|
164 |
+
default: \
|
165 |
+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
166 |
+
break; \
|
167 |
+
}
|
168 |
+
|
169 |
+
void paged_attention_v1(
|
170 |
+
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
171 |
+
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
172 |
+
torch::Tensor&
|
173 |
+
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
174 |
+
torch::Tensor&
|
175 |
+
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
176 |
+
int64_t num_kv_heads, // [num_heads]
|
177 |
+
double scale,
|
178 |
+
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
179 |
+
torch::Tensor& seq_lens, // [num_seqs]
|
180 |
+
int64_t block_size, int64_t max_seq_len,
|
181 |
+
const std::optional<torch::Tensor>& alibi_slopes,
|
182 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
183 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
184 |
+
const int64_t blocksparse_local_blocks,
|
185 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
186 |
+
const int64_t blocksparse_head_sliding_step) {
|
187 |
+
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
188 |
+
|
189 |
+
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
190 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
191 |
+
}
|
192 |
+
|
193 |
+
#undef WARP_SIZE
|
194 |
+
#undef MAX
|
195 |
+
#undef MIN
|
196 |
+
#undef DIVIDE_ROUND_UP
|
paged-attention/attention/paged_attention_v2.cu
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
4 |
+
* Copyright (c) 2023, The vLLM team.
|
5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
|
20 |
+
#include "attention_kernels.cuh"
|
21 |
+
|
22 |
+
#ifndef USE_ROCM
|
23 |
+
#define WARP_SIZE 32
|
24 |
+
#else
|
25 |
+
#define WARP_SIZE warpSize
|
26 |
+
#endif
|
27 |
+
|
28 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
29 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
30 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
31 |
+
|
32 |
+
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
33 |
+
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
34 |
+
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
|
35 |
+
PARTITION_SIZE> \
|
36 |
+
<<<grid, block, shared_mem_size, stream>>>( \
|
37 |
+
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
38 |
+
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
39 |
+
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
40 |
+
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
|
41 |
+
blocksparse_local_blocks, blocksparse_vert_stride, \
|
42 |
+
blocksparse_block_size, blocksparse_head_sliding_step); \
|
43 |
+
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
44 |
+
PARTITION_SIZE> \
|
45 |
+
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
46 |
+
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
47 |
+
max_num_partitions);
|
48 |
+
|
49 |
+
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
50 |
+
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
51 |
+
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
|
52 |
+
void paged_attention_v2_launcher(
|
53 |
+
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
54 |
+
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
55 |
+
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
56 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
57 |
+
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
58 |
+
torch::Tensor& v_scale, const int tp_rank,
|
59 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
60 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
61 |
+
int num_seqs = query.size(0);
|
62 |
+
int num_heads = query.size(1);
|
63 |
+
int head_size = query.size(2);
|
64 |
+
int max_num_blocks_per_seq = block_tables.size(1);
|
65 |
+
int q_stride = query.stride(0);
|
66 |
+
int kv_block_stride = key_cache.stride(0);
|
67 |
+
int kv_head_stride = key_cache.stride(1);
|
68 |
+
|
69 |
+
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
70 |
+
assert(head_size % thread_group_size == 0);
|
71 |
+
|
72 |
+
// NOTE: alibi_slopes is optional.
|
73 |
+
const float* alibi_slopes_ptr =
|
74 |
+
alibi_slopes
|
75 |
+
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
76 |
+
: nullptr;
|
77 |
+
|
78 |
+
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
79 |
+
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
80 |
+
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
81 |
+
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
82 |
+
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
83 |
+
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
84 |
+
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
85 |
+
int* block_tables_ptr = block_tables.data_ptr<int>();
|
86 |
+
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
87 |
+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
88 |
+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
89 |
+
|
90 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
91 |
+
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
92 |
+
int logits_size = PARTITION_SIZE * sizeof(float);
|
93 |
+
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
94 |
+
|
95 |
+
// For paged attention v2 kernel.
|
96 |
+
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
97 |
+
int shared_mem_size = std::max(logits_size, outputs_size);
|
98 |
+
// For paged attention v2 reduce kernel.
|
99 |
+
dim3 reduce_grid(num_heads, num_seqs);
|
100 |
+
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
101 |
+
|
102 |
+
dim3 block(NUM_THREADS);
|
103 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
104 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
105 |
+
switch (head_size) {
|
106 |
+
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
107 |
+
// head sizes that we use in the model. However, we can easily extend this
|
108 |
+
// to support any head size which is a multiple of 16.
|
109 |
+
case 32:
|
110 |
+
LAUNCH_PAGED_ATTENTION_V2(32);
|
111 |
+
break;
|
112 |
+
case 64:
|
113 |
+
LAUNCH_PAGED_ATTENTION_V2(64);
|
114 |
+
break;
|
115 |
+
case 80:
|
116 |
+
LAUNCH_PAGED_ATTENTION_V2(80);
|
117 |
+
break;
|
118 |
+
case 96:
|
119 |
+
LAUNCH_PAGED_ATTENTION_V2(96);
|
120 |
+
break;
|
121 |
+
case 112:
|
122 |
+
LAUNCH_PAGED_ATTENTION_V2(112);
|
123 |
+
break;
|
124 |
+
case 120:
|
125 |
+
LAUNCH_PAGED_ATTENTION_V2(120);
|
126 |
+
break;
|
127 |
+
case 128:
|
128 |
+
LAUNCH_PAGED_ATTENTION_V2(128);
|
129 |
+
break;
|
130 |
+
case 192:
|
131 |
+
LAUNCH_PAGED_ATTENTION_V2(192);
|
132 |
+
break;
|
133 |
+
case 256:
|
134 |
+
LAUNCH_PAGED_ATTENTION_V2(256);
|
135 |
+
break;
|
136 |
+
default:
|
137 |
+
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
138 |
+
break;
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
143 |
+
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
144 |
+
IS_BLOCK_SPARSE>( \
|
145 |
+
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
146 |
+
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
147 |
+
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
148 |
+
blocksparse_vert_stride, blocksparse_block_size, \
|
149 |
+
blocksparse_head_sliding_step);
|
150 |
+
|
151 |
+
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
152 |
+
if (is_block_sparse) { \
|
153 |
+
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
154 |
+
} else { \
|
155 |
+
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
156 |
+
}
|
157 |
+
|
158 |
+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
159 |
+
// 1, 2, 4, 64, 128, 256.
|
160 |
+
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
161 |
+
switch (block_size) { \
|
162 |
+
case 8: \
|
163 |
+
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
164 |
+
break; \
|
165 |
+
case 16: \
|
166 |
+
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
167 |
+
break; \
|
168 |
+
case 32: \
|
169 |
+
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
170 |
+
break; \
|
171 |
+
default: \
|
172 |
+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
173 |
+
break; \
|
174 |
+
}
|
175 |
+
|
176 |
+
void paged_attention_v2(
|
177 |
+
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
178 |
+
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
179 |
+
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
180 |
+
torch::Tensor&
|
181 |
+
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
182 |
+
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
183 |
+
torch::Tensor&
|
184 |
+
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
185 |
+
torch::Tensor&
|
186 |
+
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
187 |
+
int64_t num_kv_heads, // [num_heads]
|
188 |
+
double scale,
|
189 |
+
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
190 |
+
torch::Tensor& seq_lens, // [num_seqs]
|
191 |
+
int64_t block_size, int64_t max_seq_len,
|
192 |
+
const std::optional<torch::Tensor>& alibi_slopes,
|
193 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
194 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
195 |
+
const int64_t blocksparse_local_blocks,
|
196 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
197 |
+
const int64_t blocksparse_head_sliding_step) {
|
198 |
+
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
199 |
+
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
200 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
201 |
+
}
|
202 |
+
|
203 |
+
#undef WARP_SIZE
|
204 |
+
#undef MAX
|
205 |
+
#undef MIN
|
206 |
+
#undef DIVIDE_ROUND_UP
|
paged-attention/cache_kernels.cu
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include "cuda_compat.h"
|
6 |
+
#include "dispatch_utils.h"
|
7 |
+
|
8 |
+
#ifdef USE_ROCM
|
9 |
+
#include "quantization/fp8/amd/quant_utils.cuh"
|
10 |
+
#else
|
11 |
+
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
12 |
+
#endif
|
13 |
+
|
14 |
+
#include <algorithm>
|
15 |
+
#include <cassert>
|
16 |
+
#include <map>
|
17 |
+
#include <vector>
|
18 |
+
|
19 |
+
#ifdef USE_ROCM
|
20 |
+
#include <hip/hip_bf16.h>
|
21 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
22 |
+
#endif
|
23 |
+
|
24 |
+
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
25 |
+
const torch::Tensor& block_mapping) {
|
26 |
+
torch::Device src_device = src.device();
|
27 |
+
torch::Device dst_device = dst.device();
|
28 |
+
cudaMemcpyKind memcpy_type;
|
29 |
+
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
30 |
+
TORCH_CHECK(src_device.index() == dst_device.index(),
|
31 |
+
"src and dst must be on the same GPU");
|
32 |
+
memcpy_type = cudaMemcpyDeviceToDevice;
|
33 |
+
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
34 |
+
memcpy_type = cudaMemcpyDeviceToHost;
|
35 |
+
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
36 |
+
memcpy_type = cudaMemcpyHostToDevice;
|
37 |
+
} else {
|
38 |
+
TORCH_CHECK(false, "Invalid device combination");
|
39 |
+
}
|
40 |
+
|
41 |
+
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
42 |
+
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
43 |
+
// synchronization.
|
44 |
+
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
45 |
+
|
46 |
+
char* src_ptr = static_cast<char*>(src.data_ptr());
|
47 |
+
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
48 |
+
|
49 |
+
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
50 |
+
const at::cuda::OptionalCUDAGuard device_guard(
|
51 |
+
src_device.is_cuda() ? src_device : dst_device);
|
52 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
53 |
+
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
54 |
+
const int64_t num_blocks = block_mapping.size(0);
|
55 |
+
for (size_t i = 0; i < num_blocks; i++) {
|
56 |
+
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
57 |
+
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
58 |
+
int64_t src_offset = src_block_number * block_size_in_bytes;
|
59 |
+
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
60 |
+
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
61 |
+
block_size_in_bytes, memcpy_type, stream);
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
namespace vllm {
|
66 |
+
|
67 |
+
// Grid: (num_layers, num_pairs)
|
68 |
+
template <typename scalar_t>
|
69 |
+
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
70 |
+
int64_t* value_cache_ptrs,
|
71 |
+
const int64_t* __restrict__ block_mapping,
|
72 |
+
const int numel_per_block) {
|
73 |
+
const int layer_idx = blockIdx.x;
|
74 |
+
const int pair_idx = blockIdx.y;
|
75 |
+
|
76 |
+
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
77 |
+
scalar_t* value_cache =
|
78 |
+
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
79 |
+
int64_t src_block_number = block_mapping[2 * pair_idx];
|
80 |
+
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
81 |
+
|
82 |
+
const int64_t src_block_offset = src_block_number * numel_per_block;
|
83 |
+
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
84 |
+
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
85 |
+
int64_t src_offset = src_block_offset + i;
|
86 |
+
int64_t dst_offset = dst_block_offset + i;
|
87 |
+
key_cache[dst_offset] = key_cache[src_offset];
|
88 |
+
}
|
89 |
+
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
90 |
+
int64_t src_offset = src_block_offset + i;
|
91 |
+
int64_t dst_offset = dst_block_offset + i;
|
92 |
+
value_cache[dst_offset] = value_cache[src_offset];
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
} // namespace vllm
|
97 |
+
|
98 |
+
// Note: the key_caches and value_caches vectors are constant but
|
99 |
+
// not the Tensors they contain. The vectors need to be const refs
|
100 |
+
// in order to satisfy pytorch's C++ operator registration code.
|
101 |
+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
102 |
+
std::vector<torch::Tensor> const& value_caches,
|
103 |
+
const torch::Tensor& block_mapping) {
|
104 |
+
int num_layers = key_caches.size();
|
105 |
+
TORCH_CHECK(num_layers == value_caches.size());
|
106 |
+
if (num_layers == 0) {
|
107 |
+
return;
|
108 |
+
}
|
109 |
+
torch::Device cache_device = key_caches[0].device();
|
110 |
+
TORCH_CHECK(cache_device.is_cuda());
|
111 |
+
|
112 |
+
// Create data structures for the kernel.
|
113 |
+
// Create an array of pointers to the key and value caches.
|
114 |
+
int64_t key_cache_ptrs[num_layers];
|
115 |
+
int64_t value_cache_ptrs[num_layers];
|
116 |
+
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
117 |
+
key_cache_ptrs[layer_idx] =
|
118 |
+
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
119 |
+
value_cache_ptrs[layer_idx] =
|
120 |
+
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
121 |
+
}
|
122 |
+
|
123 |
+
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
124 |
+
int num_pairs = block_mapping.size(0);
|
125 |
+
|
126 |
+
// Move the data structures to the GPU.
|
127 |
+
// NOTE: This synchronizes the CPU and GPU.
|
128 |
+
torch::Tensor key_cache_ptrs_tensor =
|
129 |
+
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
130 |
+
.to(cache_device);
|
131 |
+
torch::Tensor value_cache_ptrs_tensor =
|
132 |
+
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
133 |
+
.to(cache_device);
|
134 |
+
|
135 |
+
// Launch the kernel.
|
136 |
+
const int numel_per_block = key_caches[0][0].numel();
|
137 |
+
dim3 grid(num_layers, num_pairs);
|
138 |
+
dim3 block(std::min(1024, numel_per_block));
|
139 |
+
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
140 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
141 |
+
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
142 |
+
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
143 |
+
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
144 |
+
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
145 |
+
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
146 |
+
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
147 |
+
}));
|
148 |
+
}
|
149 |
+
|
150 |
+
namespace vllm {
|
151 |
+
|
152 |
+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
153 |
+
__global__ void reshape_and_cache_kernel(
|
154 |
+
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
155 |
+
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
156 |
+
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
157 |
+
// block_size, x]
|
158 |
+
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
159 |
+
// block_size]
|
160 |
+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
161 |
+
const int key_stride, const int value_stride, const int num_heads,
|
162 |
+
const int head_size, const int block_size, const int x,
|
163 |
+
const float* k_scale, const float* v_scale) {
|
164 |
+
const int64_t token_idx = blockIdx.x;
|
165 |
+
const int64_t slot_idx = slot_mapping[token_idx];
|
166 |
+
if (slot_idx < 0) {
|
167 |
+
// Padding token that should be ignored.
|
168 |
+
return;
|
169 |
+
}
|
170 |
+
|
171 |
+
const int64_t block_idx = slot_idx / block_size;
|
172 |
+
const int64_t block_offset = slot_idx % block_size;
|
173 |
+
|
174 |
+
const int n = num_heads * head_size;
|
175 |
+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
176 |
+
const int64_t src_key_idx = token_idx * key_stride + i;
|
177 |
+
const int64_t src_value_idx = token_idx * value_stride + i;
|
178 |
+
|
179 |
+
const int head_idx = i / head_size;
|
180 |
+
const int head_offset = i % head_size;
|
181 |
+
const int x_idx = head_offset / x;
|
182 |
+
const int x_offset = head_offset % x;
|
183 |
+
|
184 |
+
const int64_t tgt_key_idx =
|
185 |
+
block_idx * num_heads * (head_size / x) * block_size * x +
|
186 |
+
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
187 |
+
block_offset * x + x_offset;
|
188 |
+
const int64_t tgt_value_idx =
|
189 |
+
block_idx * num_heads * head_size * block_size +
|
190 |
+
head_idx * head_size * block_size + head_offset * block_size +
|
191 |
+
block_offset;
|
192 |
+
scalar_t tgt_key = key[src_key_idx];
|
193 |
+
scalar_t tgt_value = value[src_value_idx];
|
194 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
195 |
+
key_cache[tgt_key_idx] = tgt_key;
|
196 |
+
value_cache[tgt_value_idx] = tgt_value;
|
197 |
+
} else {
|
198 |
+
key_cache[tgt_key_idx] =
|
199 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
200 |
+
value_cache[tgt_value_idx] =
|
201 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
202 |
+
}
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
207 |
+
__global__ void reshape_and_cache_flash_kernel(
|
208 |
+
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
209 |
+
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
210 |
+
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
211 |
+
// head_size]
|
212 |
+
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
213 |
+
// head_size]
|
214 |
+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
215 |
+
const int block_stride, const int key_stride, const int value_stride,
|
216 |
+
const int num_heads, const int head_size, const int block_size,
|
217 |
+
const float* k_scale, const float* v_scale) {
|
218 |
+
const int64_t token_idx = blockIdx.x;
|
219 |
+
const int64_t slot_idx = slot_mapping[token_idx];
|
220 |
+
// NOTE: slot_idx can be -1 if the token is padded
|
221 |
+
if (slot_idx < 0) {
|
222 |
+
return;
|
223 |
+
}
|
224 |
+
const int64_t block_idx = slot_idx / block_size;
|
225 |
+
const int64_t block_offset = slot_idx % block_size;
|
226 |
+
const int n = num_heads * head_size;
|
227 |
+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
228 |
+
const int64_t src_key_idx = token_idx * key_stride + i;
|
229 |
+
const int64_t src_value_idx = token_idx * value_stride + i;
|
230 |
+
const int head_idx = i / head_size;
|
231 |
+
const int head_offset = i % head_size;
|
232 |
+
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
233 |
+
block_offset * num_heads * head_size +
|
234 |
+
head_idx * head_size + head_offset;
|
235 |
+
scalar_t tgt_key = key[src_key_idx];
|
236 |
+
scalar_t tgt_value = value[src_value_idx];
|
237 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
238 |
+
key_cache[tgt_key_value_idx] = tgt_key;
|
239 |
+
value_cache[tgt_key_value_idx] = tgt_value;
|
240 |
+
} else {
|
241 |
+
key_cache[tgt_key_value_idx] =
|
242 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
243 |
+
value_cache[tgt_key_value_idx] =
|
244 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
245 |
+
}
|
246 |
+
}
|
247 |
+
}
|
248 |
+
} // namespace vllm
|
249 |
+
|
250 |
+
// KV_T is the stored data type of kv-cache.
|
251 |
+
// CACHE_T is the data type of key and value tensors.
|
252 |
+
// KV_DTYPE is the real data type of kv-cache.
|
253 |
+
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
254 |
+
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
255 |
+
<<<grid, block, 0, stream>>>( \
|
256 |
+
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
257 |
+
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
258 |
+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
259 |
+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
260 |
+
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
261 |
+
num_heads, head_size, block_size, x, \
|
262 |
+
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
263 |
+
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
264 |
+
|
265 |
+
void reshape_and_cache(
|
266 |
+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
267 |
+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
268 |
+
torch::Tensor&
|
269 |
+
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
270 |
+
torch::Tensor&
|
271 |
+
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
272 |
+
torch::Tensor& slot_mapping, // [num_tokens]
|
273 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
274 |
+
torch::Tensor& v_scale) {
|
275 |
+
int num_tokens = key.size(0);
|
276 |
+
int num_heads = key.size(1);
|
277 |
+
int head_size = key.size(2);
|
278 |
+
int block_size = key_cache.size(3);
|
279 |
+
int x = key_cache.size(4);
|
280 |
+
|
281 |
+
int key_stride = key.stride(0);
|
282 |
+
int value_stride = value.stride(0);
|
283 |
+
|
284 |
+
dim3 grid(num_tokens);
|
285 |
+
dim3 block(std::min(num_heads * head_size, 512));
|
286 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
287 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
288 |
+
|
289 |
+
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
290 |
+
CALL_RESHAPE_AND_CACHE)
|
291 |
+
}
|
292 |
+
|
293 |
+
// KV_T is the stored data type of kv-cache.
|
294 |
+
// CACHE_T is the data type of key and value tensors.
|
295 |
+
// KV_DTYPE is the real data type of kv-cache.
|
296 |
+
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
297 |
+
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
298 |
+
<<<grid, block, 0, stream>>>( \
|
299 |
+
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
300 |
+
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
301 |
+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
302 |
+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
303 |
+
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
304 |
+
value_stride, num_heads, head_size, block_size, \
|
305 |
+
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
306 |
+
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
307 |
+
|
308 |
+
void reshape_and_cache_flash(
|
309 |
+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
310 |
+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
311 |
+
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
312 |
+
torch::Tensor&
|
313 |
+
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
314 |
+
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
315 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
316 |
+
torch::Tensor& v_scale) {
|
317 |
+
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
318 |
+
// slot_mapping.size(0) because of padding for CUDA graphs.
|
319 |
+
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
320 |
+
// both include padding.
|
321 |
+
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
322 |
+
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
323 |
+
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
324 |
+
// before padding.
|
325 |
+
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
326 |
+
// number of tokens.
|
327 |
+
int num_tokens = slot_mapping.size(0);
|
328 |
+
int num_heads = key.size(1);
|
329 |
+
int head_size = key.size(2);
|
330 |
+
int block_size = key_cache.size(1);
|
331 |
+
|
332 |
+
int key_stride = key.stride(0);
|
333 |
+
int value_stride = value.stride(0);
|
334 |
+
int block_stride = key_cache.stride(0);
|
335 |
+
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
336 |
+
|
337 |
+
dim3 grid(num_tokens);
|
338 |
+
dim3 block(std::min(num_heads * head_size, 512));
|
339 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
340 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
341 |
+
|
342 |
+
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
343 |
+
CALL_RESHAPE_AND_CACHE_FLASH);
|
344 |
+
}
|
345 |
+
|
346 |
+
namespace vllm {
|
347 |
+
|
348 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
349 |
+
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
350 |
+
Tout* __restrict__ dst_cache,
|
351 |
+
const float scale,
|
352 |
+
const int64_t block_stride) {
|
353 |
+
const int64_t block_idx = blockIdx.x;
|
354 |
+
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
355 |
+
int64_t idx = block_idx * block_stride + i;
|
356 |
+
dst_cache[idx] =
|
357 |
+
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
|
358 |
+
}
|
359 |
+
}
|
360 |
+
|
361 |
+
} // namespace vllm
|
362 |
+
|
363 |
+
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
364 |
+
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
365 |
+
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
366 |
+
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
|
367 |
+
|
368 |
+
// Only for testing.
|
369 |
+
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
370 |
+
const double scale, const std::string& kv_cache_dtype) {
|
371 |
+
torch::Device src_device = src_cache.device();
|
372 |
+
torch::Device dst_device = dst_cache.device();
|
373 |
+
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
374 |
+
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
375 |
+
TORCH_CHECK(src_device.index() == dst_device.index(),
|
376 |
+
"src and dst must be on the same GPU");
|
377 |
+
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
378 |
+
|
379 |
+
int64_t num_blocks = src_cache.size(0);
|
380 |
+
int64_t block_stride = src_cache.stride(0);
|
381 |
+
|
382 |
+
dim3 grid(num_blocks);
|
383 |
+
dim3 block(std::min(block_stride, int64_t(512)));
|
384 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
385 |
+
|
386 |
+
if (kv_cache_dtype == "auto") {
|
387 |
+
if (src_cache.dtype() == at::ScalarType::Float) {
|
388 |
+
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
|
389 |
+
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
390 |
+
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
|
391 |
+
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
392 |
+
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
|
393 |
+
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
394 |
+
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
395 |
+
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
396 |
+
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
397 |
+
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
398 |
+
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
399 |
+
}
|
400 |
+
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
401 |
+
if (src_cache.dtype() == at::ScalarType::Float) {
|
402 |
+
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
403 |
+
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
404 |
+
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
405 |
+
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
406 |
+
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
407 |
+
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
408 |
+
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
409 |
+
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
410 |
+
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
411 |
+
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
412 |
+
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
413 |
+
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
414 |
+
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
415 |
+
}
|
416 |
+
} else {
|
417 |
+
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
418 |
+
}
|
419 |
+
}
|
paged-attention/cuda_compat.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifdef USE_ROCM
|
4 |
+
#include <hip/hip_runtime.h>
|
5 |
+
#endif
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#define WARP_SIZE 32
|
9 |
+
#else
|
10 |
+
#define WARP_SIZE warpSize
|
11 |
+
#endif
|
12 |
+
|
13 |
+
#ifndef USE_ROCM
|
14 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
15 |
+
#else
|
16 |
+
#define VLLM_LDG(arg) *(arg)
|
17 |
+
#endif
|
18 |
+
|
19 |
+
#ifndef USE_ROCM
|
20 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
21 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
22 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
23 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
24 |
+
#else
|
25 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
26 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
27 |
+
__shfl_xor(var, lane_mask, width)
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#ifndef USE_ROCM
|
31 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
32 |
+
#else
|
33 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
34 |
+
#endif
|
35 |
+
|
36 |
+
#ifndef USE_ROCM
|
37 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
38 |
+
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
39 |
+
#else
|
40 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
41 |
+
#endif
|
42 |
+
|
43 |
+
#ifndef USE_ROCM
|
44 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
45 |
+
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
46 |
+
#else
|
47 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
48 |
+
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
49 |
+
#endif
|
paged-attention/dispatch_utils.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
+
*/
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <torch/all.h>
|
8 |
+
|
9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
+
|
14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
16 |
+
|
17 |
+
// TODO(luka/varun): use FP8_TYPE macro after refactoring
|
18 |
+
#ifndef USE_ROCM
|
19 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
20 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
21 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
22 |
+
#else
|
23 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
24 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
25 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
26 |
+
#endif
|
27 |
+
|
28 |
+
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
29 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
30 |
+
|
31 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
33 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
34 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
35 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
36 |
+
|
37 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
38 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
39 |
+
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
40 |
+
|
41 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
42 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
43 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
44 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
45 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
46 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
47 |
+
|
48 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
49 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
paged-attention/quantization/fp8/amd/hip_float8.h
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifdef __HIPCC__
|
4 |
+
#include <hip/hip_runtime.h>
|
5 |
+
#else
|
6 |
+
#include <type_traits>
|
7 |
+
#include <stdint.h>
|
8 |
+
#include <math.h>
|
9 |
+
#include <iostream>
|
10 |
+
#endif
|
11 |
+
|
12 |
+
#include "hip_float8_impl.h"
|
13 |
+
|
14 |
+
struct alignas(1) hip_fp8 {
|
15 |
+
struct from_bits_t {};
|
16 |
+
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
17 |
+
return from_bits_t();
|
18 |
+
}
|
19 |
+
uint8_t data;
|
20 |
+
|
21 |
+
hip_fp8() = default;
|
22 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
23 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
24 |
+
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
25 |
+
: data(v) {}
|
26 |
+
|
27 |
+
#ifdef __HIP__MI300__
|
28 |
+
// NOTE: ON-DEVICE... always optimal bias
|
29 |
+
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
30 |
+
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
31 |
+
|
32 |
+
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
33 |
+
: hip_fp8(static_cast<float>(v)) {}
|
34 |
+
|
35 |
+
// Host only implementation using s/w simulation
|
36 |
+
explicit HIP_FP8_HOST
|
37 |
+
#else // __HIP__MI300__
|
38 |
+
// both Host and DEVICE for non-MI300 using s/w simulation
|
39 |
+
explicit HIP_FP8_HOST_DEVICE
|
40 |
+
#endif // __HIP__MI300__
|
41 |
+
hip_fp8(float v) {
|
42 |
+
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
43 |
+
true /*clip*/>(v);
|
44 |
+
}
|
45 |
+
|
46 |
+
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
47 |
+
: hip_fp8(static_cast<float>(v)) {}
|
48 |
+
|
49 |
+
#ifdef __HIP__MI300__
|
50 |
+
// upcast using device specific intrinsic
|
51 |
+
explicit inline HIP_FP8_DEVICE operator float() const {
|
52 |
+
float fval;
|
53 |
+
uint32_t i32val = static_cast<uint32_t>(data);
|
54 |
+
|
55 |
+
// upcast
|
56 |
+
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
57 |
+
: "=v"(fval)
|
58 |
+
: "v"(i32val));
|
59 |
+
|
60 |
+
return fval;
|
61 |
+
}
|
62 |
+
|
63 |
+
explicit inline HIP_FP8_HOST operator float() const
|
64 |
+
#else // __HIP__MI300__
|
65 |
+
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
66 |
+
#endif // __HIP__MI300__
|
67 |
+
{
|
68 |
+
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
69 |
+
data);
|
70 |
+
}
|
71 |
+
};
|
72 |
+
|
73 |
+
namespace std {
|
74 |
+
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
75 |
+
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
76 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
77 |
+
} // namespace std
|
78 |
+
|
79 |
+
// Special operator overloading
|
80 |
+
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
81 |
+
return os << float(f8);
|
82 |
+
}
|
83 |
+
|
84 |
+
// all + operator overloading with mixed types
|
85 |
+
// mixed types, always converts to f32, does computation in f32, and returns
|
86 |
+
// float
|
87 |
+
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
88 |
+
return (fa + float(b));
|
89 |
+
}
|
90 |
+
|
91 |
+
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
92 |
+
return (float(a) + fb);
|
93 |
+
}
|
94 |
+
|
95 |
+
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
96 |
+
return hip_fp8(float(a) + float(b));
|
97 |
+
}
|
98 |
+
|
99 |
+
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
100 |
+
return a = hip_fp8(float(a) + float(b));
|
101 |
+
}
|
102 |
+
|
103 |
+
// overloading multiplication, always returns float,
|
104 |
+
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
105 |
+
return float(a) * float(b);
|
106 |
+
}
|
107 |
+
|
108 |
+
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
109 |
+
return (a * float(b));
|
110 |
+
}
|
111 |
+
|
112 |
+
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
113 |
+
return (float(a) * b);
|
114 |
+
}
|
115 |
+
|
116 |
+
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
117 |
+
return ((float)a * float(b));
|
118 |
+
}
|
119 |
+
|
120 |
+
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
121 |
+
return ((float)a * float(b));
|
122 |
+
}
|
123 |
+
|
124 |
+
// overloading for compare
|
125 |
+
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
126 |
+
return (a.data == b.data);
|
127 |
+
}
|
128 |
+
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
129 |
+
return (a.data != b.data);
|
130 |
+
}
|
131 |
+
|
132 |
+
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
133 |
+
return static_cast<float>(a) >= static_cast<float>(b);
|
134 |
+
}
|
135 |
+
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
136 |
+
return static_cast<float>(a) > static_cast<float>(b);
|
137 |
+
}
|
paged-attention/quantization/fp8/amd/hip_float8_impl.h
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#if defined(__HIPCC__) && \
|
4 |
+
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
5 |
+
#define __HIP__MI300__
|
6 |
+
#endif
|
7 |
+
|
8 |
+
#ifdef __HIPCC__
|
9 |
+
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
10 |
+
#define HIP_FP8_HOST __host__
|
11 |
+
#define HIP_FP8_DEVICE __device__
|
12 |
+
#else
|
13 |
+
#define HIP_FP8_HOST_DEVICE
|
14 |
+
#define HIP_FP8_HOST
|
15 |
+
#define HIP_FP8_DEVICE
|
16 |
+
#endif
|
17 |
+
|
18 |
+
namespace hip_fp8_impl {
|
19 |
+
|
20 |
+
#ifdef __HIP__MI300__
|
21 |
+
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
22 |
+
uint8_t i8data;
|
23 |
+
union {
|
24 |
+
float fval;
|
25 |
+
uint32_t i32val;
|
26 |
+
uint8_t i8val[4]; // NOTE: not endian independent
|
27 |
+
} val;
|
28 |
+
|
29 |
+
uint32_t ival = 0;
|
30 |
+
val.fval = v;
|
31 |
+
|
32 |
+
if ((val.i32val & 0x7F800000) !=
|
33 |
+
0x7F800000) { /// propagate NAN/INF, no clipping
|
34 |
+
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
35 |
+
}
|
36 |
+
|
37 |
+
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
38 |
+
false); // false -> WORD0
|
39 |
+
val.i32val = ival;
|
40 |
+
i8data = val.i8val[0];
|
41 |
+
|
42 |
+
return i8data;
|
43 |
+
}
|
44 |
+
#endif // __HIP__MI300__
|
45 |
+
|
46 |
+
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
47 |
+
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
48 |
+
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
49 |
+
#endif
|
50 |
+
|
51 |
+
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
52 |
+
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
53 |
+
uint32_t rng = 0) {
|
54 |
+
#ifdef __HIPCC__
|
55 |
+
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
56 |
+
#else
|
57 |
+
constexpr bool is_half = false;
|
58 |
+
#endif
|
59 |
+
constexpr bool is_float = std::is_same<T, float>::value;
|
60 |
+
static_assert(wm + we == 7, "wm+we==7");
|
61 |
+
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
62 |
+
|
63 |
+
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
64 |
+
uint32_t x;
|
65 |
+
if (sizeof(T) == 4) {
|
66 |
+
x = reinterpret_cast<uint32_t&>(_x);
|
67 |
+
} else {
|
68 |
+
x = reinterpret_cast<uint16_t&>(_x);
|
69 |
+
}
|
70 |
+
|
71 |
+
uint32_t head, mantissa;
|
72 |
+
int exponent, bias;
|
73 |
+
uint32_t sign;
|
74 |
+
|
75 |
+
if (sizeof(T) == 4) {
|
76 |
+
head = x & 0xFF800000;
|
77 |
+
mantissa = x & 0x7FFFFF;
|
78 |
+
exponent = (head >> 23) & 0xFF;
|
79 |
+
sign = head >> 31;
|
80 |
+
bias = 127;
|
81 |
+
} else {
|
82 |
+
head = x & 0xFC00;
|
83 |
+
mantissa = x & 0x3FF;
|
84 |
+
exponent = (head >> 10) & 0x1F;
|
85 |
+
sign = head >> 15;
|
86 |
+
bias = 15;
|
87 |
+
}
|
88 |
+
|
89 |
+
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
90 |
+
|
91 |
+
// Deal with inf and NaNs
|
92 |
+
if (negative_zero_nan) {
|
93 |
+
if (sizeof(T) == 4) {
|
94 |
+
if ((x & 0x7F800000) == 0x7F800000) {
|
95 |
+
return 0x80;
|
96 |
+
}
|
97 |
+
} else {
|
98 |
+
// if(__hisinf(x) || __hisnan(x))
|
99 |
+
if ((x & 0x7C00) == 0x7C00) {
|
100 |
+
return 0x80;
|
101 |
+
}
|
102 |
+
}
|
103 |
+
} else {
|
104 |
+
if (sizeof(T) == 4) {
|
105 |
+
if ((x & 0x7F800000) == 0x7F800000) {
|
106 |
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
107 |
+
}
|
108 |
+
} else {
|
109 |
+
if ((x & 0x7C00) == 0x7C00) {
|
110 |
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
111 |
+
}
|
112 |
+
}
|
113 |
+
}
|
114 |
+
if (x == 0) {
|
115 |
+
return 0;
|
116 |
+
}
|
117 |
+
|
118 |
+
// First need to check if it is normal or denorm as there is a difference of
|
119 |
+
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
120 |
+
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
121 |
+
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
122 |
+
// need to check whether there is carry and adjust exponent and mantissa again
|
123 |
+
|
124 |
+
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
125 |
+
// bits
|
126 |
+
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
127 |
+
const int f8_denormal_act_exponent =
|
128 |
+
1 - f8_bias; // actual exponent of f8 denormal
|
129 |
+
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
130 |
+
// f8_exponent is the converted f8 exponent with bias encoding
|
131 |
+
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
132 |
+
// the difference needs to be adjusted and mantissa shifted
|
133 |
+
int act_exponent, f8_exponent, exponent_diff;
|
134 |
+
|
135 |
+
if (exponent == 0) { // fp32/fp16 is in denormal.
|
136 |
+
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
137 |
+
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
138 |
+
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
139 |
+
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
140 |
+
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
141 |
+
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
142 |
+
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
143 |
+
act_exponent = exponent - bias + 1;
|
144 |
+
exponent_diff =
|
145 |
+
f8_denormal_act_exponent -
|
146 |
+
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
147 |
+
} else { // fp32/fp16 is normal with implicit 1
|
148 |
+
act_exponent = exponent - bias;
|
149 |
+
if (act_exponent <= f8_denormal_act_exponent) {
|
150 |
+
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
151 |
+
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
152 |
+
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
153 |
+
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
154 |
+
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
155 |
+
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
156 |
+
} else { // both fp32/fp16 and f8 are in normal range
|
157 |
+
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
158 |
+
// difference for this case, act_exponent could be
|
159 |
+
// larger. Just that it does not need shift mantissa
|
160 |
+
}
|
161 |
+
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
162 |
+
}
|
163 |
+
|
164 |
+
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
165 |
+
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
166 |
+
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
167 |
+
done before we shift right as shift right could rip off some residual part
|
168 |
+
and make something not midpoint look like midpoint. For example, the fp16
|
169 |
+
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
170 |
+
shift right by 4 bits, it would look like midpoint.
|
171 |
+
*/
|
172 |
+
|
173 |
+
if (exponent_diff > 0) {
|
174 |
+
mantissa >>= exponent_diff;
|
175 |
+
} else if (exponent_diff == -1) {
|
176 |
+
mantissa <<= -exponent_diff;
|
177 |
+
}
|
178 |
+
bool implicit_one = mantissa & (1 << mfmt);
|
179 |
+
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
180 |
+
// to denorm exponent
|
181 |
+
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
182 |
+
f8_bias - (implicit_one ? 0 : 1);
|
183 |
+
|
184 |
+
// Now we have the exponent and mantissa adjusted
|
185 |
+
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
186 |
+
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
187 |
+
// that is not truncated is 1
|
188 |
+
mantissa +=
|
189 |
+
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
190 |
+
drop_mask;
|
191 |
+
|
192 |
+
// Now we deal with overflow
|
193 |
+
if (f8_exponent == 0) {
|
194 |
+
if ((1 << mfmt) & mantissa) {
|
195 |
+
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
196 |
+
}
|
197 |
+
} else {
|
198 |
+
if ((1 << (mfmt + 1)) & mantissa) {
|
199 |
+
mantissa >>= 1;
|
200 |
+
f8_exponent++;
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
mantissa >>= (mfmt - wm);
|
205 |
+
|
206 |
+
// above range: quantize to maximum possible float of the same sign
|
207 |
+
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
208 |
+
if (f8_exponent > max_exp) {
|
209 |
+
if (clip) {
|
210 |
+
mantissa = (1 << wm) - 1;
|
211 |
+
f8_exponent = max_exp;
|
212 |
+
} else {
|
213 |
+
return signed_inf;
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
if (f8_exponent == 0 && mantissa == 0) {
|
218 |
+
return negative_zero_nan ? 0 : (sign << 7);
|
219 |
+
}
|
220 |
+
mantissa &= (1 << wm) - 1;
|
221 |
+
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
222 |
+
}
|
223 |
+
|
224 |
+
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
225 |
+
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
226 |
+
#ifdef __HIPCC__
|
227 |
+
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
228 |
+
#else
|
229 |
+
constexpr bool is_half = false;
|
230 |
+
#endif
|
231 |
+
constexpr bool is_float = std::is_same<T, float>::value;
|
232 |
+
static_assert(is_half || is_float, "only half and float are supported");
|
233 |
+
|
234 |
+
constexpr int weo = is_half ? 5 : 8;
|
235 |
+
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
236 |
+
|
237 |
+
T fInf, fNegInf, fNaN, fNeg0;
|
238 |
+
|
239 |
+
#ifdef __HIPCC__
|
240 |
+
if (is_half) {
|
241 |
+
const uint16_t ihInf = 0x7C00;
|
242 |
+
const uint16_t ihNegInf = 0xFC00;
|
243 |
+
const uint16_t ihNaN = 0x7C01;
|
244 |
+
const uint16_t ihNeg0 = 0x8000;
|
245 |
+
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
246 |
+
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
247 |
+
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
248 |
+
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
249 |
+
} else
|
250 |
+
#endif
|
251 |
+
if (is_float) {
|
252 |
+
const uint32_t ifInf = 0x7F800000;
|
253 |
+
const uint32_t ifNegInf = 0xFF800000;
|
254 |
+
const uint32_t ifNaN = 0x7F800001;
|
255 |
+
const uint32_t ifNeg0 = 0x80000000;
|
256 |
+
fInf = reinterpret_cast<const float&>(ifInf);
|
257 |
+
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
258 |
+
fNaN = reinterpret_cast<const float&>(ifNaN);
|
259 |
+
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
260 |
+
}
|
261 |
+
|
262 |
+
if (x == 0) {
|
263 |
+
return 0;
|
264 |
+
}
|
265 |
+
|
266 |
+
uint32_t sign = x >> 7;
|
267 |
+
uint32_t mantissa = x & ((1 << wm) - 1);
|
268 |
+
int exponent = (x & 0x7F) >> wm;
|
269 |
+
if (negative_zero_nan) {
|
270 |
+
if (x == 0x80) {
|
271 |
+
return fNaN;
|
272 |
+
}
|
273 |
+
} else {
|
274 |
+
if (x == 0x80) {
|
275 |
+
return fNeg0;
|
276 |
+
}
|
277 |
+
if (exponent == ((1 << we) - 1)) {
|
278 |
+
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
279 |
+
}
|
280 |
+
}
|
281 |
+
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
282 |
+
if (we == 5 && is_half && !negative_zero_nan) {
|
283 |
+
retval = x << 8;
|
284 |
+
return reinterpret_cast<const T&>(retval);
|
285 |
+
}
|
286 |
+
|
287 |
+
const int exp_low_cutoff =
|
288 |
+
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
289 |
+
|
290 |
+
// subnormal input
|
291 |
+
if (exponent == 0) {
|
292 |
+
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
293 |
+
int sh = 1 + clz(mantissa) - (32 - wm);
|
294 |
+
mantissa <<= sh;
|
295 |
+
exponent += 1 - sh;
|
296 |
+
mantissa &= ((1 << wm) - 1);
|
297 |
+
}
|
298 |
+
exponent += exp_low_cutoff - 1;
|
299 |
+
mantissa <<= wmo - wm;
|
300 |
+
|
301 |
+
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
302 |
+
if (exponent <= 0) {
|
303 |
+
mantissa |= 1 << wmo;
|
304 |
+
mantissa >>= 1 - exponent;
|
305 |
+
exponent = 0;
|
306 |
+
}
|
307 |
+
|
308 |
+
if (sizeof(T) == 2) {
|
309 |
+
retval = (sign << 15) | (exponent << 10) | mantissa;
|
310 |
+
} else {
|
311 |
+
retval = (sign << 31) | (exponent << 23) | mantissa;
|
312 |
+
}
|
313 |
+
return reinterpret_cast<const T&>(retval);
|
314 |
+
}
|
315 |
+
|
316 |
+
} // namespace hip_fp8_impl
|
paged-attention/quantization/fp8/amd/quant_utils.cuh
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include "hip_float8.h"
|
3 |
+
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hip/hip_bf16.h>
|
6 |
+
#include <hip/hip_bfloat16.h>
|
7 |
+
|
8 |
+
#include "../../../attention/dtype_fp8.cuh"
|
9 |
+
#include "../../../attention/dtype_float32.cuh"
|
10 |
+
#include "../../../attention/dtype_bfloat16.cuh"
|
11 |
+
|
12 |
+
namespace vllm {
|
13 |
+
#ifdef USE_ROCM
|
14 |
+
|
15 |
+
namespace fp8 {
|
16 |
+
#ifdef ENABLE_FP8
|
17 |
+
|
18 |
+
template <typename Tout, typename Tin>
|
19 |
+
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
20 |
+
return x;
|
21 |
+
}
|
22 |
+
|
23 |
+
template <typename Tout, typename Tin>
|
24 |
+
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
25 |
+
const float scale) {
|
26 |
+
return x;
|
27 |
+
}
|
28 |
+
|
29 |
+
// fp8 -> half
|
30 |
+
template <>
|
31 |
+
__inline__ __device__ uint16_t
|
32 |
+
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
33 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
34 |
+
__half_raw res;
|
35 |
+
res.data = static_cast<float>(f8);
|
36 |
+
return res.x;
|
37 |
+
}
|
38 |
+
|
39 |
+
// fp8x2 -> half2
|
40 |
+
template <>
|
41 |
+
__inline__ __device__ uint32_t
|
42 |
+
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
43 |
+
#if defined(__HIP__MI300__) && \
|
44 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
45 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
46 |
+
union {
|
47 |
+
__half2_raw h2r;
|
48 |
+
uint32_t ui32;
|
49 |
+
} tmp;
|
50 |
+
tmp.h2r.x.data = f2[0];
|
51 |
+
tmp.h2r.y.data = f2[1];
|
52 |
+
return tmp.ui32;
|
53 |
+
#else
|
54 |
+
union {
|
55 |
+
uint16_t u16[2];
|
56 |
+
uint32_t u32;
|
57 |
+
} tmp;
|
58 |
+
|
59 |
+
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
60 |
+
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
61 |
+
return tmp.u32;
|
62 |
+
#endif
|
63 |
+
}
|
64 |
+
|
65 |
+
// fp8x4 -> half2x2
|
66 |
+
template <>
|
67 |
+
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
68 |
+
union {
|
69 |
+
uint2 u32x2;
|
70 |
+
uint32_t u32[2];
|
71 |
+
} tmp;
|
72 |
+
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
73 |
+
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
74 |
+
return tmp.u32x2;
|
75 |
+
}
|
76 |
+
|
77 |
+
// fp8x8 -> half2x4
|
78 |
+
template <>
|
79 |
+
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
80 |
+
union {
|
81 |
+
uint4 u64x2;
|
82 |
+
uint2 u64[2];
|
83 |
+
} tmp;
|
84 |
+
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
85 |
+
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
86 |
+
return tmp.u64x2;
|
87 |
+
}
|
88 |
+
|
89 |
+
using __nv_bfloat16 = __hip_bfloat16;
|
90 |
+
|
91 |
+
// fp8 -> __nv_bfloat16
|
92 |
+
template <>
|
93 |
+
__inline__ __device__ __nv_bfloat16
|
94 |
+
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
95 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
96 |
+
float f{f8};
|
97 |
+
return __float2bfloat16(f);
|
98 |
+
}
|
99 |
+
|
100 |
+
using __nv_bfloat162 = __hip_bfloat162;
|
101 |
+
|
102 |
+
// fp8x2 -> __nv_bfloat162
|
103 |
+
template <>
|
104 |
+
__inline__ __device__ __nv_bfloat162
|
105 |
+
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
106 |
+
__nv_bfloat162 res;
|
107 |
+
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
108 |
+
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
109 |
+
return res;
|
110 |
+
}
|
111 |
+
|
112 |
+
// fp8x4 -> bf16_4_t
|
113 |
+
template <>
|
114 |
+
__inline__ __device__ bf16_4_t
|
115 |
+
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
116 |
+
bf16_4_t res;
|
117 |
+
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
118 |
+
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
119 |
+
return res;
|
120 |
+
}
|
121 |
+
|
122 |
+
// fp8x8 -> bf16_8_t
|
123 |
+
template <>
|
124 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
125 |
+
bf16_4_t tmp1, tmp2;
|
126 |
+
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
127 |
+
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
128 |
+
bf16_8_t res;
|
129 |
+
res.x = tmp1.x;
|
130 |
+
res.y = tmp1.y;
|
131 |
+
res.z = tmp2.x;
|
132 |
+
res.w = tmp2.y;
|
133 |
+
return res;
|
134 |
+
}
|
135 |
+
|
136 |
+
// fp8 -> float
|
137 |
+
template <>
|
138 |
+
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
139 |
+
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
140 |
+
return static_cast<float>(fp8);
|
141 |
+
}
|
142 |
+
|
143 |
+
// fp8x2 -> float2
|
144 |
+
template <>
|
145 |
+
__inline__ __device__ float2
|
146 |
+
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
147 |
+
#if defined(__HIP__MI300__) && \
|
148 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
149 |
+
float2 res;
|
150 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
151 |
+
res.x = f2[0];
|
152 |
+
res.y = f2[1];
|
153 |
+
return res;
|
154 |
+
#else
|
155 |
+
float2 res;
|
156 |
+
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
157 |
+
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
158 |
+
return res;
|
159 |
+
#endif
|
160 |
+
}
|
161 |
+
|
162 |
+
// fp8x4 -> float4
|
163 |
+
template <>
|
164 |
+
__inline__ __device__ Float4_
|
165 |
+
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
166 |
+
Float4_ res;
|
167 |
+
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
168 |
+
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
169 |
+
return res;
|
170 |
+
}
|
171 |
+
|
172 |
+
// fp8x8 -> float8
|
173 |
+
template <>
|
174 |
+
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
175 |
+
Float4_ tmp1, tmp2;
|
176 |
+
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
177 |
+
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
178 |
+
Float8_ res;
|
179 |
+
res.x = tmp1.x;
|
180 |
+
res.y = tmp1.y;
|
181 |
+
res.z = tmp2.x;
|
182 |
+
res.w = tmp2.y;
|
183 |
+
return res;
|
184 |
+
}
|
185 |
+
|
186 |
+
// half -> fp8
|
187 |
+
template <>
|
188 |
+
__inline__ __device__ uint8_t
|
189 |
+
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
190 |
+
__half_raw tmp;
|
191 |
+
tmp.x = a;
|
192 |
+
|
193 |
+
hip_fp8 f8{static_cast<float>(tmp.data)};
|
194 |
+
return f8.data;
|
195 |
+
}
|
196 |
+
|
197 |
+
// bf16 -> fp8
|
198 |
+
template <>
|
199 |
+
__inline__ __device__ uint8_t
|
200 |
+
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
201 |
+
hip_fp8 res{__bfloat162float(a)};
|
202 |
+
return res.data;
|
203 |
+
}
|
204 |
+
|
205 |
+
// float -> fp8
|
206 |
+
template <>
|
207 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
208 |
+
hip_fp8 f8(a);
|
209 |
+
return f8.data;
|
210 |
+
}
|
211 |
+
|
212 |
+
// fp8x4 -> float4
|
213 |
+
template <>
|
214 |
+
__inline__ __device__ float4
|
215 |
+
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
216 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
217 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
218 |
+
return res;
|
219 |
+
}
|
220 |
+
|
221 |
+
// float2 -> half2
|
222 |
+
template <>
|
223 |
+
__inline__ __device__ uint32_t
|
224 |
+
vec_conversion<uint32_t, float2>(const float2& a) {
|
225 |
+
union {
|
226 |
+
half2 float16;
|
227 |
+
uint32_t uint32;
|
228 |
+
};
|
229 |
+
|
230 |
+
float16 = __float22half2_rn(a);
|
231 |
+
return uint32;
|
232 |
+
}
|
233 |
+
|
234 |
+
// Float4 -> half2x2
|
235 |
+
template <>
|
236 |
+
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
237 |
+
uint2 b;
|
238 |
+
float2 val;
|
239 |
+
val.x = a.x.x;
|
240 |
+
val.y = a.x.y;
|
241 |
+
b.x = vec_conversion<uint32_t, float2>(val);
|
242 |
+
|
243 |
+
val.x = a.y.x;
|
244 |
+
val.y = a.y.y;
|
245 |
+
b.y = vec_conversion<uint32_t, float2>(val);
|
246 |
+
return b;
|
247 |
+
}
|
248 |
+
|
249 |
+
// Float4 -> float4
|
250 |
+
template <>
|
251 |
+
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
252 |
+
float4 b;
|
253 |
+
b.x = a.x.x;
|
254 |
+
b.y = a.x.y;
|
255 |
+
b.z = a.y.x;
|
256 |
+
b.w = a.y.y;
|
257 |
+
return b;
|
258 |
+
}
|
259 |
+
|
260 |
+
// Float8 -> half2x4
|
261 |
+
template <>
|
262 |
+
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
263 |
+
uint4 b;
|
264 |
+
b.x = vec_conversion<uint32_t, float2>(a.x);
|
265 |
+
b.y = vec_conversion<uint32_t, float2>(a.y);
|
266 |
+
b.z = vec_conversion<uint32_t, float2>(a.z);
|
267 |
+
b.w = vec_conversion<uint32_t, float2>(a.w);
|
268 |
+
return b;
|
269 |
+
}
|
270 |
+
|
271 |
+
// float2 -> bfloat162
|
272 |
+
template <>
|
273 |
+
__inline__ __device__ __nv_bfloat162
|
274 |
+
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
275 |
+
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
276 |
+
return b;
|
277 |
+
}
|
278 |
+
|
279 |
+
// Float4 -> bfloat162x2
|
280 |
+
template <>
|
281 |
+
__inline__ __device__ bf16_4_t
|
282 |
+
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
283 |
+
bf16_4_t b;
|
284 |
+
b.x = __float22bfloat162_rn(a.x);
|
285 |
+
b.y = __float22bfloat162_rn(a.y);
|
286 |
+
return b;
|
287 |
+
}
|
288 |
+
|
289 |
+
// Float8 -> bfloat162x4
|
290 |
+
template <>
|
291 |
+
__inline__ __device__ bf16_8_t
|
292 |
+
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
293 |
+
bf16_8_t b;
|
294 |
+
b.x = __float22bfloat162_rn(a.x);
|
295 |
+
b.y = __float22bfloat162_rn(a.y);
|
296 |
+
b.z = __float22bfloat162_rn(a.z);
|
297 |
+
b.w = __float22bfloat162_rn(a.w);
|
298 |
+
return b;
|
299 |
+
}
|
300 |
+
|
301 |
+
/* Scaled and vectorized conversions, for data exchange between high and low
|
302 |
+
precision domains
|
303 |
+
|
304 |
+
Convention of the scale in API, e.g: FP8_data = Quantization(
|
305 |
+
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
306 |
+
scale => HP
|
307 |
+
|
308 |
+
*/
|
309 |
+
|
310 |
+
// fp8 -> half
|
311 |
+
template <>
|
312 |
+
__inline__ __device__ uint16_t
|
313 |
+
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
314 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
315 |
+
__half_raw res;
|
316 |
+
res.data = static_cast<float>(f8) * scale;
|
317 |
+
return res.x;
|
318 |
+
}
|
319 |
+
|
320 |
+
// fp8x2 -> half2
|
321 |
+
template <>
|
322 |
+
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
323 |
+
const uint16_t& a, const float scale) {
|
324 |
+
#if defined(__HIP__MI300__) && \
|
325 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
326 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
327 |
+
union {
|
328 |
+
__half2_raw h2r;
|
329 |
+
uint32_t ui32;
|
330 |
+
} tmp;
|
331 |
+
tmp.h2r.x.data = f2[0] * scale;
|
332 |
+
tmp.h2r.y.data = f2[1] * scale;
|
333 |
+
return tmp.ui32;
|
334 |
+
#else
|
335 |
+
union {
|
336 |
+
uint16_t u16[2];
|
337 |
+
uint32_t u32;
|
338 |
+
} tmp;
|
339 |
+
|
340 |
+
tmp.u16[0] =
|
341 |
+
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
342 |
+
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
343 |
+
static_cast<uint8_t>(a >> 8U), scale);
|
344 |
+
return tmp.u32;
|
345 |
+
#endif
|
346 |
+
}
|
347 |
+
|
348 |
+
// fp8x4 -> half2x2
|
349 |
+
template <>
|
350 |
+
__inline__ __device__ uint2
|
351 |
+
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
352 |
+
union {
|
353 |
+
uint2 u32x2;
|
354 |
+
uint32_t u32[2];
|
355 |
+
} tmp;
|
356 |
+
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
357 |
+
tmp.u32[1] =
|
358 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
359 |
+
return tmp.u32x2;
|
360 |
+
}
|
361 |
+
|
362 |
+
// fp8x8 -> half2x4
|
363 |
+
template <>
|
364 |
+
__inline__ __device__ uint4
|
365 |
+
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
366 |
+
union {
|
367 |
+
uint4 u64x2;
|
368 |
+
uint2 u64[2];
|
369 |
+
} tmp;
|
370 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
371 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
372 |
+
return tmp.u64x2;
|
373 |
+
}
|
374 |
+
|
375 |
+
using __nv_bfloat16 = __hip_bfloat16;
|
376 |
+
|
377 |
+
// fp8 -> __nv_bfloat16
|
378 |
+
template <>
|
379 |
+
__inline__ __device__ __nv_bfloat16
|
380 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
381 |
+
const float scale) {
|
382 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
383 |
+
float f{f8};
|
384 |
+
return __float2bfloat16(f * scale);
|
385 |
+
}
|
386 |
+
|
387 |
+
using __nv_bfloat162 = __hip_bfloat162;
|
388 |
+
|
389 |
+
// fp8x2 -> __nv_bfloat162
|
390 |
+
template <>
|
391 |
+
__inline__ __device__ __nv_bfloat162
|
392 |
+
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
393 |
+
const float scale) {
|
394 |
+
__nv_bfloat162 res;
|
395 |
+
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
396 |
+
res.y =
|
397 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
398 |
+
return res;
|
399 |
+
}
|
400 |
+
|
401 |
+
// fp8x4 -> bf16_4_t
|
402 |
+
template <>
|
403 |
+
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
404 |
+
const uint32_t& a, const float scale) {
|
405 |
+
bf16_4_t res;
|
406 |
+
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
407 |
+
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
408 |
+
scale);
|
409 |
+
return res;
|
410 |
+
}
|
411 |
+
|
412 |
+
// fp8x8 -> bf16_8_t
|
413 |
+
template <>
|
414 |
+
__inline__ __device__ bf16_8_t
|
415 |
+
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
416 |
+
bf16_4_t tmp1, tmp2;
|
417 |
+
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
418 |
+
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
419 |
+
bf16_8_t res;
|
420 |
+
res.x = tmp1.x;
|
421 |
+
res.y = tmp1.y;
|
422 |
+
res.z = tmp2.x;
|
423 |
+
res.w = tmp2.y;
|
424 |
+
return res;
|
425 |
+
}
|
426 |
+
|
427 |
+
// fp8 -> float
|
428 |
+
template <>
|
429 |
+
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
430 |
+
const uint8_t& a, const float scale) {
|
431 |
+
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
432 |
+
return static_cast<float>(fp8) * scale;
|
433 |
+
}
|
434 |
+
|
435 |
+
// fp8x2 -> float2
|
436 |
+
template <>
|
437 |
+
__inline__ __device__ float2
|
438 |
+
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
439 |
+
#if defined(__HIP__MI300__) && \
|
440 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
441 |
+
float2 res;
|
442 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
443 |
+
res.x = f2[0] * scale;
|
444 |
+
res.y = f2[1] * scale;
|
445 |
+
return res;
|
446 |
+
#else
|
447 |
+
float2 res;
|
448 |
+
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
449 |
+
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
450 |
+
scale);
|
451 |
+
return res;
|
452 |
+
#endif
|
453 |
+
}
|
454 |
+
|
455 |
+
// fp8x4 -> float4
|
456 |
+
template <>
|
457 |
+
__inline__ __device__ Float4_
|
458 |
+
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
459 |
+
Float4_ res;
|
460 |
+
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
461 |
+
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
462 |
+
return res;
|
463 |
+
}
|
464 |
+
|
465 |
+
// fp8x8 -> float8
|
466 |
+
template <>
|
467 |
+
__inline__ __device__ Float8_
|
468 |
+
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
469 |
+
Float4_ tmp1, tmp2;
|
470 |
+
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
471 |
+
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
472 |
+
Float8_ res;
|
473 |
+
res.x = tmp1.x;
|
474 |
+
res.y = tmp1.y;
|
475 |
+
res.z = tmp2.x;
|
476 |
+
res.w = tmp2.y;
|
477 |
+
return res;
|
478 |
+
}
|
479 |
+
|
480 |
+
/* Quantize(HP / scale) => FP8 */
|
481 |
+
|
482 |
+
// TODO(Hai): vectorized to add
|
483 |
+
|
484 |
+
// half -> fp8
|
485 |
+
template <>
|
486 |
+
__inline__ __device__ uint8_t
|
487 |
+
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
488 |
+
__half_raw tmp;
|
489 |
+
tmp.x = a;
|
490 |
+
|
491 |
+
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
492 |
+
return f8.data;
|
493 |
+
}
|
494 |
+
|
495 |
+
// bf16 -> fp8
|
496 |
+
template <>
|
497 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
498 |
+
const __nv_bfloat16& a, const float scale) {
|
499 |
+
hip_fp8 res{__bfloat162float(a) / scale};
|
500 |
+
return res.data;
|
501 |
+
}
|
502 |
+
|
503 |
+
// float -> fp8
|
504 |
+
template <>
|
505 |
+
__inline__ __device__ uint8_t
|
506 |
+
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
507 |
+
hip_fp8 f8(a / scale);
|
508 |
+
return f8.data;
|
509 |
+
}
|
510 |
+
|
511 |
+
// fp8x4 -> float4
|
512 |
+
template <>
|
513 |
+
__inline__ __device__ float4
|
514 |
+
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
515 |
+
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
516 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
517 |
+
return res;
|
518 |
+
}
|
519 |
+
#endif // ENABLE_FP8
|
520 |
+
|
521 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
522 |
+
__inline__ __device__ Tout convert(const Tin& x) {
|
523 |
+
#ifdef ENABLE_FP8
|
524 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
525 |
+
return vec_conversion<Tout, Tin>(x);
|
526 |
+
}
|
527 |
+
#endif
|
528 |
+
assert(false);
|
529 |
+
return {}; // Squash missing return statement warning
|
530 |
+
}
|
531 |
+
|
532 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
533 |
+
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
534 |
+
#ifdef ENABLE_FP8
|
535 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
536 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
537 |
+
}
|
538 |
+
#endif
|
539 |
+
assert(false);
|
540 |
+
return {}; // Squash missing return statement warning
|
541 |
+
}
|
542 |
+
|
543 |
+
// The following macro is used to dispatch the conversion function based on
|
544 |
+
// the data type of the key and value cache. The FN is a macro that calls a
|
545 |
+
// function with template<typename scalar_t, typename cache_t,
|
546 |
+
// Fp8KVCacheDataType kv_dt>.
|
547 |
+
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
548 |
+
if (KV_DTYPE == "auto") { \
|
549 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
550 |
+
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
551 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
552 |
+
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
553 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
554 |
+
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
555 |
+
} else { \
|
556 |
+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
557 |
+
} \
|
558 |
+
} else { \
|
559 |
+
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
560 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
561 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
562 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
563 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
564 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
565 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
566 |
+
} else { \
|
567 |
+
TORCH_CHECK(false, \
|
568 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
569 |
+
} \
|
570 |
+
} else { \
|
571 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
572 |
+
} \
|
573 |
+
}
|
574 |
+
|
575 |
+
} // namespace fp8
|
576 |
+
#endif // USE_ROCM
|
577 |
+
} // namespace vllm
|
paged-attention/quantization/fp8/nvidia/quant_utils.cuh
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "../../../attention/attention_dtypes.h"
|
4 |
+
#include <assert.h>
|
5 |
+
#include <float.h>
|
6 |
+
#include <stdint.h>
|
7 |
+
#include <type_traits>
|
8 |
+
|
9 |
+
namespace vllm {
|
10 |
+
#ifndef USE_ROCM
|
11 |
+
|
12 |
+
namespace fp8 {
|
13 |
+
#ifdef ENABLE_FP8
|
14 |
+
|
15 |
+
#if 0 // Disable the following code to reduce the binary size.
|
16 |
+
template <typename Tout, typename Tin>
|
17 |
+
__inline__ __device__ Tout
|
18 |
+
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
19 |
+
return x;
|
20 |
+
}
|
21 |
+
|
22 |
+
// fp8 -> half
|
23 |
+
template <>
|
24 |
+
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
25 |
+
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
26 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
27 |
+
return res.x;
|
28 |
+
}
|
29 |
+
|
30 |
+
// fp8x2 -> half2
|
31 |
+
template <>
|
32 |
+
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
33 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
34 |
+
union {
|
35 |
+
uint16_t u16[2];
|
36 |
+
uint32_t u32;
|
37 |
+
} tmp;
|
38 |
+
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
39 |
+
tmp.u16[0] = res.x;
|
40 |
+
tmp.u16[1] = res.y;
|
41 |
+
return tmp.u32;
|
42 |
+
}
|
43 |
+
|
44 |
+
// fp8x4 -> half2x2
|
45 |
+
template <>
|
46 |
+
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
47 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
48 |
+
union {
|
49 |
+
uint2 u32x2;
|
50 |
+
uint32_t u32[2];
|
51 |
+
} tmp;
|
52 |
+
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
53 |
+
tmp.u32[1] =
|
54 |
+
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
55 |
+
return tmp.u32x2;
|
56 |
+
}
|
57 |
+
|
58 |
+
// fp8x8 -> half2x4
|
59 |
+
template <>
|
60 |
+
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
61 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
62 |
+
union {
|
63 |
+
uint4 u64x2;
|
64 |
+
uint2 u64[2];
|
65 |
+
} tmp;
|
66 |
+
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
67 |
+
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
68 |
+
return tmp.u64x2;
|
69 |
+
}
|
70 |
+
|
71 |
+
// fp8 -> __nv_bfloat16
|
72 |
+
template <>
|
73 |
+
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
74 |
+
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
75 |
+
// Note there is no direct convert function from fp8 to bf16.
|
76 |
+
// fp8 -> half
|
77 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
78 |
+
// half -> float -> bf16
|
79 |
+
float tmp = half_to_float(res.x);
|
80 |
+
return __float2bfloat16(tmp);
|
81 |
+
}
|
82 |
+
|
83 |
+
// fp8x2 -> __nv_bfloat162
|
84 |
+
template <>
|
85 |
+
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
86 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
87 |
+
__nv_bfloat162 res;
|
88 |
+
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
89 |
+
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
90 |
+
return res;
|
91 |
+
}
|
92 |
+
|
93 |
+
// fp8x4 -> bf16_4_t
|
94 |
+
template <>
|
95 |
+
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
96 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
97 |
+
bf16_4_t res;
|
98 |
+
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
99 |
+
res.y =
|
100 |
+
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
101 |
+
return res;
|
102 |
+
}
|
103 |
+
|
104 |
+
// fp8x8 -> bf16_8_t
|
105 |
+
template <>
|
106 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
107 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
108 |
+
bf16_4_t tmp1, tmp2;
|
109 |
+
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
110 |
+
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
111 |
+
bf16_8_t res;
|
112 |
+
res.x = tmp1.x;
|
113 |
+
res.y = tmp1.y;
|
114 |
+
res.z = tmp2.x;
|
115 |
+
res.w = tmp2.y;
|
116 |
+
return res;
|
117 |
+
}
|
118 |
+
|
119 |
+
// fp8 -> float
|
120 |
+
template <>
|
121 |
+
__inline__ __device__ float
|
122 |
+
vec_conversion<float, uint8_t>(const uint8_t &a,
|
123 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
124 |
+
// fp8 -> half
|
125 |
+
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
126 |
+
// half -> float
|
127 |
+
return half_to_float(tmp);
|
128 |
+
}
|
129 |
+
|
130 |
+
// fp8x2 -> float2
|
131 |
+
template <>
|
132 |
+
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
133 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
134 |
+
// fp8x2 -> half2
|
135 |
+
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
136 |
+
// half2 -> float2
|
137 |
+
return half2_to_float2(tmp);
|
138 |
+
}
|
139 |
+
|
140 |
+
// fp8x4 -> float4
|
141 |
+
template <>
|
142 |
+
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
143 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
144 |
+
Float4_ res;
|
145 |
+
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
146 |
+
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
147 |
+
return res;
|
148 |
+
}
|
149 |
+
|
150 |
+
// fp8x8 -> float8
|
151 |
+
template <>
|
152 |
+
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
153 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
154 |
+
Float4_ tmp1, tmp2;
|
155 |
+
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
156 |
+
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
157 |
+
Float8_ res;
|
158 |
+
res.x = tmp1.x;
|
159 |
+
res.y = tmp1.y;
|
160 |
+
res.z = tmp2.x;
|
161 |
+
res.w = tmp2.y;
|
162 |
+
return res;
|
163 |
+
}
|
164 |
+
|
165 |
+
// half -> fp8
|
166 |
+
template <>
|
167 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
168 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
169 |
+
__half_raw tmp;
|
170 |
+
tmp.x = a;
|
171 |
+
__nv_fp8_storage_t res =
|
172 |
+
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
173 |
+
return (uint8_t)res;
|
174 |
+
}
|
175 |
+
|
176 |
+
// bf16 -> fp8
|
177 |
+
template <>
|
178 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
179 |
+
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
180 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
181 |
+
assert(false);
|
182 |
+
#else
|
183 |
+
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
184 |
+
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
185 |
+
return (uint8_t)res;
|
186 |
+
#endif
|
187 |
+
}
|
188 |
+
|
189 |
+
// float -> fp8
|
190 |
+
template <>
|
191 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
192 |
+
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
193 |
+
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
194 |
+
return (uint8_t)res;
|
195 |
+
}
|
196 |
+
|
197 |
+
// fp8x4 -> float4
|
198 |
+
template <>
|
199 |
+
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
200 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
201 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
202 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
203 |
+
return res;
|
204 |
+
}
|
205 |
+
|
206 |
+
template <>
|
207 |
+
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
208 |
+
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
209 |
+
union {
|
210 |
+
half2 float16;
|
211 |
+
uint32_t uint32;
|
212 |
+
};
|
213 |
+
|
214 |
+
float16 = __float22half2_rn(a);
|
215 |
+
return uint32;
|
216 |
+
}
|
217 |
+
|
218 |
+
template <>
|
219 |
+
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
220 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
221 |
+
uint2 b;
|
222 |
+
float2 val;
|
223 |
+
val.x = a.x.x;
|
224 |
+
val.y = a.x.y;
|
225 |
+
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
226 |
+
|
227 |
+
val.x = a.y.x;
|
228 |
+
val.y = a.y.y;
|
229 |
+
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
230 |
+
|
231 |
+
return b;
|
232 |
+
}
|
233 |
+
|
234 |
+
template <>
|
235 |
+
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
236 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
237 |
+
float4 b;
|
238 |
+
b.x = a.x.x;
|
239 |
+
b.y = a.x.y;
|
240 |
+
b.z = a.y.x;
|
241 |
+
b.w = a.y.y;
|
242 |
+
return b;
|
243 |
+
}
|
244 |
+
|
245 |
+
template <>
|
246 |
+
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
247 |
+
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
248 |
+
uint4 b;
|
249 |
+
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
250 |
+
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
251 |
+
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
252 |
+
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
253 |
+
return b;
|
254 |
+
}
|
255 |
+
|
256 |
+
template <>
|
257 |
+
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
258 |
+
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
259 |
+
__nv_bfloat162 b;
|
260 |
+
from_float(b, a);
|
261 |
+
return b;
|
262 |
+
}
|
263 |
+
|
264 |
+
template <>
|
265 |
+
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
266 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
267 |
+
bf16_4_t b;
|
268 |
+
from_float(b, a);
|
269 |
+
return b;
|
270 |
+
}
|
271 |
+
|
272 |
+
template <>
|
273 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
274 |
+
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
275 |
+
bf16_8_t b;
|
276 |
+
from_float(b, a);
|
277 |
+
return b;
|
278 |
+
}
|
279 |
+
#endif
|
280 |
+
|
281 |
+
/* Scaled and vectorized conversions, for data exchange between high and low
|
282 |
+
precision domains Convention of the scale in API, e.g: FP8_data =
|
283 |
+
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
284 |
+
Dequant(FP8) * scale => HP
|
285 |
+
*/
|
286 |
+
|
287 |
+
template <typename Tout, typename Tin>
|
288 |
+
__inline__ __device__ Tout scaled_vec_conversion(
|
289 |
+
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
290 |
+
return x;
|
291 |
+
}
|
292 |
+
|
293 |
+
// fp8 -> half
|
294 |
+
template <>
|
295 |
+
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
296 |
+
const uint8_t& a, const float scale,
|
297 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
298 |
+
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
299 |
+
return float_to_half(half_to_float(tmp.x) * scale);
|
300 |
+
}
|
301 |
+
|
302 |
+
// fp8x2 -> half2
|
303 |
+
template <>
|
304 |
+
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
305 |
+
const uint16_t& a, const float scale,
|
306 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
307 |
+
union {
|
308 |
+
uint16_t u16[2];
|
309 |
+
uint32_t u32;
|
310 |
+
} tmp;
|
311 |
+
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
312 |
+
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
313 |
+
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
314 |
+
return tmp.u32;
|
315 |
+
}
|
316 |
+
|
317 |
+
// fp8x4 -> half2x2
|
318 |
+
template <>
|
319 |
+
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
320 |
+
const uint32_t& a, const float scale,
|
321 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
322 |
+
union {
|
323 |
+
uint2 u32x2;
|
324 |
+
uint32_t u32[2];
|
325 |
+
} tmp;
|
326 |
+
tmp.u32[0] =
|
327 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
328 |
+
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
329 |
+
scale, fp8_type);
|
330 |
+
return tmp.u32x2;
|
331 |
+
}
|
332 |
+
|
333 |
+
// fp8x8 -> half2x4
|
334 |
+
template <>
|
335 |
+
__inline__ __device__ uint4
|
336 |
+
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
337 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
338 |
+
union {
|
339 |
+
uint4 u64x2;
|
340 |
+
uint2 u64[2];
|
341 |
+
} tmp;
|
342 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
343 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
344 |
+
return tmp.u64x2;
|
345 |
+
}
|
346 |
+
|
347 |
+
// fp8 -> __nv_bfloat16
|
348 |
+
template <>
|
349 |
+
__inline__ __device__ __nv_bfloat16
|
350 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
351 |
+
const uint8_t& a, const float scale,
|
352 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
353 |
+
// Note there is no direct convert function from fp8 to bf16.
|
354 |
+
// fp8 -> half
|
355 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
356 |
+
// half -> float -> bf16
|
357 |
+
float tmp = half_to_float(res.x);
|
358 |
+
return __float2bfloat16(tmp * scale);
|
359 |
+
}
|
360 |
+
|
361 |
+
// fp8x2 -> __nv_bfloat162
|
362 |
+
template <>
|
363 |
+
__inline__ __device__ __nv_bfloat162
|
364 |
+
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
365 |
+
const uint16_t& a, const float scale,
|
366 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
367 |
+
__nv_bfloat162 res;
|
368 |
+
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
369 |
+
fp8_type);
|
370 |
+
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
371 |
+
scale, fp8_type);
|
372 |
+
return res;
|
373 |
+
}
|
374 |
+
|
375 |
+
// fp8x4 -> bf16_4_t
|
376 |
+
template <>
|
377 |
+
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
378 |
+
const uint32_t& a, const float scale,
|
379 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
380 |
+
bf16_4_t res;
|
381 |
+
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
382 |
+
fp8_type);
|
383 |
+
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
384 |
+
scale, fp8_type);
|
385 |
+
return res;
|
386 |
+
}
|
387 |
+
|
388 |
+
// fp8x8 -> bf16_8_t
|
389 |
+
template <>
|
390 |
+
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
391 |
+
const uint2& a, const float scale,
|
392 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
393 |
+
bf16_4_t tmp1, tmp2;
|
394 |
+
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
395 |
+
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
396 |
+
bf16_8_t res;
|
397 |
+
res.x = tmp1.x;
|
398 |
+
res.y = tmp1.y;
|
399 |
+
res.z = tmp2.x;
|
400 |
+
res.w = tmp2.y;
|
401 |
+
return res;
|
402 |
+
}
|
403 |
+
|
404 |
+
// fp8 -> float
|
405 |
+
template <>
|
406 |
+
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
407 |
+
const uint8_t& a, const float scale,
|
408 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
409 |
+
// fp8 -> half
|
410 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
411 |
+
uint16_t tmp = res.x;
|
412 |
+
|
413 |
+
// half -> float
|
414 |
+
return half_to_float(tmp) * scale;
|
415 |
+
}
|
416 |
+
|
417 |
+
// fp8x2 -> float2
|
418 |
+
template <>
|
419 |
+
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
420 |
+
const uint16_t& a, const float scale,
|
421 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
422 |
+
// fp8x2 -> half2
|
423 |
+
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
424 |
+
// half2 -> float2
|
425 |
+
return half2_to_float2(tmp);
|
426 |
+
}
|
427 |
+
|
428 |
+
// fp8x4 -> float4
|
429 |
+
template <>
|
430 |
+
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
431 |
+
const uint32_t& a, const float scale,
|
432 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
433 |
+
Float4_ res;
|
434 |
+
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
435 |
+
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
436 |
+
fp8_type);
|
437 |
+
return res;
|
438 |
+
}
|
439 |
+
|
440 |
+
// fp8x8 -> float8
|
441 |
+
template <>
|
442 |
+
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
443 |
+
const uint2& a, const float scale,
|
444 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
445 |
+
Float4_ tmp1, tmp2;
|
446 |
+
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
447 |
+
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
448 |
+
Float8_ res;
|
449 |
+
res.x = tmp1.x;
|
450 |
+
res.y = tmp1.y;
|
451 |
+
res.z = tmp2.x;
|
452 |
+
res.w = tmp2.y;
|
453 |
+
return res;
|
454 |
+
}
|
455 |
+
|
456 |
+
// half -> fp8
|
457 |
+
template <>
|
458 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
459 |
+
const uint16_t& a, const float scale,
|
460 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
461 |
+
__nv_fp8_storage_t res =
|
462 |
+
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
463 |
+
return (uint8_t)res;
|
464 |
+
}
|
465 |
+
|
466 |
+
// bf16 -> fp8
|
467 |
+
template <>
|
468 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
469 |
+
const __nv_bfloat16& a, const float scale,
|
470 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
471 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
472 |
+
assert(false);
|
473 |
+
#else
|
474 |
+
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
475 |
+
__NV_SATFINITE, fp8_type);
|
476 |
+
return (uint8_t)res;
|
477 |
+
#endif
|
478 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
479 |
+
}
|
480 |
+
|
481 |
+
// float -> fp8
|
482 |
+
template <>
|
483 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
484 |
+
const float& a, const float scale,
|
485 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
486 |
+
__nv_fp8_storage_t res =
|
487 |
+
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
488 |
+
return (uint8_t)res;
|
489 |
+
}
|
490 |
+
|
491 |
+
// fp8x4 -> float4
|
492 |
+
template <>
|
493 |
+
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
494 |
+
const uint32_t& a, const float scale,
|
495 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
496 |
+
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
497 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
498 |
+
return res;
|
499 |
+
}
|
500 |
+
#endif // ENABLE_FP8
|
501 |
+
|
502 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
503 |
+
__inline__ __device__ Tout convert(const Tin& x) {
|
504 |
+
#if 0 // Disable the following code to reduce the binary size.
|
505 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
506 |
+
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
507 |
+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
508 |
+
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
509 |
+
}
|
510 |
+
#endif
|
511 |
+
assert(false);
|
512 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
513 |
+
}
|
514 |
+
|
515 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
516 |
+
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
517 |
+
#ifdef ENABLE_FP8
|
518 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
519 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
520 |
+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
521 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
522 |
+
}
|
523 |
+
#endif
|
524 |
+
assert(false);
|
525 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
526 |
+
}
|
527 |
+
|
528 |
+
// The following macro is used to dispatch the conversion function based on
|
529 |
+
// the data type of the key and value cache. The FN is a macro that calls a
|
530 |
+
// function with template<typename scalar_t, typename cache_t,
|
531 |
+
// Fp8KVCacheDataType kv_dt>.
|
532 |
+
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
533 |
+
if (KV_DTYPE == "auto") { \
|
534 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
535 |
+
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
536 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
537 |
+
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
538 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
539 |
+
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
540 |
+
} else { \
|
541 |
+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
542 |
+
} \
|
543 |
+
} else { \
|
544 |
+
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
545 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
546 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
547 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
548 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
549 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
550 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
551 |
+
} else { \
|
552 |
+
TORCH_CHECK(false, \
|
553 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
554 |
+
} \
|
555 |
+
} else if (KV_DTYPE == "fp8_e5m2") { \
|
556 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
557 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
558 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
559 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
560 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
561 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
562 |
+
} else { \
|
563 |
+
TORCH_CHECK(false, \
|
564 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
565 |
+
} \
|
566 |
+
} else { \
|
567 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
568 |
+
} \
|
569 |
+
}
|
570 |
+
|
571 |
+
} // namespace fp8
|
572 |
+
#endif // not USE_ROCM
|
573 |
+
} // namespace vllm
|
tests/kernels/__init__.py
ADDED
File without changes
|
tests/kernels/allclose_default.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Reference default values of atol and rtol are from
|
4 |
+
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
5 |
+
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
6 |
+
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
|
7 |
+
|
8 |
+
|
9 |
+
def get_default_atol(output) -> float:
|
10 |
+
return default_atol[output.dtype]
|
11 |
+
|
12 |
+
|
13 |
+
def get_default_rtol(output) -> float:
|
14 |
+
return default_rtol[output.dtype]
|
tests/kernels/conftest.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import attention as ops
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
@pytest.fixture()
|
9 |
+
def kv_cache_factory():
|
10 |
+
return create_kv_caches_with_random
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.fixture()
|
14 |
+
def kv_cache_factory_flashinfer():
|
15 |
+
return create_kv_caches_with_random_flash
|
16 |
+
|
17 |
+
|
18 |
+
STR_DTYPE_TO_TORCH_DTYPE = {
|
19 |
+
"half": torch.half,
|
20 |
+
"bfloat16": torch.bfloat16,
|
21 |
+
"float": torch.float,
|
22 |
+
"fp8": torch.uint8,
|
23 |
+
"fp8_e4m3": torch.uint8,
|
24 |
+
"fp8_e5m2": torch.uint8,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def create_kv_caches_with_random(
|
29 |
+
num_blocks: int,
|
30 |
+
block_size: int,
|
31 |
+
num_layers: int,
|
32 |
+
num_heads: int,
|
33 |
+
head_size: int,
|
34 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
35 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
36 |
+
seed: int = 0,
|
37 |
+
device: Optional[str] = "cuda",
|
38 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
39 |
+
|
40 |
+
if cache_dtype == "fp8" and head_size % 16:
|
41 |
+
raise ValueError(
|
42 |
+
f"Does not support key cache of type fp8 with head_size {head_size}"
|
43 |
+
)
|
44 |
+
from attention.platforms import current_platform
|
45 |
+
|
46 |
+
current_platform.seed_everything(seed)
|
47 |
+
|
48 |
+
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
49 |
+
|
50 |
+
scale = head_size**-0.5
|
51 |
+
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
52 |
+
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
53 |
+
key_caches: List[torch.Tensor] = []
|
54 |
+
for _ in range(num_layers):
|
55 |
+
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
|
56 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
57 |
+
key_cache.uniform_(-scale, scale)
|
58 |
+
elif cache_dtype == "fp8":
|
59 |
+
_generate_random_fp8(key_cache, -scale, scale)
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
62 |
+
key_caches.append(key_cache)
|
63 |
+
|
64 |
+
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
65 |
+
value_caches: List[torch.Tensor] = []
|
66 |
+
for _ in range(num_layers):
|
67 |
+
value_cache = torch.empty(
|
68 |
+
size=value_cache_shape, dtype=torch_dtype, device=device
|
69 |
+
)
|
70 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
71 |
+
value_cache.uniform_(-scale, scale)
|
72 |
+
elif cache_dtype == "fp8":
|
73 |
+
_generate_random_fp8(value_cache, -scale, scale)
|
74 |
+
else:
|
75 |
+
raise ValueError(f"Does not support value cache of type {cache_dtype}")
|
76 |
+
value_caches.append(value_cache)
|
77 |
+
return key_caches, value_caches
|
78 |
+
|
79 |
+
|
80 |
+
def create_kv_caches_with_random_flash(
|
81 |
+
num_blocks: int,
|
82 |
+
block_size: int,
|
83 |
+
num_layers: int,
|
84 |
+
num_heads: int,
|
85 |
+
head_size: int,
|
86 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
87 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
88 |
+
seed: int = 0,
|
89 |
+
device: Optional[str] = "cuda",
|
90 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
91 |
+
from attention.platforms import current_platform
|
92 |
+
|
93 |
+
current_platform.seed_everything(seed)
|
94 |
+
|
95 |
+
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
96 |
+
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
97 |
+
scale = head_size**-0.5
|
98 |
+
|
99 |
+
key_caches: List[torch.Tensor] = []
|
100 |
+
value_caches: List[torch.Tensor] = []
|
101 |
+
|
102 |
+
for _ in range(num_layers):
|
103 |
+
key_value_cache = torch.empty(
|
104 |
+
size=key_value_cache_shape, dtype=torch_dtype, device=device
|
105 |
+
)
|
106 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
107 |
+
key_value_cache.uniform_(-scale, scale)
|
108 |
+
elif cache_dtype == "fp8":
|
109 |
+
_generate_random_fp8(key_value_cache, -scale, scale)
|
110 |
+
else:
|
111 |
+
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
112 |
+
key_caches.append(key_value_cache[:, 0])
|
113 |
+
value_caches.append(key_value_cache[:, 1])
|
114 |
+
return key_caches, value_caches
|
115 |
+
|
116 |
+
|
117 |
+
def get_kv_cache_torch_dtype(
|
118 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
119 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
120 |
+
) -> torch.dtype:
|
121 |
+
if isinstance(cache_dtype, str):
|
122 |
+
if cache_dtype == "auto":
|
123 |
+
if isinstance(model_dtype, str):
|
124 |
+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
125 |
+
elif isinstance(model_dtype, torch.dtype):
|
126 |
+
torch_dtype = model_dtype
|
127 |
+
else:
|
128 |
+
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
129 |
+
elif cache_dtype in ["half", "bfloat16", "float"]:
|
130 |
+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
131 |
+
elif cache_dtype == "fp8":
|
132 |
+
torch_dtype = torch.uint8
|
133 |
+
else:
|
134 |
+
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
135 |
+
elif isinstance(cache_dtype, torch.dtype):
|
136 |
+
torch_dtype = cache_dtype
|
137 |
+
else:
|
138 |
+
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
139 |
+
return torch_dtype
|
140 |
+
|
141 |
+
|
142 |
+
def _generate_random_fp8(
|
143 |
+
tensor: torch.Tensor,
|
144 |
+
low: float,
|
145 |
+
high: float,
|
146 |
+
) -> None:
|
147 |
+
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
148 |
+
# it may occur Inf or NaN if we directly use torch.randint
|
149 |
+
# to generate random data for fp8 data.
|
150 |
+
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
151 |
+
# | E4M3 | E5M2
|
152 |
+
# -----|-------------|-------------------
|
153 |
+
# Inf | N/A | s.11111.00
|
154 |
+
# NaN | s.1111.111 | s.11111.{01,10,11}
|
155 |
+
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
156 |
+
tensor_tmp.uniform_(low, high)
|
157 |
+
ops.convert_fp8(tensor, tensor_tmp)
|
158 |
+
del tensor_tmp
|
tests/kernels/test_attention.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import List, Optional, Tuple
|
3 |
+
|
4 |
+
import attention as ops
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
from attention.platforms import current_platform
|
8 |
+
|
9 |
+
from .allclose_default import get_default_atol, get_default_rtol
|
10 |
+
from .utils import get_max_shared_memory_bytes, opcheck
|
11 |
+
|
12 |
+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
13 |
+
# This will change depending on the compute capability.
|
14 |
+
# - 512 as a buffer
|
15 |
+
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
16 |
+
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
17 |
+
# Reduce NUM_BLOCKS when it happens.
|
18 |
+
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
19 |
+
PARTITION_SIZE = 512
|
20 |
+
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
21 |
+
DTYPES = (
|
22 |
+
[torch.half, torch.bfloat16, torch.float]
|
23 |
+
if not current_platform.is_rocm()
|
24 |
+
else [torch.half, torch.bfloat16]
|
25 |
+
)
|
26 |
+
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
27 |
+
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
28 |
+
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
29 |
+
|
30 |
+
# This should be sync with get_supported_head_sizes() in
|
31 |
+
# vllm.attention.ops.paged_attn.PagedAttention
|
32 |
+
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
33 |
+
|
34 |
+
BLOCK_SIZES = [16, 32]
|
35 |
+
USE_ALIBI = [False, True]
|
36 |
+
KV_CACHE_DTYPE = ["auto", "fp8"]
|
37 |
+
SEEDS = [0]
|
38 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
39 |
+
|
40 |
+
|
41 |
+
def ref_masked_attention(
|
42 |
+
query: torch.Tensor,
|
43 |
+
key: torch.Tensor,
|
44 |
+
value: torch.Tensor,
|
45 |
+
scale: float,
|
46 |
+
attn_mask: Optional[torch.Tensor] = None,
|
47 |
+
) -> torch.Tensor:
|
48 |
+
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
49 |
+
if attn_mask is not None:
|
50 |
+
attn_weights = attn_weights + attn_mask.float()
|
51 |
+
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
52 |
+
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
def ref_single_query_cached_kv_attention(
|
57 |
+
output: torch.Tensor,
|
58 |
+
query: torch.Tensor,
|
59 |
+
num_queries_per_kv: int,
|
60 |
+
key_cache: torch.Tensor,
|
61 |
+
value_cache: torch.Tensor,
|
62 |
+
block_tables: torch.Tensor,
|
63 |
+
seq_lens: torch.Tensor,
|
64 |
+
scale: float,
|
65 |
+
alibi_slopes: Optional[torch.Tensor],
|
66 |
+
) -> None:
|
67 |
+
num_query_heads = query.shape[1]
|
68 |
+
num_kv_heads = value_cache.shape[1]
|
69 |
+
head_size = value_cache.shape[2]
|
70 |
+
block_size = value_cache.shape[3]
|
71 |
+
num_seqs = query.shape[0]
|
72 |
+
|
73 |
+
block_tables_lst = block_tables.cpu().tolist()
|
74 |
+
seq_lens_lst = seq_lens.cpu().tolist()
|
75 |
+
for i in range(num_seqs):
|
76 |
+
q = query[i].unsqueeze(0)
|
77 |
+
block_table = block_tables_lst[i]
|
78 |
+
seq_len = int(seq_lens_lst[i])
|
79 |
+
|
80 |
+
keys_lst: List[torch.Tensor] = []
|
81 |
+
values_lst: List[torch.Tensor] = []
|
82 |
+
for j in range(seq_len):
|
83 |
+
block_number = int(block_table[j // block_size])
|
84 |
+
block_offset = j % block_size
|
85 |
+
|
86 |
+
k = key_cache[block_number, :, :, block_offset, :]
|
87 |
+
k = k.reshape(num_kv_heads, head_size)
|
88 |
+
keys_lst.append(k)
|
89 |
+
|
90 |
+
v = value_cache[block_number, :, :, block_offset]
|
91 |
+
values_lst.append(v)
|
92 |
+
keys = torch.stack(keys_lst, dim=0)
|
93 |
+
values = torch.stack(values_lst, dim=0)
|
94 |
+
if num_queries_per_kv > 1:
|
95 |
+
# Handle MQA and GQA
|
96 |
+
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
97 |
+
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
98 |
+
|
99 |
+
alibi_bias = None
|
100 |
+
if alibi_slopes is not None:
|
101 |
+
# Create the ALiBi bias used in the paged attention kernel.
|
102 |
+
position_ids = torch.arange(seq_len).int()
|
103 |
+
alibi_bias = (position_ids - seq_len + 1).float()
|
104 |
+
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
|
105 |
+
|
106 |
+
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
107 |
+
out = out.view(num_query_heads, head_size)
|
108 |
+
output[i].copy_(out, non_blocking=True)
|
109 |
+
|
110 |
+
|
111 |
+
@pytest.mark.parametrize(
|
112 |
+
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
|
113 |
+
)
|
114 |
+
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
115 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
116 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
117 |
+
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
118 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
119 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
120 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
121 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
122 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
123 |
+
def test_paged_attention(
|
124 |
+
kv_cache_factory,
|
125 |
+
version: str,
|
126 |
+
num_seqs: int,
|
127 |
+
num_heads: Tuple[int, int],
|
128 |
+
head_size: int,
|
129 |
+
use_alibi: bool,
|
130 |
+
block_size: int,
|
131 |
+
dtype: torch.dtype,
|
132 |
+
kv_cache_dtype: str,
|
133 |
+
seed: int,
|
134 |
+
device: str,
|
135 |
+
) -> None:
|
136 |
+
if (kv_cache_dtype == "fp8" and head_size % 16) or (
|
137 |
+
version == "rocm" and head_size not in (64, 128)
|
138 |
+
):
|
139 |
+
pytest.skip()
|
140 |
+
|
141 |
+
current_platform.seed_everything(seed)
|
142 |
+
torch.set_default_device(device)
|
143 |
+
scale = float(1.0 / (head_size**0.5))
|
144 |
+
num_query_heads, num_kv_heads = num_heads
|
145 |
+
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
146 |
+
query.uniform_(-scale, scale)
|
147 |
+
|
148 |
+
assert num_query_heads % num_kv_heads == 0
|
149 |
+
num_queries_per_kv = num_query_heads // num_kv_heads
|
150 |
+
alibi_slopes = None
|
151 |
+
if use_alibi:
|
152 |
+
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
153 |
+
|
154 |
+
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
155 |
+
seq_lens[-1] = MAX_SEQ_LEN
|
156 |
+
max_seq_len = max(seq_lens)
|
157 |
+
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
158 |
+
|
159 |
+
# Create the block tables.
|
160 |
+
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
161 |
+
block_tables_lst: List[List[int]] = []
|
162 |
+
for _ in range(num_seqs):
|
163 |
+
block_table = [
|
164 |
+
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
165 |
+
]
|
166 |
+
block_tables_lst.append(block_table)
|
167 |
+
|
168 |
+
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
169 |
+
|
170 |
+
# Create the KV caches.
|
171 |
+
key_caches, value_caches = kv_cache_factory(
|
172 |
+
NUM_BLOCKS,
|
173 |
+
block_size,
|
174 |
+
1,
|
175 |
+
num_kv_heads,
|
176 |
+
head_size,
|
177 |
+
kv_cache_dtype,
|
178 |
+
dtype,
|
179 |
+
seed,
|
180 |
+
device,
|
181 |
+
)
|
182 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
183 |
+
|
184 |
+
# Using default kv_scale
|
185 |
+
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
186 |
+
|
187 |
+
# Call the paged attention kernel.
|
188 |
+
output = torch.empty_like(query)
|
189 |
+
if version == "v1":
|
190 |
+
ops.paged_attention_v1(
|
191 |
+
output,
|
192 |
+
query,
|
193 |
+
key_cache,
|
194 |
+
value_cache,
|
195 |
+
num_kv_heads,
|
196 |
+
scale,
|
197 |
+
block_tables,
|
198 |
+
seq_lens,
|
199 |
+
block_size,
|
200 |
+
max_seq_len,
|
201 |
+
alibi_slopes,
|
202 |
+
kv_cache_dtype,
|
203 |
+
k_scale,
|
204 |
+
v_scale,
|
205 |
+
)
|
206 |
+
|
207 |
+
opcheck(
|
208 |
+
ops.ops.paged_attention_v1,
|
209 |
+
(
|
210 |
+
output,
|
211 |
+
query,
|
212 |
+
key_cache,
|
213 |
+
value_cache,
|
214 |
+
num_kv_heads,
|
215 |
+
scale,
|
216 |
+
block_tables,
|
217 |
+
seq_lens,
|
218 |
+
block_size,
|
219 |
+
max_seq_len,
|
220 |
+
alibi_slopes,
|
221 |
+
kv_cache_dtype,
|
222 |
+
k_scale,
|
223 |
+
v_scale,
|
224 |
+
0,
|
225 |
+
0,
|
226 |
+
0,
|
227 |
+
64,
|
228 |
+
0,
|
229 |
+
),
|
230 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
231 |
+
)
|
232 |
+
|
233 |
+
elif version in ("v2", "rocm"):
|
234 |
+
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
235 |
+
assert PARTITION_SIZE % block_size == 0
|
236 |
+
num_seqs, num_heads, head_size = output.shape
|
237 |
+
tmp_output = torch.empty(
|
238 |
+
size=(num_seqs, num_heads, num_partitions, head_size),
|
239 |
+
dtype=output.dtype,
|
240 |
+
)
|
241 |
+
exp_sums = torch.empty(
|
242 |
+
size=(num_seqs, num_heads, num_partitions),
|
243 |
+
dtype=torch.float32,
|
244 |
+
)
|
245 |
+
max_logits = torch.empty_like(exp_sums)
|
246 |
+
if version == "v2":
|
247 |
+
ops.paged_attention_v2(
|
248 |
+
output,
|
249 |
+
exp_sums,
|
250 |
+
max_logits,
|
251 |
+
tmp_output,
|
252 |
+
query,
|
253 |
+
key_cache,
|
254 |
+
value_cache,
|
255 |
+
num_kv_heads,
|
256 |
+
scale,
|
257 |
+
block_tables,
|
258 |
+
seq_lens,
|
259 |
+
block_size,
|
260 |
+
max_seq_len,
|
261 |
+
alibi_slopes,
|
262 |
+
kv_cache_dtype,
|
263 |
+
k_scale,
|
264 |
+
v_scale,
|
265 |
+
)
|
266 |
+
|
267 |
+
opcheck(
|
268 |
+
ops.ops.paged_attention_v2,
|
269 |
+
(
|
270 |
+
output,
|
271 |
+
exp_sums,
|
272 |
+
max_logits,
|
273 |
+
tmp_output,
|
274 |
+
query,
|
275 |
+
key_cache,
|
276 |
+
value_cache,
|
277 |
+
num_kv_heads,
|
278 |
+
scale,
|
279 |
+
block_tables,
|
280 |
+
seq_lens,
|
281 |
+
block_size,
|
282 |
+
max_seq_len,
|
283 |
+
alibi_slopes,
|
284 |
+
kv_cache_dtype,
|
285 |
+
k_scale,
|
286 |
+
v_scale,
|
287 |
+
0,
|
288 |
+
0,
|
289 |
+
0,
|
290 |
+
64,
|
291 |
+
0,
|
292 |
+
),
|
293 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
294 |
+
)
|
295 |
+
|
296 |
+
else:
|
297 |
+
ops.paged_attention_rocm(
|
298 |
+
output,
|
299 |
+
exp_sums,
|
300 |
+
max_logits,
|
301 |
+
tmp_output,
|
302 |
+
query,
|
303 |
+
key_cache,
|
304 |
+
value_cache,
|
305 |
+
num_kv_heads,
|
306 |
+
scale,
|
307 |
+
block_tables,
|
308 |
+
seq_lens,
|
309 |
+
block_size,
|
310 |
+
max_seq_len,
|
311 |
+
alibi_slopes,
|
312 |
+
kv_cache_dtype,
|
313 |
+
k_scale,
|
314 |
+
v_scale,
|
315 |
+
)
|
316 |
+
|
317 |
+
opcheck(
|
318 |
+
torch.ops._rocm_C.paged_attention,
|
319 |
+
(
|
320 |
+
output,
|
321 |
+
exp_sums,
|
322 |
+
max_logits,
|
323 |
+
tmp_output,
|
324 |
+
query,
|
325 |
+
key_cache,
|
326 |
+
value_cache,
|
327 |
+
num_kv_heads,
|
328 |
+
scale,
|
329 |
+
block_tables,
|
330 |
+
seq_lens,
|
331 |
+
block_size,
|
332 |
+
max_seq_len,
|
333 |
+
alibi_slopes,
|
334 |
+
kv_cache_dtype,
|
335 |
+
k_scale,
|
336 |
+
v_scale,
|
337 |
+
),
|
338 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
339 |
+
)
|
340 |
+
|
341 |
+
else:
|
342 |
+
raise AssertionError(f"Unknown version: {version}")
|
343 |
+
|
344 |
+
# Run the reference implementation.
|
345 |
+
if kv_cache_dtype == "fp8":
|
346 |
+
# Convert cache data back to dtype.
|
347 |
+
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
348 |
+
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
349 |
+
dequantized_key_cache = torch.empty(
|
350 |
+
size=key_cache_shape, dtype=dtype, device=device
|
351 |
+
)
|
352 |
+
ops.convert_fp8(dequantized_key_cache, key_cache)
|
353 |
+
key_cache = dequantized_key_cache
|
354 |
+
|
355 |
+
value_cache_shape = value_cache.shape
|
356 |
+
dequantized_value_cache = torch.empty(
|
357 |
+
size=value_cache_shape, dtype=dtype, device=device
|
358 |
+
)
|
359 |
+
ops.convert_fp8(dequantized_value_cache, value_cache)
|
360 |
+
value_cache = dequantized_value_cache
|
361 |
+
|
362 |
+
ref_output = torch.empty_like(query)
|
363 |
+
ref_single_query_cached_kv_attention(
|
364 |
+
ref_output,
|
365 |
+
query,
|
366 |
+
num_queries_per_kv,
|
367 |
+
key_cache,
|
368 |
+
value_cache,
|
369 |
+
block_tables,
|
370 |
+
seq_lens,
|
371 |
+
scale,
|
372 |
+
alibi_slopes,
|
373 |
+
)
|
374 |
+
|
375 |
+
# NOTE(woosuk): Due to the kernel-level differences in the two
|
376 |
+
# implementations, there is a small numerical difference in the two
|
377 |
+
# outputs. Thus, we use a relaxed tolerance for the test.
|
378 |
+
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
379 |
+
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
380 |
+
|
381 |
+
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
382 |
+
# so we use a relaxed tolerance for the test.
|
383 |
+
atol, rtol = 1e-3, 1e-5
|
384 |
+
if kv_cache_dtype == "fp8":
|
385 |
+
atol, rtol = 1e-2, 1e-5
|
386 |
+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
387 |
+
|
388 |
+
|
389 |
+
def ref_multi_query_kv_attention(
|
390 |
+
cu_seq_lens: List[int],
|
391 |
+
query: torch.Tensor,
|
392 |
+
key: torch.Tensor,
|
393 |
+
value: torch.Tensor,
|
394 |
+
scale: float,
|
395 |
+
dtype: torch.dtype,
|
396 |
+
) -> torch.Tensor:
|
397 |
+
num_seqs = len(cu_seq_lens) - 1
|
398 |
+
ref_outputs: List[torch.Tensor] = []
|
399 |
+
for i in range(num_seqs):
|
400 |
+
start_idx = cu_seq_lens[i]
|
401 |
+
end_idx = cu_seq_lens[i + 1]
|
402 |
+
seq_len = end_idx - start_idx
|
403 |
+
|
404 |
+
# Create attention mask.
|
405 |
+
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
406 |
+
attn_mask = attn_mask * torch.finfo(dtype).min
|
407 |
+
attn_mask = attn_mask.to(dtype=dtype)
|
408 |
+
|
409 |
+
ref_output = ref_masked_attention(
|
410 |
+
query[start_idx:end_idx],
|
411 |
+
key[start_idx:end_idx],
|
412 |
+
value[start_idx:end_idx],
|
413 |
+
scale,
|
414 |
+
attn_mask=attn_mask,
|
415 |
+
)
|
416 |
+
ref_outputs.append(ref_output)
|
417 |
+
|
418 |
+
return torch.cat(ref_outputs, dim=0)
|
tests/kernels/test_cache.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import attention as ops
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
from attention.platforms import current_platform
|
8 |
+
|
9 |
+
from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
10 |
+
|
11 |
+
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
|
12 |
+
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
13 |
+
NUM_TOKENS = [42] # Arbitrary values for testing
|
14 |
+
NUM_LAYERS = [1] # Arbitrary values for testing
|
15 |
+
NUM_HEADS = [8] # Arbitrary values for testing
|
16 |
+
HEAD_SIZES = [64, 80, 120, 256]
|
17 |
+
BLOCK_SIZES = [8, 16, 32]
|
18 |
+
|
19 |
+
# Arbitrary values for testing
|
20 |
+
# don't make it too large. e.g. [1024, 36000] will OOM
|
21 |
+
NUM_BLOCKS = [1024, 10000]
|
22 |
+
|
23 |
+
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
24 |
+
SEEDS = [0]
|
25 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
26 |
+
|
27 |
+
# We assume fp8 is always enabled for testing.
|
28 |
+
KV_CACHE_DTYPE = ["auto", "fp8"]
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
32 |
+
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
33 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
34 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
35 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
36 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
37 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
38 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
39 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
40 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
41 |
+
@torch.inference_mode()
|
42 |
+
def test_copy_blocks(
|
43 |
+
kv_cache_factory,
|
44 |
+
num_mappings: int,
|
45 |
+
num_layers: int,
|
46 |
+
num_heads: int,
|
47 |
+
head_size: int,
|
48 |
+
block_size: int,
|
49 |
+
num_blocks: int,
|
50 |
+
dtype: torch.dtype,
|
51 |
+
seed: int,
|
52 |
+
kv_cache_dtype: str,
|
53 |
+
device: str,
|
54 |
+
) -> None:
|
55 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
56 |
+
pytest.skip()
|
57 |
+
current_platform.seed_everything(seed)
|
58 |
+
torch.set_default_device(device)
|
59 |
+
# Generate random block mappings where each source block is mapped to two
|
60 |
+
# destination blocks.
|
61 |
+
assert 2 * num_mappings <= num_blocks
|
62 |
+
src_blocks = random.sample(range(num_blocks), num_mappings)
|
63 |
+
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
64 |
+
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
65 |
+
block_mapping: List[Tuple[int, int]] = []
|
66 |
+
for i in range(num_mappings):
|
67 |
+
src = src_blocks[i]
|
68 |
+
dst1 = dst_blocks[2 * i]
|
69 |
+
dst2 = dst_blocks[2 * i + 1]
|
70 |
+
block_mapping.append((src, dst1))
|
71 |
+
block_mapping.append((src, dst2))
|
72 |
+
|
73 |
+
# Create the KV caches.
|
74 |
+
key_caches, value_caches = kv_cache_factory(
|
75 |
+
num_blocks,
|
76 |
+
block_size,
|
77 |
+
num_layers,
|
78 |
+
num_heads,
|
79 |
+
head_size,
|
80 |
+
kv_cache_dtype,
|
81 |
+
dtype,
|
82 |
+
seed,
|
83 |
+
device,
|
84 |
+
)
|
85 |
+
|
86 |
+
# Clone the KV caches.
|
87 |
+
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
88 |
+
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
89 |
+
|
90 |
+
# Call the copy blocks kernel.
|
91 |
+
block_mapping_tensor = torch.tensor(
|
92 |
+
block_mapping, dtype=torch.int64, device=device
|
93 |
+
).view(-1, 2)
|
94 |
+
|
95 |
+
opcheck(
|
96 |
+
ops.ops.copy_blocks,
|
97 |
+
(key_caches, value_caches, block_mapping_tensor),
|
98 |
+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
99 |
+
cond=(head_size == HEAD_SIZES[0]),
|
100 |
+
)
|
101 |
+
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
102 |
+
|
103 |
+
# Run the reference implementation.
|
104 |
+
for src, dst in block_mapping:
|
105 |
+
for cloned_key_cache in cloned_key_caches:
|
106 |
+
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
107 |
+
for cloned_value_cache in cloned_value_caches:
|
108 |
+
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
109 |
+
|
110 |
+
# Compare the results.
|
111 |
+
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
112 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
113 |
+
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
114 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
115 |
+
|
116 |
+
|
117 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
118 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
119 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
120 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
121 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
122 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
123 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
124 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
125 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
126 |
+
@torch.inference_mode()
|
127 |
+
def test_reshape_and_cache(
|
128 |
+
kv_cache_factory,
|
129 |
+
num_tokens: int,
|
130 |
+
num_heads: int,
|
131 |
+
head_size: int,
|
132 |
+
block_size: int,
|
133 |
+
num_blocks: int,
|
134 |
+
dtype: torch.dtype,
|
135 |
+
seed: int,
|
136 |
+
device: str,
|
137 |
+
kv_cache_dtype: str,
|
138 |
+
) -> None:
|
139 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
140 |
+
pytest.skip()
|
141 |
+
current_platform.seed_everything(seed)
|
142 |
+
torch.set_default_device(device)
|
143 |
+
# Create a random slot mapping.
|
144 |
+
num_slots = block_size * num_blocks
|
145 |
+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
146 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
|
147 |
+
|
148 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
149 |
+
_, key, value = qkv.unbind(dim=1)
|
150 |
+
|
151 |
+
# Create the KV caches.
|
152 |
+
key_caches, value_caches = kv_cache_factory(
|
153 |
+
num_blocks,
|
154 |
+
block_size,
|
155 |
+
1,
|
156 |
+
num_heads,
|
157 |
+
head_size,
|
158 |
+
kv_cache_dtype,
|
159 |
+
dtype,
|
160 |
+
seed,
|
161 |
+
device,
|
162 |
+
)
|
163 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
164 |
+
|
165 |
+
# Clone the KV caches.
|
166 |
+
if kv_cache_dtype == "fp8":
|
167 |
+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
168 |
+
ops.convert_fp8(cloned_key_cache, key_cache)
|
169 |
+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
170 |
+
ops.convert_fp8(cloned_value_cache, value_cache)
|
171 |
+
else:
|
172 |
+
cloned_key_cache = key_cache.clone()
|
173 |
+
cloned_value_cache = value_cache.clone()
|
174 |
+
|
175 |
+
# Using default kv_scale
|
176 |
+
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
177 |
+
|
178 |
+
# Call the reshape_and_cache kernel.
|
179 |
+
opcheck(
|
180 |
+
ops.ops.reshape_and_cache,
|
181 |
+
(
|
182 |
+
key,
|
183 |
+
value,
|
184 |
+
key_cache,
|
185 |
+
value_cache,
|
186 |
+
slot_mapping,
|
187 |
+
kv_cache_dtype,
|
188 |
+
k_scale,
|
189 |
+
v_scale,
|
190 |
+
),
|
191 |
+
cond=(head_size == HEAD_SIZES[0]),
|
192 |
+
)
|
193 |
+
ops.reshape_and_cache(
|
194 |
+
key,
|
195 |
+
value,
|
196 |
+
key_cache,
|
197 |
+
value_cache,
|
198 |
+
slot_mapping,
|
199 |
+
kv_cache_dtype,
|
200 |
+
k_scale,
|
201 |
+
v_scale,
|
202 |
+
)
|
203 |
+
|
204 |
+
if kv_cache_dtype == "fp8":
|
205 |
+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
206 |
+
ops.convert_fp8(result_key_cache, key_cache)
|
207 |
+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
208 |
+
ops.convert_fp8(result_value_cache, value_cache)
|
209 |
+
|
210 |
+
# Run the reference implementation.
|
211 |
+
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
212 |
+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
213 |
+
block_indicies_lst = block_indicies.cpu().tolist()
|
214 |
+
block_offsets = slot_mapping % block_size
|
215 |
+
block_offsets_lst = block_offsets.cpu().tolist()
|
216 |
+
for i in range(num_tokens):
|
217 |
+
block_idx = block_indicies_lst[i]
|
218 |
+
block_offset = block_offsets_lst[i]
|
219 |
+
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
220 |
+
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
221 |
+
|
222 |
+
if kv_cache_dtype == "fp8":
|
223 |
+
torch.testing.assert_close(
|
224 |
+
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
225 |
+
)
|
226 |
+
torch.testing.assert_close(
|
227 |
+
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
231 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
232 |
+
|
233 |
+
|
234 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
235 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
236 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
237 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
238 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
239 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
240 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
241 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
242 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
243 |
+
@torch.inference_mode()
|
244 |
+
def test_reshape_and_cache_flash(
|
245 |
+
kv_cache_factory_flashinfer,
|
246 |
+
num_tokens: int,
|
247 |
+
num_heads: int,
|
248 |
+
head_size: int,
|
249 |
+
block_size: int,
|
250 |
+
num_blocks: int,
|
251 |
+
dtype: torch.dtype,
|
252 |
+
seed: int,
|
253 |
+
device: str,
|
254 |
+
kv_cache_dtype: str,
|
255 |
+
) -> None:
|
256 |
+
current_platform.seed_everything(seed)
|
257 |
+
torch.set_default_device(device)
|
258 |
+
|
259 |
+
# Create a random slot mapping.
|
260 |
+
num_slots = block_size * num_blocks
|
261 |
+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
262 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
263 |
+
|
264 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
|
265 |
+
_, key, value = qkv.unbind(dim=1)
|
266 |
+
|
267 |
+
# Create the KV caches.
|
268 |
+
key_caches, value_caches = kv_cache_factory_flashinfer(
|
269 |
+
num_blocks,
|
270 |
+
block_size,
|
271 |
+
1,
|
272 |
+
num_heads,
|
273 |
+
head_size,
|
274 |
+
kv_cache_dtype,
|
275 |
+
dtype,
|
276 |
+
device=device,
|
277 |
+
)
|
278 |
+
key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous()
|
279 |
+
del key_caches
|
280 |
+
del value_caches
|
281 |
+
|
282 |
+
k_scale = (key.amax() / 256.0).to(torch.float32)
|
283 |
+
v_scale = (value.amax() / 256.0).to(torch.float32)
|
284 |
+
|
285 |
+
# Clone the KV caches.
|
286 |
+
if kv_cache_dtype == "fp8":
|
287 |
+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
288 |
+
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
|
289 |
+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
290 |
+
ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype)
|
291 |
+
else:
|
292 |
+
cloned_key_cache = key_cache.clone()
|
293 |
+
cloned_value_cache = value_cache.clone()
|
294 |
+
|
295 |
+
# Call the reshape_and_cache kernel.
|
296 |
+
opcheck(
|
297 |
+
ops.ops.reshape_and_cache_flash,
|
298 |
+
(
|
299 |
+
key,
|
300 |
+
value,
|
301 |
+
key_cache,
|
302 |
+
value_cache,
|
303 |
+
slot_mapping,
|
304 |
+
kv_cache_dtype,
|
305 |
+
k_scale,
|
306 |
+
v_scale,
|
307 |
+
),
|
308 |
+
cond=(head_size == HEAD_SIZES[0]),
|
309 |
+
)
|
310 |
+
ops.reshape_and_cache_flash(
|
311 |
+
key,
|
312 |
+
value,
|
313 |
+
key_cache,
|
314 |
+
value_cache,
|
315 |
+
slot_mapping,
|
316 |
+
kv_cache_dtype,
|
317 |
+
k_scale,
|
318 |
+
v_scale,
|
319 |
+
)
|
320 |
+
|
321 |
+
if kv_cache_dtype == "fp8":
|
322 |
+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
323 |
+
ops.convert_fp8(
|
324 |
+
result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype
|
325 |
+
)
|
326 |
+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
327 |
+
ops.convert_fp8(
|
328 |
+
result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype
|
329 |
+
)
|
330 |
+
|
331 |
+
# Run the reference implementation.
|
332 |
+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
333 |
+
block_indicies_lst = block_indicies.cpu().tolist()
|
334 |
+
block_offsets = slot_mapping % block_size
|
335 |
+
block_offsets_lst = block_offsets.cpu().tolist()
|
336 |
+
for i in range(num_tokens):
|
337 |
+
block_idx = block_indicies_lst[i]
|
338 |
+
block_offset = block_offsets_lst[i]
|
339 |
+
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
340 |
+
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
341 |
+
|
342 |
+
if kv_cache_dtype == "fp8":
|
343 |
+
torch.testing.assert_close(
|
344 |
+
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
345 |
+
)
|
346 |
+
torch.testing.assert_close(
|
347 |
+
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
348 |
+
)
|
349 |
+
else:
|
350 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
351 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
352 |
+
|
353 |
+
|
354 |
+
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
355 |
+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
356 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
357 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
358 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
359 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
360 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
361 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
362 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
363 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
364 |
+
@torch.inference_mode()
|
365 |
+
def test_swap_blocks(
|
366 |
+
kv_cache_factory,
|
367 |
+
direction: Tuple[str, str],
|
368 |
+
num_mappings: int,
|
369 |
+
num_heads: int,
|
370 |
+
head_size: int,
|
371 |
+
block_size: int,
|
372 |
+
num_blocks: int,
|
373 |
+
dtype: torch.dtype,
|
374 |
+
seed: int,
|
375 |
+
device: str,
|
376 |
+
kv_cache_dtype: str,
|
377 |
+
) -> None:
|
378 |
+
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
379 |
+
pytest.skip()
|
380 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
381 |
+
pytest.skip()
|
382 |
+
|
383 |
+
current_platform.seed_everything(seed)
|
384 |
+
|
385 |
+
src_device = device if direction[0] == "cuda" else "cpu"
|
386 |
+
dst_device = device if direction[1] == "cuda" else "cpu"
|
387 |
+
|
388 |
+
src_blocks = random.sample(range(num_blocks), num_mappings)
|
389 |
+
# For the same device, mapping must not overlap
|
390 |
+
if src_device == dst_device:
|
391 |
+
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
392 |
+
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
393 |
+
else:
|
394 |
+
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
395 |
+
|
396 |
+
block_mapping = list(zip(src_blocks, dst_blocks))
|
397 |
+
block_mapping_tensor = torch.tensor(
|
398 |
+
block_mapping, dtype=torch.int64, device="cpu"
|
399 |
+
).view(-1, 2)
|
400 |
+
|
401 |
+
# Create the KV caches on the first device.
|
402 |
+
src_key_caches, src_value_caches = kv_cache_factory(
|
403 |
+
num_blocks,
|
404 |
+
block_size,
|
405 |
+
1,
|
406 |
+
num_heads,
|
407 |
+
head_size,
|
408 |
+
kv_cache_dtype,
|
409 |
+
dtype,
|
410 |
+
seed,
|
411 |
+
src_device,
|
412 |
+
)
|
413 |
+
|
414 |
+
# Create the KV caches on the second device.
|
415 |
+
dist_key_caches, dist_value_caches = kv_cache_factory(
|
416 |
+
num_blocks,
|
417 |
+
block_size,
|
418 |
+
1,
|
419 |
+
num_heads,
|
420 |
+
head_size,
|
421 |
+
kv_cache_dtype,
|
422 |
+
dtype,
|
423 |
+
seed,
|
424 |
+
dst_device,
|
425 |
+
)
|
426 |
+
|
427 |
+
src_key_caches_clone = src_key_caches[0].clone()
|
428 |
+
src_value_caches_clone = src_value_caches[0].clone()
|
429 |
+
|
430 |
+
# Call the swap_blocks kernel.
|
431 |
+
do_opcheck = head_size == HEAD_SIZES[0]
|
432 |
+
opcheck(
|
433 |
+
ops.ops.swap_blocks,
|
434 |
+
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
435 |
+
cond=do_opcheck,
|
436 |
+
)
|
437 |
+
opcheck(
|
438 |
+
ops.ops.swap_blocks,
|
439 |
+
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
440 |
+
cond=do_opcheck,
|
441 |
+
)
|
442 |
+
|
443 |
+
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
|
444 |
+
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
|
445 |
+
|
446 |
+
for src, dst in block_mapping:
|
447 |
+
torch.testing.assert_close(
|
448 |
+
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
|
449 |
+
)
|
450 |
+
torch.testing.assert_close(
|
451 |
+
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
456 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
457 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
458 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
459 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
460 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
461 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
462 |
+
@torch.inference_mode()
|
463 |
+
def test_fp8_e4m3_conversion(
|
464 |
+
num_heads: int,
|
465 |
+
head_size: int,
|
466 |
+
block_size: int,
|
467 |
+
num_blocks: int,
|
468 |
+
dtype: torch.dtype,
|
469 |
+
seed: int,
|
470 |
+
device: str,
|
471 |
+
) -> None:
|
472 |
+
current_platform.seed_everything(seed)
|
473 |
+
|
474 |
+
low = -224.0
|
475 |
+
high = 224.0
|
476 |
+
shape = (num_blocks, num_heads, head_size, block_size)
|
477 |
+
cache = torch.empty(shape, dtype=dtype, device=device)
|
478 |
+
cache.uniform_(low, high)
|
479 |
+
|
480 |
+
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
481 |
+
ops.convert_fp8(cache_fp8, cache)
|
482 |
+
|
483 |
+
converted_cache = torch.empty_like(cache)
|
484 |
+
ops.convert_fp8(converted_cache, cache_fp8)
|
485 |
+
|
486 |
+
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
|
tests/kernels/utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Kernel test utils"""
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import random
|
5 |
+
import unittest
|
6 |
+
from functools import lru_cache
|
7 |
+
from numbers import Number
|
8 |
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
9 |
+
|
10 |
+
import pytest
|
11 |
+
import torch
|
12 |
+
from torch._prims_common import TensorLikeType
|
13 |
+
|
14 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
15 |
+
# bugs related to this test in PyTorch 2.4.
|
16 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
17 |
+
"test_schema",
|
18 |
+
"test_autograd_registration",
|
19 |
+
"test_faketensor",
|
20 |
+
)
|
21 |
+
|
22 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
23 |
+
"test_schema",
|
24 |
+
"test_autograd_registration",
|
25 |
+
"test_faketensor",
|
26 |
+
"test_aot_dispatch_dynamic",
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
# Copied/modified from torch._refs.__init__.py
|
31 |
+
def fp8_allclose(
|
32 |
+
a: TensorLikeType,
|
33 |
+
b: TensorLikeType,
|
34 |
+
rtol: float = 1e-05,
|
35 |
+
atol: float = 1e-08,
|
36 |
+
equal_nan: bool = False,
|
37 |
+
) -> bool:
|
38 |
+
"""
|
39 |
+
Reference implementation of torch.allclose
|
40 |
+
"""
|
41 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
42 |
+
|
43 |
+
return bool(
|
44 |
+
torch.all(
|
45 |
+
torch.isclose(
|
46 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
47 |
+
)
|
48 |
+
).item()
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def compute_max_diff(output, output_ref):
|
53 |
+
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
54 |
+
torch.abs(output_ref)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
# A special version of op check that has a restricted default set of test_utils
|
59 |
+
# and a patched version of allclose that supports fp8 types.
|
60 |
+
def opcheck(
|
61 |
+
op: Union[
|
62 |
+
torch._ops.OpOverload,
|
63 |
+
torch._ops.OpOverloadPacket,
|
64 |
+
torch._library.custom_ops.CustomOpDef,
|
65 |
+
],
|
66 |
+
args: Tuple[Any, ...],
|
67 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
68 |
+
*,
|
69 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
70 |
+
raise_exception: bool = True,
|
71 |
+
cond: bool = True
|
72 |
+
) -> Dict[str, str]:
|
73 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
74 |
+
return (
|
75 |
+
torch.library.opcheck(
|
76 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
77 |
+
)
|
78 |
+
if cond
|
79 |
+
else {}
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
@lru_cache(maxsize=None)
|
84 |
+
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
85 |
+
"""Returns the maximum shared memory per thread block in bytes."""
|
86 |
+
from attention import ops
|
87 |
+
|
88 |
+
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
89 |
+
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
90 |
+
# will fail
|
91 |
+
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
92 |
+
return int(max_shared_mem)
|
torch-ext/attention/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._custom_ops import (
|
2 |
+
convert_fp8,
|
3 |
+
copy_blocks,
|
4 |
+
paged_attention_v1,
|
5 |
+
paged_attention_v2,
|
6 |
+
reshape_and_cache,
|
7 |
+
reshape_and_cache_flash,
|
8 |
+
swap_blocks,
|
9 |
+
)
|
10 |
+
from ._ops import ops
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"convert_fp8",
|
14 |
+
"copy_blocks",
|
15 |
+
"ops",
|
16 |
+
"paged_attention_v1",
|
17 |
+
"paged_attention_v2",
|
18 |
+
"reshape_and_cache",
|
19 |
+
"reshape_and_cache_flash",
|
20 |
+
"swap_blocks",
|
21 |
+
]
|
torch-ext/attention/_custom_ops.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
|
8 |
+
# page attention ops
|
9 |
+
def paged_attention_v1(
|
10 |
+
out: torch.Tensor,
|
11 |
+
query: torch.Tensor,
|
12 |
+
key_cache: torch.Tensor,
|
13 |
+
value_cache: torch.Tensor,
|
14 |
+
num_kv_heads: int,
|
15 |
+
scale: float,
|
16 |
+
block_tables: torch.Tensor,
|
17 |
+
seq_lens: torch.Tensor,
|
18 |
+
block_size: int,
|
19 |
+
max_seq_len: int,
|
20 |
+
alibi_slopes: Optional[torch.Tensor],
|
21 |
+
kv_cache_dtype: str,
|
22 |
+
k_scale: float,
|
23 |
+
v_scale: float,
|
24 |
+
tp_rank: int = 0,
|
25 |
+
blocksparse_local_blocks: int = 0,
|
26 |
+
blocksparse_vert_stride: int = 0,
|
27 |
+
blocksparse_block_size: int = 64,
|
28 |
+
blocksparse_head_sliding_step: int = 0,
|
29 |
+
) -> None:
|
30 |
+
ops.paged_attention_v1(
|
31 |
+
out,
|
32 |
+
query,
|
33 |
+
key_cache,
|
34 |
+
value_cache,
|
35 |
+
num_kv_heads,
|
36 |
+
scale,
|
37 |
+
block_tables,
|
38 |
+
seq_lens,
|
39 |
+
block_size,
|
40 |
+
max_seq_len,
|
41 |
+
alibi_slopes,
|
42 |
+
kv_cache_dtype,
|
43 |
+
k_scale,
|
44 |
+
v_scale,
|
45 |
+
tp_rank,
|
46 |
+
blocksparse_local_blocks,
|
47 |
+
blocksparse_vert_stride,
|
48 |
+
blocksparse_block_size,
|
49 |
+
blocksparse_head_sliding_step,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def paged_attention_v2(
|
54 |
+
out: torch.Tensor,
|
55 |
+
exp_sum: torch.Tensor,
|
56 |
+
max_logits: torch.Tensor,
|
57 |
+
tmp_out: torch.Tensor,
|
58 |
+
query: torch.Tensor,
|
59 |
+
key_cache: torch.Tensor,
|
60 |
+
value_cache: torch.Tensor,
|
61 |
+
num_kv_heads: int,
|
62 |
+
scale: float,
|
63 |
+
block_tables: torch.Tensor,
|
64 |
+
seq_lens: torch.Tensor,
|
65 |
+
block_size: int,
|
66 |
+
max_seq_len: int,
|
67 |
+
alibi_slopes: Optional[torch.Tensor],
|
68 |
+
kv_cache_dtype: str,
|
69 |
+
k_scale: float,
|
70 |
+
v_scale: float,
|
71 |
+
tp_rank: int = 0,
|
72 |
+
blocksparse_local_blocks: int = 0,
|
73 |
+
blocksparse_vert_stride: int = 0,
|
74 |
+
blocksparse_block_size: int = 64,
|
75 |
+
blocksparse_head_sliding_step: int = 0,
|
76 |
+
) -> None:
|
77 |
+
ops.paged_attention_v2(
|
78 |
+
out,
|
79 |
+
exp_sum,
|
80 |
+
max_logits,
|
81 |
+
tmp_out,
|
82 |
+
query,
|
83 |
+
key_cache,
|
84 |
+
value_cache,
|
85 |
+
num_kv_heads,
|
86 |
+
scale,
|
87 |
+
block_tables,
|
88 |
+
seq_lens,
|
89 |
+
block_size,
|
90 |
+
max_seq_len,
|
91 |
+
alibi_slopes,
|
92 |
+
kv_cache_dtype,
|
93 |
+
k_scale,
|
94 |
+
v_scale,
|
95 |
+
tp_rank,
|
96 |
+
blocksparse_local_blocks,
|
97 |
+
blocksparse_vert_stride,
|
98 |
+
blocksparse_block_size,
|
99 |
+
blocksparse_head_sliding_step,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def reshape_and_cache(
|
104 |
+
key: torch.Tensor,
|
105 |
+
value: torch.Tensor,
|
106 |
+
key_cache: torch.Tensor,
|
107 |
+
value_cache: torch.Tensor,
|
108 |
+
slot_mapping: torch.Tensor,
|
109 |
+
kv_cache_dtype: str,
|
110 |
+
k_scale: float,
|
111 |
+
v_scale: float,
|
112 |
+
) -> None:
|
113 |
+
ops.reshape_and_cache(
|
114 |
+
key,
|
115 |
+
value,
|
116 |
+
key_cache,
|
117 |
+
value_cache,
|
118 |
+
slot_mapping,
|
119 |
+
kv_cache_dtype,
|
120 |
+
k_scale,
|
121 |
+
v_scale,
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
def reshape_and_cache_flash(
|
126 |
+
key: torch.Tensor,
|
127 |
+
value: torch.Tensor,
|
128 |
+
key_cache: torch.Tensor,
|
129 |
+
value_cache: torch.Tensor,
|
130 |
+
slot_mapping: torch.Tensor,
|
131 |
+
kv_cache_dtype: str,
|
132 |
+
k_scale: torch.Tensor,
|
133 |
+
v_scale: torch.Tensor,
|
134 |
+
) -> None:
|
135 |
+
ops.reshape_and_cache_flash(
|
136 |
+
key,
|
137 |
+
value,
|
138 |
+
key_cache,
|
139 |
+
value_cache,
|
140 |
+
slot_mapping,
|
141 |
+
kv_cache_dtype,
|
142 |
+
k_scale,
|
143 |
+
v_scale,
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def copy_blocks(
|
148 |
+
key_caches: List[torch.Tensor],
|
149 |
+
value_caches: List[torch.Tensor],
|
150 |
+
block_mapping: torch.Tensor,
|
151 |
+
) -> None:
|
152 |
+
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
153 |
+
|
154 |
+
|
155 |
+
def swap_blocks(
|
156 |
+
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
157 |
+
) -> None:
|
158 |
+
ops.swap_blocks(src, dst, block_mapping)
|
159 |
+
|
160 |
+
|
161 |
+
def convert_fp8(
|
162 |
+
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
|
163 |
+
) -> None:
|
164 |
+
ops.convert_fp8(output, input, scale, kv_dtype)
|
165 |
+
|
166 |
+
|
167 |
+
__all__ = [
|
168 |
+
"convert_fp8",
|
169 |
+
"paged_attention_v1",
|
170 |
+
"paged_attention_v2",
|
171 |
+
"reshape_and_cache",
|
172 |
+
"copy_blocks",
|
173 |
+
]
|
torch-ext/attention/platforms.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from functools import lru_cache, wraps
|
5 |
+
from typing import Callable, ParamSpec, TypeVar
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
IS_ROCM = torch.version.hip is not None
|
11 |
+
|
12 |
+
|
13 |
+
class Platform(ABC):
|
14 |
+
@classmethod
|
15 |
+
def seed_everything(cls, seed: int) -> None:
|
16 |
+
"""
|
17 |
+
Set the seed of each random module.
|
18 |
+
`torch.manual_seed` will set seed on all devices.
|
19 |
+
|
20 |
+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
21 |
+
"""
|
22 |
+
random.seed(seed)
|
23 |
+
np.random.seed(seed)
|
24 |
+
torch.manual_seed(seed)
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def get_device_name(self, device_id: int = 0) -> str: ...
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def is_cuda(self) -> bool: ...
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def is_rocm(self) -> bool: ...
|
34 |
+
|
35 |
+
|
36 |
+
class CudaPlatform(Platform):
|
37 |
+
@classmethod
|
38 |
+
@lru_cache(maxsize=8)
|
39 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
40 |
+
return torch.cuda.get_device_name(0)
|
41 |
+
|
42 |
+
def is_cuda(self) -> bool:
|
43 |
+
return True
|
44 |
+
|
45 |
+
def is_rocm(self) -> bool:
|
46 |
+
return False
|
47 |
+
|
48 |
+
|
49 |
+
class RocmPlatform(Platform):
|
50 |
+
@classmethod
|
51 |
+
@lru_cache(maxsize=8)
|
52 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
53 |
+
return torch.cuda.get_device_name(device_id)
|
54 |
+
|
55 |
+
def is_cuda(self) -> bool:
|
56 |
+
return False
|
57 |
+
|
58 |
+
def is_rocm(self) -> bool:
|
59 |
+
return True
|
60 |
+
|
61 |
+
|
62 |
+
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
|
torch-ext/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
|
5 |
+
#include "torch_binding.h"
|
6 |
+
|
7 |
+
// Note on op signatures:
|
8 |
+
// The X_meta signatures are for the meta functions corresponding to op X.
|
9 |
+
// They must be kept in sync with the signature for X. Generally, only
|
10 |
+
// functions that return Tensors require a meta function.
|
11 |
+
//
|
12 |
+
// See the following links for detailed docs on op registration and function
|
13 |
+
// schemas.
|
14 |
+
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
15 |
+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
16 |
+
|
17 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
18 |
+
// Attention ops
|
19 |
+
// Compute the attention between an input query and the cached
|
20 |
+
// keys/values using PagedAttention.
|
21 |
+
ops.def(
|
22 |
+
"paged_attention_v1("
|
23 |
+
" Tensor! out, Tensor query, Tensor key_cache,"
|
24 |
+
" Tensor value_cache, int num_kv_heads, float scale,"
|
25 |
+
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
26 |
+
" int max_seq_len, Tensor? alibi_slopes,"
|
27 |
+
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
28 |
+
" int tp_rank, int blocksparse_local_blocks,"
|
29 |
+
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
30 |
+
" int blocksparse_head_sliding_step) -> ()");
|
31 |
+
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
32 |
+
|
33 |
+
// PagedAttention V2.
|
34 |
+
ops.def(
|
35 |
+
"paged_attention_v2("
|
36 |
+
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
37 |
+
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
38 |
+
" Tensor value_cache, int num_kv_heads, float scale,"
|
39 |
+
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
40 |
+
" int max_seq_len, Tensor? alibi_slopes,"
|
41 |
+
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
42 |
+
" int tp_rank, int blocksparse_local_blocks,"
|
43 |
+
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
44 |
+
" int blocksparse_head_sliding_step) -> ()");
|
45 |
+
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
46 |
+
|
47 |
+
// Swap in (out) the cache blocks from src to dst.
|
48 |
+
ops.def(
|
49 |
+
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
50 |
+
ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
51 |
+
|
52 |
+
// Copy the cache blocks from src to dst.
|
53 |
+
ops.def(
|
54 |
+
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
55 |
+
"Tensor block_mapping) -> ()");
|
56 |
+
ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
57 |
+
|
58 |
+
// Reshape the key and value tensors and cache them.
|
59 |
+
ops.def(
|
60 |
+
"reshape_and_cache(Tensor key, Tensor value,"
|
61 |
+
" Tensor! key_cache, Tensor! value_cache,"
|
62 |
+
" Tensor slot_mapping,"
|
63 |
+
" str kv_cache_dtype,"
|
64 |
+
" Tensor k_scale, Tensor v_scale) -> ()");
|
65 |
+
ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
66 |
+
|
67 |
+
// Reshape the key and value tensors and cache them.
|
68 |
+
ops.def(
|
69 |
+
"reshape_and_cache_flash(Tensor key, Tensor value,"
|
70 |
+
" Tensor! key_cache,"
|
71 |
+
" Tensor! value_cache,"
|
72 |
+
" Tensor slot_mapping,"
|
73 |
+
" str kv_cache_dtype,"
|
74 |
+
" Tensor k_scale, Tensor v_scale) -> ()");
|
75 |
+
ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
76 |
+
&reshape_and_cache_flash);
|
77 |
+
|
78 |
+
// Gets the specified device attribute.
|
79 |
+
ops.def("get_device_attribute(int attribute, int device_id) -> int");
|
80 |
+
ops.impl("get_device_attribute", &get_device_attribute);
|
81 |
+
|
82 |
+
// Gets the maximum shared memory per block device attribute.
|
83 |
+
ops.def(
|
84 |
+
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
85 |
+
ops.impl("get_max_shared_memory_per_block_device_attribute",
|
86 |
+
&get_max_shared_memory_per_block_device_attribute);
|
87 |
+
|
88 |
+
// Convert the key and value cache to fp8 data type.
|
89 |
+
ops.def(
|
90 |
+
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
91 |
+
"str kv_cache_dtype) -> ()");
|
92 |
+
ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
93 |
+
}
|
94 |
+
|
95 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void paged_attention_v1(
|
6 |
+
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
7 |
+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
8 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
9 |
+
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
10 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
11 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
12 |
+
const int64_t blocksparse_local_blocks,
|
13 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
14 |
+
const int64_t blocksparse_head_sliding_step);
|
15 |
+
|
16 |
+
void paged_attention_v2(
|
17 |
+
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
18 |
+
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
19 |
+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
20 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
21 |
+
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
22 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
23 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
24 |
+
const int64_t blocksparse_local_blocks,
|
25 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
26 |
+
const int64_t blocksparse_head_sliding_step);
|
27 |
+
|
28 |
+
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
29 |
+
const torch::Tensor& block_mapping);
|
30 |
+
|
31 |
+
// Note: the key_caches and value_caches vectors are constant but
|
32 |
+
// not the Tensors they contain. The vectors need to be const refs
|
33 |
+
// in order to satisfy pytorch's C++ operator registration code.
|
34 |
+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
35 |
+
std::vector<torch::Tensor> const& value_caches,
|
36 |
+
const torch::Tensor& block_mapping);
|
37 |
+
|
38 |
+
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
39 |
+
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
40 |
+
torch::Tensor& slot_mapping,
|
41 |
+
const std::string& kv_cache_dtype,
|
42 |
+
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
43 |
+
|
44 |
+
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
45 |
+
torch::Tensor& key_cache,
|
46 |
+
torch::Tensor& value_cache,
|
47 |
+
torch::Tensor& slot_mapping,
|
48 |
+
const std::string& kv_cache_dtype,
|
49 |
+
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
50 |
+
|
51 |
+
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
52 |
+
|
53 |
+
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
54 |
+
|
55 |
+
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
56 |
+
const double scale, const std::string& kv_cache_dtype);
|