K00B404 commited on
Commit
45f2d0c
·
verified ·
1 Parent(s): 3a1c257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -15,7 +15,7 @@ from big_1024_model import UNet as big_UNet
15
 
16
  # Device configuration
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- big = False if device == torch.device('cpu') else True
19
 
20
  # Parameters
21
  IMG_SIZE = 1024 if big else 256
@@ -190,7 +190,11 @@ def run_inference(image):
190
  output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
191
  output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
192
  return output
193
-
 
 
 
 
194
  def train_model(epochs):
195
  """Training function"""
196
  global global_model
@@ -223,15 +227,17 @@ def train_model(epochs):
223
  status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
224
  print(status)
225
  output_text.append(status)
226
-
 
 
227
  global_model = model
228
  return model, "\n".join(output_text)
229
 
 
230
  def gradio_train(epochs):
231
  """Gradio training interface function"""
232
  model, training_log = train_model(int(epochs))
233
- wrapper = UNetWrapper(model, model_repo_id)
234
- wrapper.push_to_hub()
235
  return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
236
 
237
  def gradio_inference(input_image):
 
15
 
16
  # Device configuration
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ big = True if device == torch.device('cpu') else False
19
 
20
  # Parameters
21
  IMG_SIZE = 1024 if big else 256
 
190
  output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
191
  output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
192
  return output
193
+
194
+ def to_hub(model):
195
+ wrapper = UNetWrapper(model, model_repo_id)
196
+ wrapper.push_to_hub()
197
+
198
  def train_model(epochs):
199
  """Training function"""
200
  global global_model
 
227
  status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
228
  print(status)
229
  output_text.append(status)
230
+
231
+ to_hub(model)
232
+
233
  global_model = model
234
  return model, "\n".join(output_text)
235
 
236
+
237
  def gradio_train(epochs):
238
  """Gradio training interface function"""
239
  model, training_log = train_model(int(epochs))
240
+ to_hub(model)
 
241
  return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
242
 
243
  def gradio_inference(input_image):