danieldk HF staff commited on
Commit
8294a79
·
1 Parent(s): f1c3798

quant_weights is a CPU function

Browse files
Files changed (1) hide show
  1. torch-ext/torch_binding.cpp +1 -1
torch-ext/torch_binding.cpp CHANGED
@@ -13,7 +13,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
13
  ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
14
  ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
15
  "bool return_unprocessed_quantized_tensor) -> Tensor[]");
16
- ops.impl("quant_weights", torch::kCUDA, &symmetric_quantize_last_axis_of_tensor);
17
  }
18
 
19
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
13
  ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
14
  ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
15
  "bool return_unprocessed_quantized_tensor) -> Tensor[]");
16
+ ops.impl("quant_weights", torch::kCPU, &symmetric_quantize_last_axis_of_tensor);
17
  }
18
 
19
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)