K00B404 commited on
Commit
82ee3f8
·
verified ·
1 Parent(s): dadfc60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -18
app.py CHANGED
@@ -87,12 +87,65 @@ class Pix2PixDataset(torch.utils.data.Dataset):
87
  return original, target, original_tokens, enhanced_tokens
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class UNetWrapper:
91
- def __init__(self, unet_model, repo_id):
 
 
92
  self.model = unet_model
 
 
93
  self.repo_id = repo_id
94
- self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
95
- self.api = HfApi(token=os.getenv('NEW_TOKEN'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def push_to_hub(self):
98
  try:
@@ -103,7 +156,11 @@ class UNetWrapper:
103
  'big': isinstance(self.model, big_UNet),
104
  'img_size': 1024 if isinstance(self.model, big_UNet) else 256
105
  },
106
- 'model_architecture': str(self.model)
 
 
 
 
107
  }
108
 
109
  # Save model locally
@@ -120,14 +177,18 @@ class UNetWrapper:
120
  except Exception as e:
121
  print(f"Repository creation note: {e}")
122
 
123
- # Upload the model file
124
- self.api.upload_file(
125
- path_or_fileobj=pth_name,
126
- path_in_repo=pth_name,
127
- repo_id=self.repo_id,
128
- token=self.token,
129
- repo_type="model"
130
- )
 
 
 
 
131
 
132
  # Create and upload model card
133
  model_card = f"""---
@@ -222,14 +283,83 @@ def run_inference(image):
222
  # Convert output to image
223
  output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
224
  output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
225
- rp(output)
226
  return output
227
 
228
- def to_hub(model):
229
- wrapper = UNetWrapper(model, model_repo_id)
230
  wrapper.push_to_hub()
231
 
232
- def train_model(epochs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  """Training function"""
234
  global global_model
235
 
@@ -282,7 +412,7 @@ def train_model(epochs):
282
  output_text.append(status)
283
 
284
  # Push model to Hugging Face Hub at the end of each epoch
285
- to_hub(model)
286
 
287
  global_model = model # Update the global model after training
288
  return model, "\n".join(output_text)
@@ -295,7 +425,11 @@ def gradio_train(epochs):
295
 
296
  def gradio_inference(input_image):
297
  """Gradio inference interface function"""
298
- return input_image, run_inference(input_image)
 
 
 
 
299
 
300
  # Create Gradio interface with tabs
301
  with gr.Blocks() as app:
 
87
  return original, target, original_tokens, enhanced_tokens
88
 
89
 
90
+
91
+ class UNetWrapper:
92
+
93
+ def push_to_hub(self, pth_name):
94
+ """Push model checkpoint and metadata to the Hugging Face Hub."""
95
+ try:
96
+ self.api.upload_file(
97
+ path_or_fileobj=pth_name,
98
+ path_in_repo=pth_name,
99
+ repo_id=self.repo_id,
100
+ token=self.token,
101
+ repo_type="model"
102
+ )
103
+ print(f"Model checkpoint successfully uploaded to {self.repo_id}")
104
+ except Exception as e:
105
+ print(f"Error uploading model: {e}")
106
+
107
+
108
+
109
+
110
+
111
  class UNetWrapper:
112
+ def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None):
113
+ self.loss = loss
114
+ self.epoch = epoch
115
  self.model = unet_model
116
+ self.optimizer = optimizer
117
+ self.scheduler = scheduler
118
  self.repo_id = repo_id
119
+ self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment
120
+ self.api = HfApi(token=self.token)
121
+
122
+ def save_checkpoint(self, save_path):
123
+ """Save checkpoint with model, optimizer, and scheduler states."""
124
+ save_dict = {
125
+ 'model_state_dict': self.model.state_dict(),
126
+ 'optimizer_state_dict': self.optimizer.state_dict(),
127
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
128
+ 'model_config': {
129
+ 'big': isinstance(self.model, big_UNet),
130
+ 'img_size': 1024 if isinstance(self.model, big_UNet) else 256
131
+ },
132
+ 'epoch': self.epoch,
133
+ 'loss': self.loss
134
+ }
135
+ torch.save(save_dict, save_path)
136
+ print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
137
+
138
+ def load_checkpoint(self, checkpoint_path):
139
+ """Load model, optimizer, and scheduler states from the checkpoint."""
140
+ checkpoint = torch.load(checkpoint_path, map_location=device)
141
+ self.model.load_state_dict(checkpoint['model_state_dict'])
142
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
143
+ if self.scheduler and checkpoint['scheduler_state_dict']:
144
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
145
+ self.epoch = checkpoint['epoch']
146
+ self.loss = checkpoint['loss']
147
+ print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")
148
+
149
 
150
  def push_to_hub(self):
151
  try:
 
156
  'big': isinstance(self.model, big_UNet),
157
  'img_size': 1024 if isinstance(self.model, big_UNet) else 256
158
  },
159
+ 'model_architecture': str(self.model),
160
+ 'model_state':{
161
+ 'epoch': self.epoch,
162
+ 'loss': self.loss
163
+ }
164
  }
165
 
166
  # Save model locally
 
177
  except Exception as e:
178
  print(f"Repository creation note: {e}")
179
 
180
+ ""Push model checkpoint and metadata to the Hugging Face Hub."""
181
+ try:
182
+ self.api.upload_file(
183
+ path_or_fileobj=pth_name,
184
+ path_in_repo=pth_name,
185
+ repo_id=self.repo_id,
186
+ token=self.token,
187
+ repo_type="model"
188
+ )
189
+ print(f"Model checkpoint successfully uploaded to {self.repo_id}")
190
+ except Exception as e:
191
+ print(f"Error uploading model: {e}")
192
 
193
  # Create and upload model card
194
  model_card = f"""---
 
283
  # Convert output to image
284
  output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
285
  output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
286
+ rp(output[0])
287
  return output
288
 
289
+ def to_hub(model, epoch, loss):
290
+ wrapper = UNetWrapper(model, model_repo_id, epoch, loss)
291
  wrapper.push_to_hub()
292
 
293
+
294
+ def train_model(epochs, save_interval=1):
295
+ """Training function with checkpoint saving and model uploading."""
296
+ global global_model
297
+
298
+ # Load combined data CSV
299
+ data_path = 'combined_data.csv'
300
+ combined_data = pd.read_csv(data_path)
301
+
302
+ # Define the transformation
303
+ transform = transforms.Compose([
304
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
305
+ transforms.ToTensor(),
306
+ ])
307
+
308
+ # Initialize dataset and dataloader
309
+ dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
310
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
311
+
312
+ model = global_model
313
+ criterion = nn.L1Loss()
314
+ optimizer = optim.Adam(model.parameters(), lr=LR)
315
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
316
+ wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
317
+
318
+ output_text = []
319
+
320
+ for epoch in range(epochs):
321
+ model.train()
322
+ running_loss = 0.0
323
+
324
+ for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
325
+ # Move data to device
326
+ original, target = original.to(device), target.to(device)
327
+ original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float()
328
+ enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()
329
+
330
+ optimizer.zero_grad()
331
+
332
+ # Forward pass
333
+ output = model(target)
334
+ img_loss = criterion(output, original)
335
+ total_loss = img_loss
336
+ total_loss.backward()
337
+ optimizer.step()
338
+
339
+ running_loss += total_loss.item()
340
+
341
+ if i % 10 == 0:
342
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
343
+ print(status)
344
+ output_text.append(status)
345
+
346
+ # Update the epoch and loss for checkpoint
347
+ wrapper.epoch = epoch + 1
348
+ wrapper.loss = running_loss / len(dataloader)
349
+
350
+ # Save checkpoint at specified intervals
351
+ if (epoch + 1) % save_interval == 0:
352
+ checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
353
+ wrapper.save_checkpoint(checkpoint_path)
354
+ wrapper.push_to_hub(checkpoint_path)
355
+
356
+ scheduler.step() # Update learning rate scheduler
357
+
358
+ global_model = model # Update global model after training
359
+ return model, "\n".join(output_text)
360
+
361
+
362
+ def train_model_old(epochs):
363
  """Training function"""
364
  global global_model
365
 
 
412
  output_text.append(status)
413
 
414
  # Push model to Hugging Face Hub at the end of each epoch
415
+ to_hub(model, epoch, total_loss)
416
 
417
  global_model = model # Update the global model after training
418
  return model, "\n".join(output_text)
 
425
 
426
  def gradio_inference(input_image):
427
  """Gradio inference interface function"""
428
+ output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data)
429
+ rp(output_image)
430
+ # If `run_inference` returns a tuple, you should only return the image part
431
+ return output_image # Ensure you're only returning the processed output image
432
+
433
 
434
  # Create Gradio interface with tabs
435
  with gr.Blocks() as app: