kevinwang676 commited on
Commit
6988d86
·
verified ·
1 Parent(s): e3bb1ba

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +6 -10
utils.py CHANGED
@@ -47,13 +47,8 @@ def load_wav(wav, target_sr):
47
  return speech
48
 
49
 
50
- def convert_onnx_to_trt(trt_model, onnx_model, fp16):
51
  import tensorrt as trt
52
- _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
53
- _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
54
- _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
55
- input_names = ["x", "mask", "mu", "t", "spks", "cond"]
56
-
57
  logging.info("Converting onnx to trt...")
58
  network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
59
  logger = trt.Logger(trt.Logger.INFO)
@@ -61,7 +56,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
61
  network = builder.create_network(network_flags)
62
  parser = trt.OnnxParser(network, logger)
63
  config = builder.create_builder_config()
64
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
65
  if fp16:
66
  config.set_flag(trt.BuilderFlag.FP16)
67
  profile = builder.create_optimization_profile()
@@ -72,8 +67,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
72
  print(parser.get_error(error))
73
  raise ValueError('failed to parse {}'.format(onnx_model))
74
  # set input shapes
75
- for i in range(len(input_names)):
76
- profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
77
  tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
78
  # set input and output data type
79
  for i in range(network.num_inputs):
@@ -86,4 +81,5 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
86
  engine_bytes = builder.build_serialized_network(network, config)
87
  # save trt engine
88
  with open(trt_model, "wb") as f:
89
- f.write(engine_bytes)
 
 
47
  return speech
48
 
49
 
50
+ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
51
  import tensorrt as trt
 
 
 
 
 
52
  logging.info("Converting onnx to trt...")
53
  network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
54
  logger = trt.Logger(trt.Logger.INFO)
 
56
  network = builder.create_network(network_flags)
57
  parser = trt.OnnxParser(network, logger)
58
  config = builder.create_builder_config()
59
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
60
  if fp16:
61
  config.set_flag(trt.BuilderFlag.FP16)
62
  profile = builder.create_optimization_profile()
 
67
  print(parser.get_error(error))
68
  raise ValueError('failed to parse {}'.format(onnx_model))
69
  # set input shapes
70
+ for i in range(len(trt_kwargs['input_names'])):
71
+ profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
72
  tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
73
  # set input and output data type
74
  for i in range(network.num_inputs):
 
81
  engine_bytes = builder.build_serialized_network(network, config)
82
  # save trt engine
83
  with open(trt_model, "wb") as f:
84
+ f.write(engine_bytes)
85
+ logging.info("Succesfully convert onnx to trt...")