K00B404 commited on
Commit
97c17f0
·
verified ·
1 Parent(s): 833a3af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -187
app.py CHANGED
@@ -31,8 +31,8 @@ model_repo_id = "K00B404/pix2pix_flux"
31
  # Global model variable
32
  global_model = None
33
 
34
- # clip
35
- clip_model,clip_tokenizer = load_clip()
36
 
37
  def load_model():
38
  """Load the models at startup"""
@@ -53,10 +53,6 @@ def load_model():
53
  global_model = model
54
  return model
55
 
56
-
57
- import os
58
- import pandas as pd
59
-
60
  class Pix2PixDataset(torch.utils.data.Dataset):
61
  def __init__(self, combined_data, transform, clip_tokenizer):
62
  self.data = combined_data
@@ -91,97 +87,11 @@ class Pix2PixDataset(torch.utils.data.Dataset):
91
  return original, target, original_tokens, enhanced_tokens
92
 
93
 
94
- class Pix2PixDataset_older(torch.utils.data.Dataset):
95
- def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
96
- if not os.path.exists(csv_path):
97
- os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
98
-
99
- self.data = pd.read_csv(csv_path)
100
- self.clip_tokenizer = clip_tokenizer
101
-
102
- self.originals = [x for x in ds["train"] if x['label'] == 0]
103
- self.targets = [x for x in ds["train"] if x['label'] == 1]
104
- assert len(self.originals) == len(self.targets)
105
- print(f"Number of original images: {len(self.originals)}")
106
- print(f"Number of target images: {len(self.targets)}")
107
-
108
- # Debugging: Print out filenames from the dataset and CSV
109
- print("Dataset Original Filenames:")
110
- for original in self.originals:
111
- print(original['image'].filename)
112
-
113
- print("\nCSV Image Filenames:")
114
- print(self.data['image_path'].unique())
115
-
116
-
117
-
118
- self.transform = transform
119
-
120
- def __len__(self):
121
- return len(self.originals)
122
-
123
- def __getitem__(self, idx):
124
- original_img = self.originals[idx]['image']
125
- target_img = self.targets[idx]['image']
126
-
127
- # Convert PIL images
128
- original = original_img.convert('RGB')
129
- target = target_img.convert('RGB')
130
-
131
- # Extract the filename from the image_path in the CSV
132
- original_img_path = self.data.iloc[idx]['image_path']
133
- original_img_filename = os.path.basename(original_img_path)
134
-
135
- # Match the image filename with the `image_path` column in the CSV
136
- matched_row = self.data[self.data['image_path'].str.endswith(original_img_filename)]
137
-
138
- if matched_row.empty:
139
- raise ValueError(f"No matching entry found in the CSV for image {original_img_filename}")
140
-
141
- # Get the prompts from the matched row
142
- original_prompt = matched_row['original_prompt'].values[0]
143
- enhanced_prompt = matched_row['enhanced_prompt'].values[0]
144
-
145
- # Tokenize the prompts using CLIP tokenizer
146
- original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
147
- enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
148
-
149
- # Return transformed images and tokenized prompts
150
- return self.transform(original), self.transform(target), original_tokens, enhanced_tokens
151
-
152
-
153
- # Dataset class remains the same
154
- class Pix2PixDataset_old(torch.utils.data.Dataset):
155
- def __init__(self, ds, transform, csv_path='combined_data.csv'):
156
- if not os.path.exists(csv_path):
157
- os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
158
-
159
- self.data = pd.read_csv(csv_path)
160
- self.clip_tokenizer = clip_tokenizer
161
-
162
- self.originals = [x for x in ds["train"] if x['label'] == 0]
163
- self.targets = [x for x in ds["train"] if x['label'] == 1]
164
- assert len(self.originals) == len(self.targets)
165
- print(f"Number of original images: {len(self.originals)}")
166
- print(f"Number of target images: {len(self.targets)}")
167
- self.transform = transform
168
-
169
- def __len__(self):
170
- return len(self.originals)
171
-
172
- def __getitem__(self, idx):
173
- original_img = self.originals[idx]['image']
174
- # TODO: get original_img file name and match with image_path in self.data....then tokenize the prompts with clip_tokenizer
175
- target_img = self.targets[idx]['image']
176
- original = original_img.convert('RGB')
177
- target = target_img.convert('RGB')
178
- return self.transform(original), self.transform(target)
179
-
180
  class UNetWrapper:
181
  def __init__(self, unet_model, repo_id):
182
  self.model = unet_model
183
  self.repo_id = repo_id
184
- self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
185
  self.api = HfApi(token=os.getenv('NEW_TOKEN'))
186
 
187
  def push_to_hub(self):
@@ -197,7 +107,7 @@ class UNetWrapper:
197
  }
198
 
199
  # Save model locally
200
- pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
201
  torch.save(save_dict, pth_name)
202
 
203
  # Create repo if it doesn't exist
@@ -260,9 +170,13 @@ model.eval()
260
 
261
  {str(self.model)} """
262
  rp(model_card)
 
263
  # Save and upload README
264
  with open("README.md", "w") as f:
265
- f.write(model_card)
 
 
 
266
 
267
  self.api.upload_file(
268
  path_or_fileobj="README.md",
@@ -280,7 +194,6 @@ model.eval()
280
 
281
  except Exception as e:
282
  print(f"Error uploading model: {e}")
283
-
284
 
285
  def prepare_input(image, device='cpu'):
286
  """Prepare image for inference"""
@@ -315,8 +228,6 @@ def to_hub(model):
315
  wrapper = UNetWrapper(model, model_repo_id)
316
  wrapper.push_to_hub()
317
 
318
-
319
-
320
  def train_model(epochs):
321
  """Training function"""
322
  global global_model
@@ -355,13 +266,10 @@ def train_model(epochs):
355
 
356
  # Compute image reconstruction loss
357
  img_loss = criterion(output, original)
358
- rp(f"Image {i} Loss:{imag_loss}")
359
- #print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
360
- # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
361
- #prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
362
 
363
  # Combine losses
364
- total_loss = img_loss #+ 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
365
  total_loss.backward()
366
 
367
  # Optimizer step
@@ -377,62 +285,6 @@ def train_model(epochs):
377
 
378
  global_model = model # Update the global model after training
379
  return model, "\n".join(output_text)
380
-
381
- def train_model_old(epochs):
382
- """Training function"""
383
- global global_model
384
-
385
- ds = load_dataset(dataset_id)
386
- transform = transforms.Compose([
387
- transforms.Resize((IMG_SIZE, IMG_SIZE)),
388
- transforms.ToTensor(),
389
- ])
390
-
391
- # Initialize the dataset and dataloader
392
- dataset = Pix2PixDataset(ds, transform, clip_tokenizer)
393
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
394
-
395
- model = global_model
396
- criterion = nn.L1Loss() # L1 loss for image reconstruction
397
- optimizer = optim.Adam(model.parameters(), lr=LR)
398
- output_text = []
399
-
400
- for epoch in range(epochs):
401
- model.train()
402
- for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
403
- # Move images and prompt embeddings to the appropriate device (CPU or GPU)
404
- original, target = original.to(device), target.to(device)
405
- original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
406
- enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
407
-
408
- optimizer.zero_grad()
409
-
410
- # Forward pass through the model
411
- output = model(target)
412
-
413
- # Compute image reconstruction loss
414
- img_loss = criterion(output, original)
415
-
416
- # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
417
- prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
418
-
419
- # Combine losses
420
- total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
421
- total_loss.backward()
422
-
423
- # Optimizer step
424
- optimizer.step()
425
-
426
- if i % 10 == 0:
427
- status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
428
- rp(status)
429
- output_text.append(status)
430
-
431
- # Push model to Hugging Face Hub at the end of each epoch
432
- to_hub(model)
433
-
434
- global_model = model # Update the global model after training
435
- return model, "\n".join(output_text)
436
 
437
  def gradio_train(epochs):
438
  """Gradio training interface function"""
@@ -440,15 +292,6 @@ def gradio_train(epochs):
440
  to_hub(model)
441
  return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
442
 
443
- def gradio_inference(input_image, keywords):
444
- """Gradio inference interface function"""
445
- # Generate an enhanced prompt using the chat bot
446
- enhanced_prompt = chat_with_bot(keywords)
447
-
448
- # Run inference on the input image
449
- output_image = run_inference(input_image, chat_with_bot(keywords))
450
-
451
- return input_image, output_image, keywords, enhanced_prompt
452
  def gradio_inference(input_image):
453
  """Gradio inference interface function"""
454
  return input_image, run_inference(input_image)
@@ -457,23 +300,18 @@ def gradio_inference(input_image):
457
  with gr.Blocks() as app:
458
  gr.Markdown("# Pix2Pix Model Training and Inference")
459
 
460
- with gr.Tabs():
461
- with gr.TabItem("Training"):
462
- epochs_input = gr.Number(label="Number of Epochs")
463
- train_button = gr.Button("Train Model")
464
- output_text = gr.Textbox(label="Training Progress", lines=10)
465
- train_button.click(gradio_train, inputs=epochs_input, outputs=output_text)
466
-
467
- with gr.TabItem("Inference"):
468
- with gr.Row():
469
- input_image = gr.Image(label="Input Image")
470
- output_image = gr.Image(label="Model Output")
471
- infer_button = gr.Button("Run Inference")
472
- infer_button.click(gradio_inference, inputs=input_image, outputs=[input_image, output_image])
473
-
474
- if __name__ == '__main__':
475
- # Load model at startup
476
- load_model()
477
 
478
- # Launch the Gradio app
479
- app.launch()
 
 
 
 
 
 
 
 
31
  # Global model variable
32
  global_model = None
33
 
34
+ # CLIP
35
+ clip_model, clip_tokenizer = load_clip()
36
 
37
  def load_model():
38
  """Load the models at startup"""
 
53
  global_model = model
54
  return model
55
 
 
 
 
 
56
  class Pix2PixDataset(torch.utils.data.Dataset):
57
  def __init__(self, combined_data, transform, clip_tokenizer):
58
  self.data = combined_data
 
87
  return original, target, original_tokens, enhanced_tokens
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class UNetWrapper:
91
  def __init__(self, unet_model, repo_id):
92
  self.model = unet_model
93
  self.repo_id = repo_id
94
+ self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
95
  self.api = HfApi(token=os.getenv('NEW_TOKEN'))
96
 
97
  def push_to_hub(self):
 
107
  }
108
 
109
  # Save model locally
110
+ pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
111
  torch.save(save_dict, pth_name)
112
 
113
  # Create repo if it doesn't exist
 
170
 
171
  {str(self.model)} """
172
  rp(model_card)
173
+
174
  # Save and upload README
175
  with open("README.md", "w") as f:
176
+ f.write(f"# Pix2Pix UNet Model\n\n"
177
+ f"- **Image Size:** {save_dict['model_config']['img_size']}\n"
178
+ f"- **Model Type:** {'big' if big else 'small'}_UNet ({save_dict['model_config']['img_size']})\n"
179
+ f"## Model Architecture\n{str(self.model)}")
180
 
181
  self.api.upload_file(
182
  path_or_fileobj="README.md",
 
194
 
195
  except Exception as e:
196
  print(f"Error uploading model: {e}")
 
197
 
198
  def prepare_input(image, device='cpu'):
199
  """Prepare image for inference"""
 
228
  wrapper = UNetWrapper(model, model_repo_id)
229
  wrapper.push_to_hub()
230
 
 
 
231
  def train_model(epochs):
232
  """Training function"""
233
  global global_model
 
266
 
267
  # Compute image reconstruction loss
268
  img_loss = criterion(output, original)
269
+ rp(f"Image {i} Loss:{img_loss}")
 
 
 
270
 
271
  # Combine losses
272
+ total_loss = img_loss # Add any other losses if necessary
273
  total_loss.backward()
274
 
275
  # Optimizer step
 
285
 
286
  global_model = model # Update the global model after training
287
  return model, "\n".join(output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  def gradio_train(epochs):
290
  """Gradio training interface function"""
 
292
  to_hub(model)
293
  return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
294
 
 
 
 
 
 
 
 
 
 
295
  def gradio_inference(input_image):
296
  """Gradio inference interface function"""
297
  return input_image, run_inference(input_image)
 
300
  with gr.Blocks() as app:
301
  gr.Markdown("# Pix2Pix Model Training and Inference")
302
 
303
+ with gr.Tab("Train"):
304
+ epochs_input = gr.Number(value=EPOCHS, label="Number of epochs")
305
+ train_button = gr.Button("Train")
306
+ training_output = gr.Textbox(label="Training Log", interactive=False)
307
+ train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output])
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ with gr.Tab("Inference"):
310
+ image_input = gr.Image(type='numpy')
311
+ prompt_input = gr.Textbox(label="Prompt")
312
+ inference_button = gr.Button("Generate")
313
+ inference_output = gr.Image(type='numpy', label="Generated Image")
314
+ inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output])
315
+
316
+ load_model()
317
+ app.launch()