K00B404 commited on
Commit
c4457ca
·
verified ·
1 Parent(s): 5e80942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -1
app.py CHANGED
@@ -57,6 +57,40 @@ import os
57
  import pandas as pd
58
 
59
  class Pix2PixDataset(torch.utils.data.Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
61
  if not os.path.exists(csv_path):
62
  os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
@@ -277,11 +311,73 @@ def run_inference(image, prompt):
277
  def to_hub(model):
278
  wrapper = UNetWrapper(model, model_repo_id)
279
  wrapper.push_to_hub()
280
-
 
 
281
  def train_model(epochs):
282
  """Training function"""
283
  global global_model
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  ds = load_dataset(dataset_id)
286
  transform = transforms.Compose([
287
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
 
57
  import pandas as pd
58
 
59
  class Pix2PixDataset(torch.utils.data.Dataset):
60
+ def __init__(self, combined_data, transform, clip_tokenizer):
61
+ self.data = combined_data
62
+ self.transform = transform
63
+ self.clip_tokenizer = clip_tokenizer
64
+ self.original_folder = 'images_dataset/original/'
65
+ self.target_folder = 'images_dataset/target/'
66
+
67
+ def __len__(self):
68
+ return len(self.data)
69
+
70
+ def __getitem__(self, idx):
71
+ original_img_filename = os.path.basename(self.data.iloc[idx]['image_path'])
72
+ original_img_path = os.path.join(self.original_folder, original_img_filename)
73
+ target_img_path = os.path.join(self.target_folder, original_img_filename)
74
+
75
+ original_img = Image.open(original_img_path).convert('RGB')
76
+ target_img = Image.open(target_img_path).convert('RGB')
77
+
78
+ # Transform images
79
+ original = self.transform(original_img)
80
+ target = self.transform(target_img)
81
+
82
+ # Get prompts from the DataFrame
83
+ original_prompt = self.data.iloc[idx]['original_prompt']
84
+ enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
85
+
86
+ # Tokenize the prompts using CLIP tokenizer
87
+ original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
88
+ enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
89
+
90
+ return original, target, original_tokens, enhanced_tokens
91
+
92
+
93
+ class Pix2PixDataset_older(torch.utils.data.Dataset):
94
  def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
95
  if not os.path.exists(csv_path):
96
  os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
 
311
  def to_hub(model):
312
  wrapper = UNetWrapper(model, model_repo_id)
313
  wrapper.push_to_hub()
314
+
315
+
316
+
317
  def train_model(epochs):
318
  """Training function"""
319
  global global_model
320
 
321
+ # Load combined data CSV
322
+ data_path = 'path/to/your/combined_data.csv' # Adjust this path
323
+ combined_data = pd.read_csv(data_path)
324
+
325
+ # Define the transformation
326
+ transform = transforms.Compose([
327
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
328
+ transforms.ToTensor(),
329
+ ])
330
+
331
+ # Initialize the dataset and dataloader
332
+ dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
333
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
334
+
335
+ model = global_model
336
+ criterion = nn.L1Loss() # L1 loss for image reconstruction
337
+ optimizer = optim.Adam(model.parameters(), lr=LR)
338
+ output_text = []
339
+
340
+ for epoch in range(epochs):
341
+ model.train()
342
+ for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
343
+ # Move images and prompt embeddings to the appropriate device (CPU or GPU)
344
+ original, target = original.to(device), target.to(device)
345
+ original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
346
+ enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
347
+
348
+ optimizer.zero_grad()
349
+
350
+ # Forward pass through the model
351
+ output = model(target)
352
+
353
+ # Compute image reconstruction loss
354
+ img_loss = criterion(output, original)
355
+
356
+ # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
357
+ prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
358
+
359
+ # Combine losses
360
+ total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
361
+ total_loss.backward()
362
+
363
+ # Optimizer step
364
+ optimizer.step()
365
+
366
+ if i % 10 == 0:
367
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
368
+ print(status)
369
+ output_text.append(status)
370
+
371
+ # Push model to Hugging Face Hub at the end of each epoch
372
+ to_hub(model)
373
+
374
+ global_model = model # Update the global model after training
375
+ return model, "\n".join(output_text)
376
+
377
+ def train_model_old(epochs):
378
+ """Training function"""
379
+ global global_model
380
+
381
  ds = load_dataset(dataset_id)
382
  transform = transforms.Compose([
383
  transforms.Resize((IMG_SIZE, IMG_SIZE)),