Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
@@ -47,13 +47,8 @@ def load_wav(wav, target_sr):
|
|
47 |
return speech
|
48 |
|
49 |
|
50 |
-
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
51 |
import tensorrt as trt
|
52 |
-
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
|
53 |
-
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
|
54 |
-
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
|
55 |
-
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
|
56 |
-
|
57 |
logging.info("Converting onnx to trt...")
|
58 |
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
59 |
logger = trt.Logger(trt.Logger.INFO)
|
@@ -61,7 +56,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|
61 |
network = builder.create_network(network_flags)
|
62 |
parser = trt.OnnxParser(network, logger)
|
63 |
config = builder.create_builder_config()
|
64 |
-
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 <<
|
65 |
if fp16:
|
66 |
config.set_flag(trt.BuilderFlag.FP16)
|
67 |
profile = builder.create_optimization_profile()
|
@@ -72,8 +67,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|
72 |
print(parser.get_error(error))
|
73 |
raise ValueError('failed to parse {}'.format(onnx_model))
|
74 |
# set input shapes
|
75 |
-
for i in range(len(input_names)):
|
76 |
-
profile.set_shape(input_names[i],
|
77 |
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
78 |
# set input and output data type
|
79 |
for i in range(network.num_inputs):
|
@@ -86,4 +81,5 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|
86 |
engine_bytes = builder.build_serialized_network(network, config)
|
87 |
# save trt engine
|
88 |
with open(trt_model, "wb") as f:
|
89 |
-
f.write(engine_bytes)
|
|
|
|
47 |
return speech
|
48 |
|
49 |
|
50 |
+
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
51 |
import tensorrt as trt
|
|
|
|
|
|
|
|
|
|
|
52 |
logging.info("Converting onnx to trt...")
|
53 |
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
54 |
logger = trt.Logger(trt.Logger.INFO)
|
|
|
56 |
network = builder.create_network(network_flags)
|
57 |
parser = trt.OnnxParser(network, logger)
|
58 |
config = builder.create_builder_config()
|
59 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
60 |
if fp16:
|
61 |
config.set_flag(trt.BuilderFlag.FP16)
|
62 |
profile = builder.create_optimization_profile()
|
|
|
67 |
print(parser.get_error(error))
|
68 |
raise ValueError('failed to parse {}'.format(onnx_model))
|
69 |
# set input shapes
|
70 |
+
for i in range(len(trt_kwargs['input_names'])):
|
71 |
+
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
72 |
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
73 |
# set input and output data type
|
74 |
for i in range(network.num_inputs):
|
|
|
81 |
engine_bytes = builder.build_serialized_network(network, config)
|
82 |
# save trt engine
|
83 |
with open(trt_model, "wb") as f:
|
84 |
+
f.write(engine_bytes)
|
85 |
+
logging.info("Succesfully convert onnx to trt...")
|