Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -31,8 +31,9 @@ global_model = None
|
|
31 |
def load_model():
|
32 |
"""Load the model at startup"""
|
33 |
global global_model
|
|
|
34 |
try:
|
35 |
-
checkpoint = torch.load(
|
36 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
37 |
model.load_state_dict(checkpoint['model_state_dict'])
|
38 |
model.to(device)
|
@@ -86,7 +87,7 @@ class UNetWrapper:
|
|
86 |
}
|
87 |
|
88 |
# Save model locally
|
89 |
-
pth_name = '
|
90 |
torch.save(save_dict, pth_name)
|
91 |
|
92 |
# Create repo if it doesn't exist
|
@@ -115,14 +116,20 @@ tags:
|
|
115 |
- pix2pix
|
116 |
- pytorch
|
117 |
library_name: pytorch
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
---
|
119 |
|
120 |
# Pix2Pix UNet Model
|
121 |
|
122 |
## Model Description
|
123 |
Custom UNet model for Pix2Pix image translation.
|
124 |
-
- **Image Size:**
|
125 |
-
- **Model Type:**
|
126 |
|
127 |
## Usage
|
128 |
|
@@ -130,9 +137,10 @@ Custom UNet model for Pix2Pix image translation.
|
|
130 |
import torch
|
131 |
from small_256_model import UNet as small_UNet
|
132 |
from big_1024_model import UNet as big_UNet
|
133 |
-
|
134 |
# Load the model
|
135 |
-
|
|
|
136 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
137 |
model.load_state_dict(checkpoint['model_state_dict'])
|
138 |
model.eval()
|
|
|
31 |
def load_model():
|
32 |
"""Load the model at startup"""
|
33 |
global global_model
|
34 |
+
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
35 |
try:
|
36 |
+
checkpoint = torch.load(weights_name, map_location=device)
|
37 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
38 |
model.load_state_dict(checkpoint['model_state_dict'])
|
39 |
model.to(device)
|
|
|
87 |
}
|
88 |
|
89 |
# Save model locally
|
90 |
+
pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
91 |
torch.save(save_dict, pth_name)
|
92 |
|
93 |
# Create repo if it doesn't exist
|
|
|
116 |
- pix2pix
|
117 |
- pytorch
|
118 |
library_name: pytorch
|
119 |
+
license: wtfpl
|
120 |
+
datasets:
|
121 |
+
- K00B404/pix2pix_flux_set
|
122 |
+
language:
|
123 |
+
- en
|
124 |
+
pipeline_tag: image-to-image
|
125 |
---
|
126 |
|
127 |
# Pix2Pix UNet Model
|
128 |
|
129 |
## Model Description
|
130 |
Custom UNet model for Pix2Pix image translation.
|
131 |
+
- **Image Size:** 1024
|
132 |
+
- **Model Type:** Big (1024)
|
133 |
|
134 |
## Usage
|
135 |
|
|
|
137 |
import torch
|
138 |
from small_256_model import UNet as small_UNet
|
139 |
from big_1024_model import UNet as big_UNet
|
140 |
+
big = True
|
141 |
# Load the model
|
142 |
+
name='big_model_weights.pth' if big else 'small_model_weights.pth'
|
143 |
+
checkpoint = torch.load(name)
|
144 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
145 |
model.load_state_dict(checkpoint['model_state_dict'])
|
146 |
model.eval()
|