uisikdag commited on
Commit
8935f2f
·
verified ·
1 Parent(s): 1a8a785

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +255 -3
README.md CHANGED
@@ -1,4 +1,256 @@
1
- -----------
2
- #metadata
3
- -----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
1
+ ---
2
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
3
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
4
+ {}
5
+ ---
6
+ # Model Card for Model ID
7
+
8
+ copy/paste/save as pix2pixinference.py
9
+ ```
10
+ import argparse
11
+ import torch
12
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, ToPILImage
13
+ from torchvision.utils import save_image
14
+ from PIL import Image
15
+ import os
16
+ import io
17
+ from huggingface_hub import hf_hub_download
18
+ import sys
19
+ import matplotlib.pyplot as plt
20
+
21
+ # Import the model architecture - assuming it's locally available
22
+ # If not, we'll need to define it here
23
+ try:
24
+ from modeling_pix2pix import GeneratorUNet
25
+ except ImportError:
26
+ print("Couldn't import model architecture, defining it here...")
27
+ # Define the UNet architecture as it appears in the original code
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ def weights_init_normal(m):
32
+ classname = m.__class__.__name__
33
+ if classname.find("Conv") != -1:
34
+ torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
35
+ elif classname.find("BatchNorm2d") != -1:
36
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
37
+ torch.nn.init.constant_(m.bias.data, 0.0)
38
+
39
+ class UNetDown(nn.Module):
40
+ def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
41
+ super(UNetDown, self).__init__()
42
+ layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
43
+ if normalize:
44
+ layers.append(nn.InstanceNorm2d(out_size))
45
+ layers.append(nn.LeakyReLU(0.2))
46
+ if dropout:
47
+ layers.append(nn.Dropout(dropout))
48
+ self.model = nn.Sequential(*layers)
49
+
50
+ def forward(self, x):
51
+ return self.model(x)
52
+
53
+ class UNetUp(nn.Module):
54
+ def __init__(self, in_size, out_size, dropout=0.0):
55
+ super(UNetUp, self).__init__()
56
+ layers = [
57
+ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
58
+ nn.InstanceNorm2d(out_size),
59
+ nn.ReLU(inplace=True),
60
+ ]
61
+ if dropout:
62
+ layers.append(nn.Dropout(dropout))
63
+ self.model = nn.Sequential(*layers)
64
+
65
+ def forward(self, x, skip_input):
66
+ x = self.model(x)
67
+ x = torch.cat((x, skip_input), 1)
68
+ return x
69
+
70
+ class GeneratorUNet(nn.Module):
71
+ def __init__(self, in_channels=3, out_channels=3):
72
+ super(GeneratorUNet, self).__init__()
73
+
74
+ self.down1 = UNetDown(in_channels, 64, normalize=False)
75
+ self.down2 = UNetDown(64, 128)
76
+ self.down3 = UNetDown(128, 256)
77
+ self.down4 = UNetDown(256, 512, dropout=0.5)
78
+ self.down5 = UNetDown(512, 512, dropout=0.5)
79
+ self.down6 = UNetDown(512, 512, dropout=0.5)
80
+ self.down7 = UNetDown(512, 512, dropout=0.5)
81
+ self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
82
+
83
+ self.up1 = UNetUp(512, 512, dropout=0.5)
84
+ self.up2 = UNetUp(1024, 512, dropout=0.5)
85
+ self.up3 = UNetUp(1024, 512, dropout=0.5)
86
+ self.up4 = UNetUp(1024, 512, dropout=0.5)
87
+ self.up5 = UNetUp(1024, 256)
88
+ self.up6 = UNetUp(512, 128)
89
+ self.up7 = UNetUp(256, 64)
90
+
91
+ self.final = nn.Sequential(
92
+ nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
93
+ nn.Tanh(),
94
+ )
95
+
96
+ def forward(self, x):
97
+ # U-Net generator with skip connections from encoder to decoder
98
+ d1 = self.down1(x)
99
+ d2 = self.down2(d1)
100
+ d3 = self.down3(d2)
101
+ d4 = self.down4(d3)
102
+ d5 = self.down5(d4)
103
+ d6 = self.down6(d5)
104
+ d7 = self.down7(d6)
105
+ d8 = self.down8(d7)
106
+ u1 = self.up1(d8, d7)
107
+ u2 = self.up2(u1, d6)
108
+ u3 = self.up3(u2, d5)
109
+ u4 = self.up4(u3, d4)
110
+ u5 = self.up5(u4, d3)
111
+ u6 = self.up6(u5, d2)
112
+ u7 = self.up7(u6, d1)
113
+ return self.final(u7)
114
+
115
+
116
+ def parse_args():
117
+ parser = argparse.ArgumentParser(description="Generate images using Pix2Pix model from HuggingFace Hub")
118
+ parser.add_argument(
119
+ "--repo_id",
120
+ type=str,
121
+ required=True,
122
+ help="HuggingFace Hub repository ID (e.g., 'username/model_name')"
123
+ )
124
+ parser.add_argument(
125
+ "--model_file",
126
+ type=str,
127
+ default="model.pt",
128
+ help="Name of the model file in the repository"
129
+ )
130
+ parser.add_argument(
131
+ "--input_image",
132
+ type=str,
133
+ required=True,
134
+ help="Path to input image (night image to transform to day)"
135
+ )
136
+ parser.add_argument(
137
+ "--output_image",
138
+ type=str,
139
+ default="output.png",
140
+ help="Path to save the generated image"
141
+ )
142
+ parser.add_argument(
143
+ "--image_size",
144
+ type=int,
145
+ default=256,
146
+ help="Size of the input/output images"
147
+ )
148
+ parser.add_argument(
149
+ "--display",
150
+ action="store_true",
151
+ help="Display input and output images using matplotlib"
152
+ )
153
+ parser.add_argument(
154
+ "--token",
155
+ type=str,
156
+ default=None,
157
+ help="HuggingFace token for accessing private repositories"
158
+ )
159
+ return parser.parse_args()
160
+
161
+
162
+ def main():
163
+ args = parse_args()
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+
166
+ print(f"Using device: {device}")
167
+
168
+ # Set up image transformations
169
+ transform_input = Compose([
170
+ Resize((args.image_size, args.image_size)),
171
+ ToTensor(),
172
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
173
+ ])
174
+
175
+ # Initialize model
176
+ print("Initializing model...")
177
+ generator = GeneratorUNet()
178
+ generator.to(device)
179
+
180
+ # Download model from Hugging Face Hub
181
+ print(f"Downloading model from {args.repo_id}...")
182
+ try:
183
+ model_path = hf_hub_download(
184
+ repo_id=args.repo_id,
185
+ filename=args.model_file,
186
+ token=args.token
187
+ )
188
+ print(f"Model downloaded to {model_path}")
189
+ except Exception as e:
190
+ print(f"Error downloading model: {e}")
191
+ sys.exit(1)
192
+
193
+ # Load model weights
194
+ try:
195
+ generator.load_state_dict(torch.load(model_path, map_location=device))
196
+ generator.eval()
197
+ print("Model loaded successfully")
198
+ except Exception as e:
199
+ print(f"Error loading model weights: {e}")
200
+ sys.exit(1)
201
+
202
+ # Load and preprocess input image
203
+ try:
204
+ image = Image.open(args.input_image).convert("RGB")
205
+ original_image = image.copy()
206
+ input_tensor = transform_input(image).unsqueeze(0).to(device)
207
+ print(f"Input image loaded: {args.input_image}")
208
+ except Exception as e:
209
+ print(f"Error loading input image: {e}")
210
+ sys.exit(1)
211
+
212
+ # Generate output image
213
+ print("Generating image...")
214
+ with torch.no_grad():
215
+ fake_B = generator(input_tensor)
216
+
217
+ # Save the output image
218
+ try:
219
+ # Denormalize and convert back to image
220
+ output_image = fake_B.cpu()
221
+ save_image(output_image, args.output_image, normalize=True)
222
+ print(f"Output image saved to {args.output_image}")
223
+
224
+ # Create a PIL image for display if needed
225
+ to_pil = ToPILImage()
226
+ output_pil = to_pil(output_image.squeeze(0) * 0.5 + 0.5)
227
+ except Exception as e:
228
+ print(f"Error saving output image: {e}")
229
+ sys.exit(1)
230
+
231
+ # Display images if requested
232
+ if args.display:
233
+ try:
234
+ plt.figure(figsize=(10, 5))
235
+
236
+ plt.subplot(1, 2, 1)
237
+ plt.title("Input Image (Night)")
238
+ plt.imshow(original_image)
239
+ plt.axis("off")
240
+
241
+ plt.subplot(1, 2, 2)
242
+ plt.title("Generated Image (Day)")
243
+ plt.imshow(output_pil)
244
+ plt.axis("off")
245
+
246
+ plt.tight_layout()
247
+ plt.show()
248
+ except Exception as e:
249
+ print(f"Error displaying images: {e}")
250
+
251
+
252
+ if __name__ == "__main__":
253
+ main()
254
+ ```
255
+ python pix2pixinference.py --repo_id "uisikdag/gan-pix2pix-night2day" --input_image "night_image.jpg" --output_image "day_image.png"
256