Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -57,6 +57,40 @@ import os
|
|
57 |
import pandas as pd
|
58 |
|
59 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
|
61 |
if not os.path.exists(csv_path):
|
62 |
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
|
@@ -277,11 +311,73 @@ def run_inference(image, prompt):
|
|
277 |
def to_hub(model):
|
278 |
wrapper = UNetWrapper(model, model_repo_id)
|
279 |
wrapper.push_to_hub()
|
280 |
-
|
|
|
|
|
281 |
def train_model(epochs):
|
282 |
"""Training function"""
|
283 |
global global_model
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
ds = load_dataset(dataset_id)
|
286 |
transform = transforms.Compose([
|
287 |
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
|
|
57 |
import pandas as pd
|
58 |
|
59 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
60 |
+
def __init__(self, combined_data, transform, clip_tokenizer):
|
61 |
+
self.data = combined_data
|
62 |
+
self.transform = transform
|
63 |
+
self.clip_tokenizer = clip_tokenizer
|
64 |
+
self.original_folder = 'images_dataset/original/'
|
65 |
+
self.target_folder = 'images_dataset/target/'
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return len(self.data)
|
69 |
+
|
70 |
+
def __getitem__(self, idx):
|
71 |
+
original_img_filename = os.path.basename(self.data.iloc[idx]['image_path'])
|
72 |
+
original_img_path = os.path.join(self.original_folder, original_img_filename)
|
73 |
+
target_img_path = os.path.join(self.target_folder, original_img_filename)
|
74 |
+
|
75 |
+
original_img = Image.open(original_img_path).convert('RGB')
|
76 |
+
target_img = Image.open(target_img_path).convert('RGB')
|
77 |
+
|
78 |
+
# Transform images
|
79 |
+
original = self.transform(original_img)
|
80 |
+
target = self.transform(target_img)
|
81 |
+
|
82 |
+
# Get prompts from the DataFrame
|
83 |
+
original_prompt = self.data.iloc[idx]['original_prompt']
|
84 |
+
enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
|
85 |
+
|
86 |
+
# Tokenize the prompts using CLIP tokenizer
|
87 |
+
original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
88 |
+
enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
89 |
+
|
90 |
+
return original, target, original_tokens, enhanced_tokens
|
91 |
+
|
92 |
+
|
93 |
+
class Pix2PixDataset_older(torch.utils.data.Dataset):
|
94 |
def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
|
95 |
if not os.path.exists(csv_path):
|
96 |
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
|
|
|
311 |
def to_hub(model):
|
312 |
wrapper = UNetWrapper(model, model_repo_id)
|
313 |
wrapper.push_to_hub()
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
def train_model(epochs):
|
318 |
"""Training function"""
|
319 |
global global_model
|
320 |
|
321 |
+
# Load combined data CSV
|
322 |
+
data_path = 'path/to/your/combined_data.csv' # Adjust this path
|
323 |
+
combined_data = pd.read_csv(data_path)
|
324 |
+
|
325 |
+
# Define the transformation
|
326 |
+
transform = transforms.Compose([
|
327 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
328 |
+
transforms.ToTensor(),
|
329 |
+
])
|
330 |
+
|
331 |
+
# Initialize the dataset and dataloader
|
332 |
+
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
|
333 |
+
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
334 |
+
|
335 |
+
model = global_model
|
336 |
+
criterion = nn.L1Loss() # L1 loss for image reconstruction
|
337 |
+
optimizer = optim.Adam(model.parameters(), lr=LR)
|
338 |
+
output_text = []
|
339 |
+
|
340 |
+
for epoch in range(epochs):
|
341 |
+
model.train()
|
342 |
+
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
|
343 |
+
# Move images and prompt embeddings to the appropriate device (CPU or GPU)
|
344 |
+
original, target = original.to(device), target.to(device)
|
345 |
+
original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
|
346 |
+
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
|
347 |
+
|
348 |
+
optimizer.zero_grad()
|
349 |
+
|
350 |
+
# Forward pass through the model
|
351 |
+
output = model(target)
|
352 |
+
|
353 |
+
# Compute image reconstruction loss
|
354 |
+
img_loss = criterion(output, original)
|
355 |
+
|
356 |
+
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
357 |
+
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
358 |
+
|
359 |
+
# Combine losses
|
360 |
+
total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
|
361 |
+
total_loss.backward()
|
362 |
+
|
363 |
+
# Optimizer step
|
364 |
+
optimizer.step()
|
365 |
+
|
366 |
+
if i % 10 == 0:
|
367 |
+
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
368 |
+
print(status)
|
369 |
+
output_text.append(status)
|
370 |
+
|
371 |
+
# Push model to Hugging Face Hub at the end of each epoch
|
372 |
+
to_hub(model)
|
373 |
+
|
374 |
+
global_model = model # Update the global model after training
|
375 |
+
return model, "\n".join(output_text)
|
376 |
+
|
377 |
+
def train_model_old(epochs):
|
378 |
+
"""Training function"""
|
379 |
+
global global_model
|
380 |
+
|
381 |
ds = load_dataset(dataset_id)
|
382 |
transform = transforms.Compose([
|
383 |
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|