kevinwang676 commited on
Commit
b1a0d87
·
verified ·
1 Parent(s): 05e87f6

Update flow.py

Browse files
Files changed (1) hide show
  1. flow.py +62 -15
flow.py CHANGED
@@ -15,7 +15,26 @@ import threading
15
  import torch
16
  import torch.nn.functional as F
17
  from matcha.models.components.flow_matching import BASECFM
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class ConditionalCFM(BASECFM):
21
  def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
@@ -125,22 +144,50 @@ class ConditionalCFM(BASECFM):
125
  if isinstance(self.estimator, torch.nn.Module):
126
  return self.estimator.forward(x, mask, mu, t, spks, cond)
127
  else:
128
- with self.lock:
129
- self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
- self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
- self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
- self.estimator.set_input_shape('t', (2,))
133
- self.estimator.set_input_shape('spks', (2, 80))
134
- self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  # run trt engine
136
- self.estimator.execute_v2([x.contiguous().data_ptr(),
137
- mask.contiguous().data_ptr(),
138
- mu.contiguous().data_ptr(),
139
- t.contiguous().data_ptr(),
140
- spks.contiguous().data_ptr(),
141
- cond.contiguous().data_ptr(),
142
- x.data_ptr()])
143
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
146
  """Computes diffusion loss
 
15
  import torch
16
  import torch.nn.functional as F
17
  from matcha.models.components.flow_matching import BASECFM
18
+ import queue
19
 
20
+ class EstimatorWrapper:
21
+ def __init__(self, estimator_engine, estimator_count=2,):
22
+ self.estimators = queue.Queue()
23
+ self.estimator_engine = estimator_engine
24
+ for _ in range(estimator_count):
25
+ estimator = estimator_engine.create_execution_context()
26
+ if estimator is not None:
27
+ self.estimators.put(estimator)
28
+
29
+ if self.estimators.empty():
30
+ raise Exception("No available estimator")
31
+
32
+ def acquire_estimator(self):
33
+ return self.estimators.get(), self.estimator_engine
34
+
35
+ def release_estimator(self, estimator):
36
+ self.estimators.put(estimator)
37
+ return
38
 
39
  class ConditionalCFM(BASECFM):
40
  def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
 
144
  if isinstance(self.estimator, torch.nn.Module):
145
  return self.estimator.forward(x, mask, mu, t, spks, cond)
146
  else:
147
+ if isinstance(self.estimator, EstimatorWrapper):
148
+ estimator, engine = self.estimator.acquire_estimator()
149
+
150
+ estimator.set_input_shape('x', (2, 80, x.size(2)))
151
+ estimator.set_input_shape('mask', (2, 1, x.size(2)))
152
+ estimator.set_input_shape('mu', (2, 80, x.size(2)))
153
+ estimator.set_input_shape('t', (2,))
154
+ estimator.set_input_shape('spks', (2, 80))
155
+ estimator.set_input_shape('cond', (2, 80, x.size(2)))
156
+
157
+ data_ptrs = [x.contiguous().data_ptr(),
158
+ mask.contiguous().data_ptr(),
159
+ mu.contiguous().data_ptr(),
160
+ t.contiguous().data_ptr(),
161
+ spks.contiguous().data_ptr(),
162
+ cond.contiguous().data_ptr(),
163
+ x.data_ptr()]
164
+
165
+ for idx, data_ptr in enumerate(data_ptrs):
166
+ estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
167
+
168
  # run trt engine
169
+ estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
170
+
171
+ torch.cuda.current_stream().synchronize()
172
+ self.estimator.release_estimator(estimator)
173
+ return x
174
+ else:
175
+ with self.lock:
176
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
177
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
178
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
179
+ self.estimator.set_input_shape('t', (2,))
180
+ self.estimator.set_input_shape('spks', (2, 80))
181
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
182
+ # run trt engine
183
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
184
+ mask.contiguous().data_ptr(),
185
+ mu.contiguous().data_ptr(),
186
+ t.contiguous().data_ptr(),
187
+ spks.contiguous().data_ptr(),
188
+ cond.contiguous().data_ptr(),
189
+ x.data_ptr()])
190
+ return x
191
 
192
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
193
  """Computes diffusion loss