Update app.py
Browse files
app.py
CHANGED
@@ -176,7 +176,7 @@ transform = transforms.Compose([
|
|
176 |
transforms.ToTensor(),
|
177 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
178 |
])
|
179 |
-
weight_dtype = torch.
|
180 |
|
181 |
# line model
|
182 |
line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
|
@@ -201,7 +201,7 @@ global MultiResNetModel
|
|
201 |
global cur_style
|
202 |
|
203 |
cur_style = 'line + shadow'
|
204 |
-
weight_dtype = torch.
|
205 |
|
206 |
block_out_channels = [128, 128, 256, 512, 512]
|
207 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
@@ -313,7 +313,7 @@ print('loaded pipeline')
|
|
313 |
|
314 |
@spaces.GPU
|
315 |
def change_ckpt(style):
|
316 |
-
weight_dtype = torch.
|
317 |
|
318 |
if style == 'line':
|
319 |
MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')
|
|
|
176 |
transforms.ToTensor(),
|
177 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
178 |
])
|
179 |
+
weight_dtype = torch.float16
|
180 |
|
181 |
# line model
|
182 |
line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
|
|
|
201 |
global cur_style
|
202 |
|
203 |
cur_style = 'line + shadow'
|
204 |
+
weight_dtype = torch.float16
|
205 |
|
206 |
block_out_channels = [128, 128, 256, 512, 512]
|
207 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
|
|
313 |
|
314 |
@spaces.GPU
|
315 |
def change_ckpt(style):
|
316 |
+
weight_dtype = torch.float16
|
317 |
|
318 |
if style == 'line':
|
319 |
MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')
|