Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -31,8 +31,8 @@ model_repo_id = "K00B404/pix2pix_flux"
|
|
31 |
# Global model variable
|
32 |
global_model = None
|
33 |
|
34 |
-
#
|
35 |
-
clip_model,clip_tokenizer = load_clip()
|
36 |
|
37 |
def load_model():
|
38 |
"""Load the models at startup"""
|
@@ -53,10 +53,6 @@ def load_model():
|
|
53 |
global_model = model
|
54 |
return model
|
55 |
|
56 |
-
|
57 |
-
import os
|
58 |
-
import pandas as pd
|
59 |
-
|
60 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
61 |
def __init__(self, combined_data, transform, clip_tokenizer):
|
62 |
self.data = combined_data
|
@@ -91,97 +87,11 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
91 |
return original, target, original_tokens, enhanced_tokens
|
92 |
|
93 |
|
94 |
-
class Pix2PixDataset_older(torch.utils.data.Dataset):
|
95 |
-
def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'):
|
96 |
-
if not os.path.exists(csv_path):
|
97 |
-
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
|
98 |
-
|
99 |
-
self.data = pd.read_csv(csv_path)
|
100 |
-
self.clip_tokenizer = clip_tokenizer
|
101 |
-
|
102 |
-
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
103 |
-
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
104 |
-
assert len(self.originals) == len(self.targets)
|
105 |
-
print(f"Number of original images: {len(self.originals)}")
|
106 |
-
print(f"Number of target images: {len(self.targets)}")
|
107 |
-
|
108 |
-
# Debugging: Print out filenames from the dataset and CSV
|
109 |
-
print("Dataset Original Filenames:")
|
110 |
-
for original in self.originals:
|
111 |
-
print(original['image'].filename)
|
112 |
-
|
113 |
-
print("\nCSV Image Filenames:")
|
114 |
-
print(self.data['image_path'].unique())
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
self.transform = transform
|
119 |
-
|
120 |
-
def __len__(self):
|
121 |
-
return len(self.originals)
|
122 |
-
|
123 |
-
def __getitem__(self, idx):
|
124 |
-
original_img = self.originals[idx]['image']
|
125 |
-
target_img = self.targets[idx]['image']
|
126 |
-
|
127 |
-
# Convert PIL images
|
128 |
-
original = original_img.convert('RGB')
|
129 |
-
target = target_img.convert('RGB')
|
130 |
-
|
131 |
-
# Extract the filename from the image_path in the CSV
|
132 |
-
original_img_path = self.data.iloc[idx]['image_path']
|
133 |
-
original_img_filename = os.path.basename(original_img_path)
|
134 |
-
|
135 |
-
# Match the image filename with the `image_path` column in the CSV
|
136 |
-
matched_row = self.data[self.data['image_path'].str.endswith(original_img_filename)]
|
137 |
-
|
138 |
-
if matched_row.empty:
|
139 |
-
raise ValueError(f"No matching entry found in the CSV for image {original_img_filename}")
|
140 |
-
|
141 |
-
# Get the prompts from the matched row
|
142 |
-
original_prompt = matched_row['original_prompt'].values[0]
|
143 |
-
enhanced_prompt = matched_row['enhanced_prompt'].values[0]
|
144 |
-
|
145 |
-
# Tokenize the prompts using CLIP tokenizer
|
146 |
-
original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
147 |
-
enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
148 |
-
|
149 |
-
# Return transformed images and tokenized prompts
|
150 |
-
return self.transform(original), self.transform(target), original_tokens, enhanced_tokens
|
151 |
-
|
152 |
-
|
153 |
-
# Dataset class remains the same
|
154 |
-
class Pix2PixDataset_old(torch.utils.data.Dataset):
|
155 |
-
def __init__(self, ds, transform, csv_path='combined_data.csv'):
|
156 |
-
if not os.path.exists(csv_path):
|
157 |
-
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv')
|
158 |
-
|
159 |
-
self.data = pd.read_csv(csv_path)
|
160 |
-
self.clip_tokenizer = clip_tokenizer
|
161 |
-
|
162 |
-
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
163 |
-
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
164 |
-
assert len(self.originals) == len(self.targets)
|
165 |
-
print(f"Number of original images: {len(self.originals)}")
|
166 |
-
print(f"Number of target images: {len(self.targets)}")
|
167 |
-
self.transform = transform
|
168 |
-
|
169 |
-
def __len__(self):
|
170 |
-
return len(self.originals)
|
171 |
-
|
172 |
-
def __getitem__(self, idx):
|
173 |
-
original_img = self.originals[idx]['image']
|
174 |
-
# TODO: get original_img file name and match with image_path in self.data....then tokenize the prompts with clip_tokenizer
|
175 |
-
target_img = self.targets[idx]['image']
|
176 |
-
original = original_img.convert('RGB')
|
177 |
-
target = target_img.convert('RGB')
|
178 |
-
return self.transform(original), self.transform(target)
|
179 |
-
|
180 |
class UNetWrapper:
|
181 |
def __init__(self, unet_model, repo_id):
|
182 |
self.model = unet_model
|
183 |
self.repo_id = repo_id
|
184 |
-
self.token = os.getenv('NEW_TOKEN')
|
185 |
self.api = HfApi(token=os.getenv('NEW_TOKEN'))
|
186 |
|
187 |
def push_to_hub(self):
|
@@ -197,7 +107,7 @@ class UNetWrapper:
|
|
197 |
}
|
198 |
|
199 |
# Save model locally
|
200 |
-
pth_name = 'big_model_weights.pth' if big else
|
201 |
torch.save(save_dict, pth_name)
|
202 |
|
203 |
# Create repo if it doesn't exist
|
@@ -260,9 +170,13 @@ model.eval()
|
|
260 |
|
261 |
{str(self.model)} """
|
262 |
rp(model_card)
|
|
|
263 |
# Save and upload README
|
264 |
with open("README.md", "w") as f:
|
265 |
-
f.write(
|
|
|
|
|
|
|
266 |
|
267 |
self.api.upload_file(
|
268 |
path_or_fileobj="README.md",
|
@@ -280,7 +194,6 @@ model.eval()
|
|
280 |
|
281 |
except Exception as e:
|
282 |
print(f"Error uploading model: {e}")
|
283 |
-
|
284 |
|
285 |
def prepare_input(image, device='cpu'):
|
286 |
"""Prepare image for inference"""
|
@@ -315,8 +228,6 @@ def to_hub(model):
|
|
315 |
wrapper = UNetWrapper(model, model_repo_id)
|
316 |
wrapper.push_to_hub()
|
317 |
|
318 |
-
|
319 |
-
|
320 |
def train_model(epochs):
|
321 |
"""Training function"""
|
322 |
global global_model
|
@@ -355,13 +266,10 @@ def train_model(epochs):
|
|
355 |
|
356 |
# Compute image reconstruction loss
|
357 |
img_loss = criterion(output, original)
|
358 |
-
rp(f"Image {i} Loss:{
|
359 |
-
#print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
|
360 |
-
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
361 |
-
#prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
362 |
|
363 |
# Combine losses
|
364 |
-
total_loss = img_loss
|
365 |
total_loss.backward()
|
366 |
|
367 |
# Optimizer step
|
@@ -377,62 +285,6 @@ def train_model(epochs):
|
|
377 |
|
378 |
global_model = model # Update the global model after training
|
379 |
return model, "\n".join(output_text)
|
380 |
-
|
381 |
-
def train_model_old(epochs):
|
382 |
-
"""Training function"""
|
383 |
-
global global_model
|
384 |
-
|
385 |
-
ds = load_dataset(dataset_id)
|
386 |
-
transform = transforms.Compose([
|
387 |
-
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
388 |
-
transforms.ToTensor(),
|
389 |
-
])
|
390 |
-
|
391 |
-
# Initialize the dataset and dataloader
|
392 |
-
dataset = Pix2PixDataset(ds, transform, clip_tokenizer)
|
393 |
-
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
394 |
-
|
395 |
-
model = global_model
|
396 |
-
criterion = nn.L1Loss() # L1 loss for image reconstruction
|
397 |
-
optimizer = optim.Adam(model.parameters(), lr=LR)
|
398 |
-
output_text = []
|
399 |
-
|
400 |
-
for epoch in range(epochs):
|
401 |
-
model.train()
|
402 |
-
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
|
403 |
-
# Move images and prompt embeddings to the appropriate device (CPU or GPU)
|
404 |
-
original, target = original.to(device), target.to(device)
|
405 |
-
original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
|
406 |
-
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
|
407 |
-
|
408 |
-
optimizer.zero_grad()
|
409 |
-
|
410 |
-
# Forward pass through the model
|
411 |
-
output = model(target)
|
412 |
-
|
413 |
-
# Compute image reconstruction loss
|
414 |
-
img_loss = criterion(output, original)
|
415 |
-
|
416 |
-
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
417 |
-
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
418 |
-
|
419 |
-
# Combine losses
|
420 |
-
total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
|
421 |
-
total_loss.backward()
|
422 |
-
|
423 |
-
# Optimizer step
|
424 |
-
optimizer.step()
|
425 |
-
|
426 |
-
if i % 10 == 0:
|
427 |
-
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
428 |
-
rp(status)
|
429 |
-
output_text.append(status)
|
430 |
-
|
431 |
-
# Push model to Hugging Face Hub at the end of each epoch
|
432 |
-
to_hub(model)
|
433 |
-
|
434 |
-
global_model = model # Update the global model after training
|
435 |
-
return model, "\n".join(output_text)
|
436 |
|
437 |
def gradio_train(epochs):
|
438 |
"""Gradio training interface function"""
|
@@ -440,15 +292,6 @@ def gradio_train(epochs):
|
|
440 |
to_hub(model)
|
441 |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
|
442 |
|
443 |
-
def gradio_inference(input_image, keywords):
|
444 |
-
"""Gradio inference interface function"""
|
445 |
-
# Generate an enhanced prompt using the chat bot
|
446 |
-
enhanced_prompt = chat_with_bot(keywords)
|
447 |
-
|
448 |
-
# Run inference on the input image
|
449 |
-
output_image = run_inference(input_image, chat_with_bot(keywords))
|
450 |
-
|
451 |
-
return input_image, output_image, keywords, enhanced_prompt
|
452 |
def gradio_inference(input_image):
|
453 |
"""Gradio inference interface function"""
|
454 |
return input_image, run_inference(input_image)
|
@@ -457,23 +300,18 @@ def gradio_inference(input_image):
|
|
457 |
with gr.Blocks() as app:
|
458 |
gr.Markdown("# Pix2Pix Model Training and Inference")
|
459 |
|
460 |
-
with gr.
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
train_button.click(gradio_train, inputs=epochs_input, outputs=output_text)
|
466 |
-
|
467 |
-
with gr.TabItem("Inference"):
|
468 |
-
with gr.Row():
|
469 |
-
input_image = gr.Image(label="Input Image")
|
470 |
-
output_image = gr.Image(label="Model Output")
|
471 |
-
infer_button = gr.Button("Run Inference")
|
472 |
-
infer_button.click(gradio_inference, inputs=input_image, outputs=[input_image, output_image])
|
473 |
-
|
474 |
-
if __name__ == '__main__':
|
475 |
-
# Load model at startup
|
476 |
-
load_model()
|
477 |
|
478 |
-
|
479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
# Global model variable
|
32 |
global_model = None
|
33 |
|
34 |
+
# CLIP
|
35 |
+
clip_model, clip_tokenizer = load_clip()
|
36 |
|
37 |
def load_model():
|
38 |
"""Load the models at startup"""
|
|
|
53 |
global_model = model
|
54 |
return model
|
55 |
|
|
|
|
|
|
|
|
|
56 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
57 |
def __init__(self, combined_data, transform, clip_tokenizer):
|
58 |
self.data = combined_data
|
|
|
87 |
return original, target, original_tokens, enhanced_tokens
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
class UNetWrapper:
|
91 |
def __init__(self, unet_model, repo_id):
|
92 |
self.model = unet_model
|
93 |
self.repo_id = repo_id
|
94 |
+
self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set
|
95 |
self.api = HfApi(token=os.getenv('NEW_TOKEN'))
|
96 |
|
97 |
def push_to_hub(self):
|
|
|
107 |
}
|
108 |
|
109 |
# Save model locally
|
110 |
+
pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
111 |
torch.save(save_dict, pth_name)
|
112 |
|
113 |
# Create repo if it doesn't exist
|
|
|
170 |
|
171 |
{str(self.model)} """
|
172 |
rp(model_card)
|
173 |
+
|
174 |
# Save and upload README
|
175 |
with open("README.md", "w") as f:
|
176 |
+
f.write(f"# Pix2Pix UNet Model\n\n"
|
177 |
+
f"- **Image Size:** {save_dict['model_config']['img_size']}\n"
|
178 |
+
f"- **Model Type:** {'big' if big else 'small'}_UNet ({save_dict['model_config']['img_size']})\n"
|
179 |
+
f"## Model Architecture\n{str(self.model)}")
|
180 |
|
181 |
self.api.upload_file(
|
182 |
path_or_fileobj="README.md",
|
|
|
194 |
|
195 |
except Exception as e:
|
196 |
print(f"Error uploading model: {e}")
|
|
|
197 |
|
198 |
def prepare_input(image, device='cpu'):
|
199 |
"""Prepare image for inference"""
|
|
|
228 |
wrapper = UNetWrapper(model, model_repo_id)
|
229 |
wrapper.push_to_hub()
|
230 |
|
|
|
|
|
231 |
def train_model(epochs):
|
232 |
"""Training function"""
|
233 |
global global_model
|
|
|
266 |
|
267 |
# Compute image reconstruction loss
|
268 |
img_loss = criterion(output, original)
|
269 |
+
rp(f"Image {i} Loss:{img_loss}")
|
|
|
|
|
|
|
270 |
|
271 |
# Combine losses
|
272 |
+
total_loss = img_loss # Add any other losses if necessary
|
273 |
total_loss.backward()
|
274 |
|
275 |
# Optimizer step
|
|
|
285 |
|
286 |
global_model = model # Update the global model after training
|
287 |
return model, "\n".join(output_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
def gradio_train(epochs):
|
290 |
"""Gradio training interface function"""
|
|
|
292 |
to_hub(model)
|
293 |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
def gradio_inference(input_image):
|
296 |
"""Gradio inference interface function"""
|
297 |
return input_image, run_inference(input_image)
|
|
|
300 |
with gr.Blocks() as app:
|
301 |
gr.Markdown("# Pix2Pix Model Training and Inference")
|
302 |
|
303 |
+
with gr.Tab("Train"):
|
304 |
+
epochs_input = gr.Number(value=EPOCHS, label="Number of epochs")
|
305 |
+
train_button = gr.Button("Train")
|
306 |
+
training_output = gr.Textbox(label="Training Log", interactive=False)
|
307 |
+
train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
with gr.Tab("Inference"):
|
310 |
+
image_input = gr.Image(type='numpy')
|
311 |
+
prompt_input = gr.Textbox(label="Prompt")
|
312 |
+
inference_button = gr.Button("Generate")
|
313 |
+
inference_output = gr.Image(type='numpy', label="Generated Image")
|
314 |
+
inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output])
|
315 |
+
|
316 |
+
load_model()
|
317 |
+
app.launch()
|