quant_weights is a CPU function
Browse files
torch-ext/torch_binding.cpp
CHANGED
@@ -13,7 +13,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
13 |
ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
|
14 |
ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
|
15 |
"bool return_unprocessed_quantized_tensor) -> Tensor[]");
|
16 |
-
ops.impl("quant_weights", torch::
|
17 |
}
|
18 |
|
19 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
13 |
ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
|
14 |
ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
|
15 |
"bool return_unprocessed_quantized_tensor) -> Tensor[]");
|
16 |
+
ops.impl("quant_weights", torch::kCPU, &symmetric_quantize_last_axis_of_tensor);
|
17 |
}
|
18 |
|
19 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|