--- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/model-cards {} --- # Model Card for Model ID copy/paste/save as pix2pixinference.py ``` import argparse import torch from torchvision.transforms import Compose, Resize, ToTensor, Normalize, ToPILImage from torchvision.utils import save_image from PIL import Image import os import io from huggingface_hub import hf_hub_download import sys import matplotlib.pyplot as plt # Import the model architecture - assuming it's locally available # If not, we'll need to define it here try: from modeling_pix2pix import GeneratorUNet except ImportError: print("Couldn't import model architecture, defining it here...") # Define the UNet architecture as it appears in the original code import torch.nn as nn import torch.nn.functional as F def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] if normalize: layers.append(nn.InstanceNorm2d(out_size)) layers.append(nn.LeakyReLU(0.2)) if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class UNetUp(nn.Module): def __init__(self, in_size, out_size, dropout=0.0): super(UNetUp, self).__init__() layers = [ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), nn.InstanceNorm2d(out_size), nn.ReLU(inplace=True), ] if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x, skip_input): x = self.model(x) x = torch.cat((x, skip_input), 1) return x class GeneratorUNet(nn.Module): def __init__(self, in_channels=3, out_channels=3): super(GeneratorUNet, self).__init__() self.down1 = UNetDown(in_channels, 64, normalize=False) self.down2 = UNetDown(64, 128) self.down3 = UNetDown(128, 256) self.down4 = UNetDown(256, 512, dropout=0.5) self.down5 = UNetDown(512, 512, dropout=0.5) self.down6 = UNetDown(512, 512, dropout=0.5) self.down7 = UNetDown(512, 512, dropout=0.5) self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) self.up1 = UNetUp(512, 512, dropout=0.5) self.up2 = UNetUp(1024, 512, dropout=0.5) self.up3 = UNetUp(1024, 512, dropout=0.5) self.up4 = UNetUp(1024, 512, dropout=0.5) self.up5 = UNetUp(1024, 256) self.up6 = UNetUp(512, 128) self.up7 = UNetUp(256, 64) self.final = nn.Sequential( nn.ConvTranspose2d(128, out_channels, 4, 2, 1), nn.Tanh(), ) def forward(self, x): # U-Net generator with skip connections from encoder to decoder d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) d6 = self.down6(d5) d7 = self.down7(d6) d8 = self.down8(d7) u1 = self.up1(d8, d7) u2 = self.up2(u1, d6) u3 = self.up3(u2, d5) u4 = self.up4(u3, d4) u5 = self.up5(u4, d3) u6 = self.up6(u5, d2) u7 = self.up7(u6, d1) return self.final(u7) def parse_args(): parser = argparse.ArgumentParser(description="Generate images using Pix2Pix model from HuggingFace Hub") parser.add_argument( "--repo_id", type=str, required=True, help="HuggingFace Hub repository ID (e.g., 'username/model_name')" ) parser.add_argument( "--model_file", type=str, default="model.pt", help="Name of the model file in the repository" ) parser.add_argument( "--input_image", type=str, required=True, help="Path to input image (night image to transform to day)" ) parser.add_argument( "--output_image", type=str, default="output.png", help="Path to save the generated image" ) parser.add_argument( "--image_size", type=int, default=256, help="Size of the input/output images" ) parser.add_argument( "--display", action="store_true", help="Display input and output images using matplotlib" ) parser.add_argument( "--token", type=str, default=None, help="HuggingFace token for accessing private repositories" ) return parser.parse_args() def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Set up image transformations transform_input = Compose([ Resize((args.image_size, args.image_size)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # Initialize model print("Initializing model...") generator = GeneratorUNet() generator.to(device) # Download model from Hugging Face Hub print(f"Downloading model from {args.repo_id}...") try: model_path = hf_hub_download( repo_id=args.repo_id, filename=args.model_file, token=args.token ) print(f"Model downloaded to {model_path}") except Exception as e: print(f"Error downloading model: {e}") sys.exit(1) # Load model weights try: generator.load_state_dict(torch.load(model_path, map_location=device)) generator.eval() print("Model loaded successfully") except Exception as e: print(f"Error loading model weights: {e}") sys.exit(1) # Load and preprocess input image try: image = Image.open(args.input_image).convert("RGB") original_image = image.copy() input_tensor = transform_input(image).unsqueeze(0).to(device) print(f"Input image loaded: {args.input_image}") except Exception as e: print(f"Error loading input image: {e}") sys.exit(1) # Generate output image print("Generating image...") with torch.no_grad(): fake_B = generator(input_tensor) # Save the output image try: # Denormalize and convert back to image output_image = fake_B.cpu() save_image(output_image, args.output_image, normalize=True) print(f"Output image saved to {args.output_image}") # Create a PIL image for display if needed to_pil = ToPILImage() output_pil = to_pil(output_image.squeeze(0) * 0.5 + 0.5) except Exception as e: print(f"Error saving output image: {e}") sys.exit(1) # Display images if requested if args.display: try: plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title("Input Image (Night)") plt.imshow(original_image) plt.axis("off") plt.subplot(1, 2, 2) plt.title("Generated Image (Day)") plt.imshow(output_pil) plt.axis("off") plt.tight_layout() plt.show() except Exception as e: print(f"Error displaying images: {e}") if __name__ == "__main__": main() ``` python pix2pixinference.py --repo_id "uisikdag/gan-pix2pix-night2day" --input_image "night_image.jpg" --output_image "day_image.png"