Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ from big_1024_model import UNet as big_UNet
|
|
15 |
|
16 |
# Device configuration
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
-
big =
|
19 |
|
20 |
# Parameters
|
21 |
IMG_SIZE = 1024 if big else 256
|
@@ -190,7 +190,11 @@ def run_inference(image):
|
|
190 |
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
|
191 |
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
|
192 |
return output
|
193 |
-
|
|
|
|
|
|
|
|
|
194 |
def train_model(epochs):
|
195 |
"""Training function"""
|
196 |
global global_model
|
@@ -223,15 +227,17 @@ def train_model(epochs):
|
|
223 |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
|
224 |
print(status)
|
225 |
output_text.append(status)
|
226 |
-
|
|
|
|
|
227 |
global_model = model
|
228 |
return model, "\n".join(output_text)
|
229 |
|
|
|
230 |
def gradio_train(epochs):
|
231 |
"""Gradio training interface function"""
|
232 |
model, training_log = train_model(int(epochs))
|
233 |
-
|
234 |
-
wrapper.push_to_hub()
|
235 |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
|
236 |
|
237 |
def gradio_inference(input_image):
|
|
|
15 |
|
16 |
# Device configuration
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
big = True if device == torch.device('cpu') else False
|
19 |
|
20 |
# Parameters
|
21 |
IMG_SIZE = 1024 if big else 256
|
|
|
190 |
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
|
191 |
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
|
192 |
return output
|
193 |
+
|
194 |
+
def to_hub(model):
|
195 |
+
wrapper = UNetWrapper(model, model_repo_id)
|
196 |
+
wrapper.push_to_hub()
|
197 |
+
|
198 |
def train_model(epochs):
|
199 |
"""Training function"""
|
200 |
global global_model
|
|
|
227 |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
|
228 |
print(status)
|
229 |
output_text.append(status)
|
230 |
+
|
231 |
+
to_hub(model)
|
232 |
+
|
233 |
global_model = model
|
234 |
return model, "\n".join(output_text)
|
235 |
|
236 |
+
|
237 |
def gradio_train(epochs):
|
238 |
"""Gradio training interface function"""
|
239 |
model, training_log = train_model(int(epochs))
|
240 |
+
to_hub(model)
|
|
|
241 |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
|
242 |
|
243 |
def gradio_inference(input_image):
|