import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_dataset from huggingface_hub import Repository from huggingface_hub import HfApi, HfFolder, Repository, create_repo import os import pandas as pd import gradio as gr from PIL import Image import numpy as np from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet from CLIP import load as load_clip # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') big = True if device == torch.device('cpu') else False # Parameters IMG_SIZE = 1024 if big else 256 BATCH_SIZE = 1 if big else 4 EPOCHS = 12 LR = 0.0002 dataset_id = "K00B404/pix2pix_flux_set" model_repo_id = "K00B404/pix2pix_flux" # Global model variable global_model = None # clip clip_model,clip_tokenizer = load_clip() def load_model(): """Load the models at startup""" global global_model weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' try: checkpoint = torch.load(weights_name, map_location=device) model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() global_model = model print("Model loaded successfully!") return model except Exception as e: print(f"Error loading model: {e}") model = big_UNet().to(device) if big else small_UNet().to(device) global_model = model return model import os import pandas as pd class Pix2PixDataset(torch.utils.data.Dataset): def __init__(self, combined_data, transform, clip_tokenizer): self.data = combined_data self.transform = transform self.clip_tokenizer = clip_tokenizer self.original_folder = 'images_dataset/original/' self.target_folder = 'images_dataset/target/' def __len__(self): return len(self.data) def __getitem__(self, idx): original_img_filename = os.path.basename(self.data.iloc[idx]['image_path']) original_img_path = os.path.join(self.original_folder, original_img_filename) target_img_path = os.path.join(self.target_folder, original_img_filename) original_img = Image.open(original_img_path).convert('RGB') target_img = Image.open(target_img_path).convert('RGB') # Transform images original = self.transform(original_img) target = self.transform(target_img) # Get prompts from the DataFrame original_prompt = self.data.iloc[idx]['original_prompt'] enhanced_prompt = self.data.iloc[idx]['enhanced_prompt'] # Tokenize the prompts using CLIP tokenizer original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) return original, target, original_tokens, enhanced_tokens class Pix2PixDataset_older(torch.utils.data.Dataset): def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'): if not os.path.exists(csv_path): os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv') self.data = pd.read_csv(csv_path) self.clip_tokenizer = clip_tokenizer self.originals = [x for x in ds["train"] if x['label'] == 0] self.targets = [x for x in ds["train"] if x['label'] == 1] assert len(self.originals) == len(self.targets) print(f"Number of original images: {len(self.originals)}") print(f"Number of target images: {len(self.targets)}") # Debugging: Print out filenames from the dataset and CSV print("Dataset Original Filenames:") for original in self.originals: print(original['image'].filename) print("\nCSV Image Filenames:") print(self.data['image_path'].unique()) self.transform = transform def __len__(self): return len(self.originals) def __getitem__(self, idx): original_img = self.originals[idx]['image'] target_img = self.targets[idx]['image'] # Convert PIL images original = original_img.convert('RGB') target = target_img.convert('RGB') # Extract the filename from the image_path in the CSV original_img_path = self.data.iloc[idx]['image_path'] original_img_filename = os.path.basename(original_img_path) # Match the image filename with the `image_path` column in the CSV matched_row = self.data[self.data['image_path'].str.endswith(original_img_filename)] if matched_row.empty: raise ValueError(f"No matching entry found in the CSV for image {original_img_filename}") # Get the prompts from the matched row original_prompt = matched_row['original_prompt'].values[0] enhanced_prompt = matched_row['enhanced_prompt'].values[0] # Tokenize the prompts using CLIP tokenizer original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) # Return transformed images and tokenized prompts return self.transform(original), self.transform(target), original_tokens, enhanced_tokens # Dataset class remains the same class Pix2PixDataset_old(torch.utils.data.Dataset): def __init__(self, ds, transform, csv_path='combined_data.csv'): if not os.path.exists(csv_path): os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv') self.data = pd.read_csv(csv_path) self.clip_tokenizer = clip_tokenizer self.originals = [x for x in ds["train"] if x['label'] == 0] self.targets = [x for x in ds["train"] if x['label'] == 1] assert len(self.originals) == len(self.targets) print(f"Number of original images: {len(self.originals)}") print(f"Number of target images: {len(self.targets)}") self.transform = transform def __len__(self): return len(self.originals) def __getitem__(self, idx): original_img = self.originals[idx]['image'] # TODO: get original_img file name and match with image_path in self.data....then tokenize the prompts with clip_tokenizer target_img = self.targets[idx]['image'] original = original_img.convert('RGB') target = target_img.convert('RGB') return self.transform(original), self.transform(target) class UNetWrapper: def __init__(self, unet_model, repo_id): self.model = unet_model self.repo_id = repo_id self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set self.api = HfApi(token=os.getenv('NEW_TOKEN')) def push_to_hub(self): try: # Save model state and configuration save_dict = { 'model_state_dict': self.model.state_dict(), 'model_config': { 'big': isinstance(self.model, big_UNet), 'img_size': 1024 if isinstance(self.model, big_UNet) else 256 }, 'model_architecture': str(self.model) } # Save model locally pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' torch.save(save_dict, pth_name) # Create repo if it doesn't exist try: create_repo( repo_id=self.repo_id, token=self.token, exist_ok=True ) except Exception as e: print(f"Repository creation note: {e}") # Upload the model file self.api.upload_file( path_or_fileobj=pth_name, path_in_repo=pth_name, repo_id=self.repo_id, token=self.token, repo_type="model" ) # Create and upload model card model_card = f"""--- tags: - unet - pix2pix - pytorch library_name: pytorch license: wtfpl datasets: - K00B404/pix2pix_flux_set language: - en pipeline_tag: image-to-image --- # Pix2Pix UNet Model ## Model Description Custom UNet model for Pix2Pix image translation. - **Image Size:** 1024 - **Model Type:** Big (1024) ## Usage ```python import torch from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet big = True # Load the model name='big_model_weights.pth' if big else 'small_model_weights.pth' checkpoint = torch.load(name) model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() model.load_state_dict(checkpoint['model_state_dict']) model.eval() Model Architecture {str(self.model)} """ # Save and upload README with open("README.md", "w") as f: f.write(model_card) self.api.upload_file( path_or_fileobj="README.md", path_in_repo="README.md", repo_id=self.repo_id, token=self.token, repo_type="model" ) # Clean up local files os.remove(pth_name) os.remove("README.md") print(f"Model successfully uploaded to {self.repo_id}") except Exception as e: print(f"Error uploading model: {e}") def prepare_input(image, device='cpu'): """Prepare image for inference""" transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) if isinstance(image, np.ndarray): image = Image.fromarray(image) input_tensor = transform(image).unsqueeze(0).to(device) return input_tensor def run_inference(image, prompt): """Run inference on a single image""" global global_model if global_model is None: return "Error: Model not loaded" global_model.eval() input_tensor = prepare_input(image, device) with torch.no_grad(): output = global_model(input_tensor) # Convert output to image output = output.cpu().squeeze(0).permute(1, 2, 0).numpy() output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8) return output def to_hub(model): wrapper = UNetWrapper(model, model_repo_id) wrapper.push_to_hub() def train_model(epochs): """Training function""" global global_model # Load combined data CSV data_path = 'path/to/your/combined_data.csv' # Adjust this path combined_data = pd.read_csv(data_path) # Define the transformation transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) # Initialize the dataset and dataloader dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) model = global_model criterion = nn.L1Loss() # L1 loss for image reconstruction optimizer = optim.Adam(model.parameters(), lr=LR) output_text = [] for epoch in range(epochs): model.train() for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): # Move images and prompt embeddings to the appropriate device (CPU or GPU) original, target = original.to(device), target.to(device) original_prompt_tokens = original_prompt_tokens.input_ids.to(device) enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device) optimizer.zero_grad() # Forward pass through the model output = model(target) # Compute image reconstruction loss img_loss = criterion(output, original) # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings) prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2) # Combine losses total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance total_loss.backward() # Optimizer step optimizer.step() if i % 10 == 0: status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" print(status) output_text.append(status) # Push model to Hugging Face Hub at the end of each epoch to_hub(model) global_model = model # Update the global model after training return model, "\n".join(output_text) def train_model_old(epochs): """Training function""" global global_model ds = load_dataset(dataset_id) transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) # Initialize the dataset and dataloader dataset = Pix2PixDataset(ds, transform, clip_tokenizer) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) model = global_model criterion = nn.L1Loss() # L1 loss for image reconstruction optimizer = optim.Adam(model.parameters(), lr=LR) output_text = [] for epoch in range(epochs): model.train() for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): # Move images and prompt embeddings to the appropriate device (CPU or GPU) original, target = original.to(device), target.to(device) original_prompt_tokens = original_prompt_tokens.input_ids.to(device) enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device) optimizer.zero_grad() # Forward pass through the model output = model(target) # Compute image reconstruction loss img_loss = criterion(output, original) # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings) prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2) # Combine losses total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance total_loss.backward() # Optimizer step optimizer.step() if i % 10 == 0: status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" print(status) output_text.append(status) # Push model to Hugging Face Hub at the end of each epoch to_hub(model) global_model = model # Update the global model after training return model, "\n".join(output_text) def gradio_train(epochs): """Gradio training interface function""" model, training_log = train_model(int(epochs)) to_hub(model) return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}" def gradio_inference(input_image, keywords): """Gradio inference interface function""" # Generate an enhanced prompt using the chat bot enhanced_prompt = chat_with_bot(keywords) # Run inference on the input image output_image = run_inference(input_image, chat_with_bot(keywords)) return input_image, output_image, keywords, enhanced_prompt def gradio_inference(input_image): """Gradio inference interface function""" return input_image, run_inference(input_image) # Create Gradio interface with tabs with gr.Blocks() as app: gr.Markdown("# Pix2Pix Model Training and Inference") with gr.Tabs(): with gr.TabItem("Training"): epochs_input = gr.Number(label="Number of Epochs") train_button = gr.Button("Train Model") output_text = gr.Textbox(label="Training Progress", lines=10) train_button.click(gradio_train, inputs=epochs_input, outputs=output_text) with gr.TabItem("Inference"): with gr.Row(): input_image = gr.Image(label="Input Image") output_image = gr.Image(label="Model Output") infer_button = gr.Button("Run Inference") infer_button.click(gradio_inference, inputs=input_image, outputs=[input_image, output_image]) if __name__ == '__main__': # Load model at startup load_model() # Launch the Gradio app app.launch()