Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -87,12 +87,65 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
87 |
return original, target, original_tokens, enhanced_tokens
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
class UNetWrapper:
|
91 |
-
|
|
|
|
|
92 |
self.model = unet_model
|
|
|
|
|
93 |
self.repo_id = repo_id
|
94 |
-
self.token = os.getenv('NEW_TOKEN') #
|
95 |
-
self.api = HfApi(token=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def push_to_hub(self):
|
98 |
try:
|
@@ -103,7 +156,11 @@ class UNetWrapper:
|
|
103 |
'big': isinstance(self.model, big_UNet),
|
104 |
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
|
105 |
},
|
106 |
-
'model_architecture': str(self.model)
|
|
|
|
|
|
|
|
|
107 |
}
|
108 |
|
109 |
# Save model locally
|
@@ -120,14 +177,18 @@ class UNetWrapper:
|
|
120 |
except Exception as e:
|
121 |
print(f"Repository creation note: {e}")
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
|
132 |
# Create and upload model card
|
133 |
model_card = f"""---
|
@@ -222,14 +283,83 @@ def run_inference(image):
|
|
222 |
# Convert output to image
|
223 |
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
|
224 |
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
|
225 |
-
rp(output)
|
226 |
return output
|
227 |
|
228 |
-
def to_hub(model):
|
229 |
-
wrapper = UNetWrapper(model, model_repo_id)
|
230 |
wrapper.push_to_hub()
|
231 |
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
"""Training function"""
|
234 |
global global_model
|
235 |
|
@@ -282,7 +412,7 @@ def train_model(epochs):
|
|
282 |
output_text.append(status)
|
283 |
|
284 |
# Push model to Hugging Face Hub at the end of each epoch
|
285 |
-
to_hub(model)
|
286 |
|
287 |
global_model = model # Update the global model after training
|
288 |
return model, "\n".join(output_text)
|
@@ -295,7 +425,11 @@ def gradio_train(epochs):
|
|
295 |
|
296 |
def gradio_inference(input_image):
|
297 |
"""Gradio inference interface function"""
|
298 |
-
|
|
|
|
|
|
|
|
|
299 |
|
300 |
# Create Gradio interface with tabs
|
301 |
with gr.Blocks() as app:
|
|
|
87 |
return original, target, original_tokens, enhanced_tokens
|
88 |
|
89 |
|
90 |
+
|
91 |
+
class UNetWrapper:
|
92 |
+
|
93 |
+
def push_to_hub(self, pth_name):
|
94 |
+
"""Push model checkpoint and metadata to the Hugging Face Hub."""
|
95 |
+
try:
|
96 |
+
self.api.upload_file(
|
97 |
+
path_or_fileobj=pth_name,
|
98 |
+
path_in_repo=pth_name,
|
99 |
+
repo_id=self.repo_id,
|
100 |
+
token=self.token,
|
101 |
+
repo_type="model"
|
102 |
+
)
|
103 |
+
print(f"Model checkpoint successfully uploaded to {self.repo_id}")
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error uploading model: {e}")
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
class UNetWrapper:
|
112 |
+
def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None):
|
113 |
+
self.loss = loss
|
114 |
+
self.epoch = epoch
|
115 |
self.model = unet_model
|
116 |
+
self.optimizer = optimizer
|
117 |
+
self.scheduler = scheduler
|
118 |
self.repo_id = repo_id
|
119 |
+
self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment
|
120 |
+
self.api = HfApi(token=self.token)
|
121 |
+
|
122 |
+
def save_checkpoint(self, save_path):
|
123 |
+
"""Save checkpoint with model, optimizer, and scheduler states."""
|
124 |
+
save_dict = {
|
125 |
+
'model_state_dict': self.model.state_dict(),
|
126 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
127 |
+
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
128 |
+
'model_config': {
|
129 |
+
'big': isinstance(self.model, big_UNet),
|
130 |
+
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
|
131 |
+
},
|
132 |
+
'epoch': self.epoch,
|
133 |
+
'loss': self.loss
|
134 |
+
}
|
135 |
+
torch.save(save_dict, save_path)
|
136 |
+
print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
|
137 |
+
|
138 |
+
def load_checkpoint(self, checkpoint_path):
|
139 |
+
"""Load model, optimizer, and scheduler states from the checkpoint."""
|
140 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
141 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
142 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
143 |
+
if self.scheduler and checkpoint['scheduler_state_dict']:
|
144 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
145 |
+
self.epoch = checkpoint['epoch']
|
146 |
+
self.loss = checkpoint['loss']
|
147 |
+
print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")
|
148 |
+
|
149 |
|
150 |
def push_to_hub(self):
|
151 |
try:
|
|
|
156 |
'big': isinstance(self.model, big_UNet),
|
157 |
'img_size': 1024 if isinstance(self.model, big_UNet) else 256
|
158 |
},
|
159 |
+
'model_architecture': str(self.model),
|
160 |
+
'model_state':{
|
161 |
+
'epoch': self.epoch,
|
162 |
+
'loss': self.loss
|
163 |
+
}
|
164 |
}
|
165 |
|
166 |
# Save model locally
|
|
|
177 |
except Exception as e:
|
178 |
print(f"Repository creation note: {e}")
|
179 |
|
180 |
+
""Push model checkpoint and metadata to the Hugging Face Hub."""
|
181 |
+
try:
|
182 |
+
self.api.upload_file(
|
183 |
+
path_or_fileobj=pth_name,
|
184 |
+
path_in_repo=pth_name,
|
185 |
+
repo_id=self.repo_id,
|
186 |
+
token=self.token,
|
187 |
+
repo_type="model"
|
188 |
+
)
|
189 |
+
print(f"Model checkpoint successfully uploaded to {self.repo_id}")
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Error uploading model: {e}")
|
192 |
|
193 |
# Create and upload model card
|
194 |
model_card = f"""---
|
|
|
283 |
# Convert output to image
|
284 |
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
|
285 |
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
|
286 |
+
rp(output[0])
|
287 |
return output
|
288 |
|
289 |
+
def to_hub(model, epoch, loss):
|
290 |
+
wrapper = UNetWrapper(model, model_repo_id, epoch, loss)
|
291 |
wrapper.push_to_hub()
|
292 |
|
293 |
+
|
294 |
+
def train_model(epochs, save_interval=1):
|
295 |
+
"""Training function with checkpoint saving and model uploading."""
|
296 |
+
global global_model
|
297 |
+
|
298 |
+
# Load combined data CSV
|
299 |
+
data_path = 'combined_data.csv'
|
300 |
+
combined_data = pd.read_csv(data_path)
|
301 |
+
|
302 |
+
# Define the transformation
|
303 |
+
transform = transforms.Compose([
|
304 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
305 |
+
transforms.ToTensor(),
|
306 |
+
])
|
307 |
+
|
308 |
+
# Initialize dataset and dataloader
|
309 |
+
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
|
310 |
+
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
311 |
+
|
312 |
+
model = global_model
|
313 |
+
criterion = nn.L1Loss()
|
314 |
+
optimizer = optim.Adam(model.parameters(), lr=LR)
|
315 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
|
316 |
+
wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
|
317 |
+
|
318 |
+
output_text = []
|
319 |
+
|
320 |
+
for epoch in range(epochs):
|
321 |
+
model.train()
|
322 |
+
running_loss = 0.0
|
323 |
+
|
324 |
+
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
|
325 |
+
# Move data to device
|
326 |
+
original, target = original.to(device), target.to(device)
|
327 |
+
original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float()
|
328 |
+
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()
|
329 |
+
|
330 |
+
optimizer.zero_grad()
|
331 |
+
|
332 |
+
# Forward pass
|
333 |
+
output = model(target)
|
334 |
+
img_loss = criterion(output, original)
|
335 |
+
total_loss = img_loss
|
336 |
+
total_loss.backward()
|
337 |
+
optimizer.step()
|
338 |
+
|
339 |
+
running_loss += total_loss.item()
|
340 |
+
|
341 |
+
if i % 10 == 0:
|
342 |
+
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
343 |
+
print(status)
|
344 |
+
output_text.append(status)
|
345 |
+
|
346 |
+
# Update the epoch and loss for checkpoint
|
347 |
+
wrapper.epoch = epoch + 1
|
348 |
+
wrapper.loss = running_loss / len(dataloader)
|
349 |
+
|
350 |
+
# Save checkpoint at specified intervals
|
351 |
+
if (epoch + 1) % save_interval == 0:
|
352 |
+
checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
|
353 |
+
wrapper.save_checkpoint(checkpoint_path)
|
354 |
+
wrapper.push_to_hub(checkpoint_path)
|
355 |
+
|
356 |
+
scheduler.step() # Update learning rate scheduler
|
357 |
+
|
358 |
+
global_model = model # Update global model after training
|
359 |
+
return model, "\n".join(output_text)
|
360 |
+
|
361 |
+
|
362 |
+
def train_model_old(epochs):
|
363 |
"""Training function"""
|
364 |
global global_model
|
365 |
|
|
|
412 |
output_text.append(status)
|
413 |
|
414 |
# Push model to Hugging Face Hub at the end of each epoch
|
415 |
+
to_hub(model, epoch, total_loss)
|
416 |
|
417 |
global_model = model # Update the global model after training
|
418 |
return model, "\n".join(output_text)
|
|
|
425 |
|
426 |
def gradio_inference(input_image):
|
427 |
"""Gradio inference interface function"""
|
428 |
+
output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data)
|
429 |
+
rp(output_image)
|
430 |
+
# If `run_inference` returns a tuple, you should only return the image part
|
431 |
+
return output_image # Ensure you're only returning the processed output image
|
432 |
+
|
433 |
|
434 |
# Create Gradio interface with tabs
|
435 |
with gr.Blocks() as app:
|