K00B404 commited on
Commit
3783d54
·
verified ·
1 Parent(s): 38b513f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ from datasets import load_dataset
7
+ from huggingface_hub import Repository
8
+ from huggingface_hub import HfApi, HfFolder, Repository, create_repo
9
+ import os
10
+ import gradio as gr
11
+ from PIL import Image
12
+ import numpy as np
13
+ from small_256_model import UNet as small_UNet
14
+ from big_1024_model import UNet as big_UNet
15
+
16
+ # Device configuration
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ big = False if device == torch.device('cpu') else True
19
+
20
+ # Parameters
21
+ IMG_SIZE = 1024 if big else 256
22
+ BATCH_SIZE = 16 if big else 4
23
+ EPOCHS = 12
24
+ LR = 0.0002
25
+ dataset_id = "K00B404/pix2pix_flux_set"
26
+ model_repo_id = "K00B404/pix2pix_flux"
27
+
28
+ # Global model variable
29
+ global_model = None
30
+
31
+ def load_model():
32
+ """Load the model at startup"""
33
+ global global_model
34
+ try:
35
+ checkpoint = torch.load('model_weights.pth', map_location=device)
36
+ model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
37
+ model.load_state_dict(checkpoint['model_state_dict'])
38
+ model.to(device)
39
+ model.eval()
40
+ global_model = model
41
+ print("Model loaded successfully!")
42
+ return model
43
+ except Exception as e:
44
+ print(f"Error loading model: {e}")
45
+ model = big_UNet().to(device) if big else small_UNet().to(device)
46
+ global_model = model
47
+ return model
48
+
49
+ # Dataset class remains the same
50
+ class Pix2PixDataset(torch.utils.data.Dataset):
51
+ def __init__(self, ds, transform):
52
+ self.originals = [x for x in ds["train"] if x['label'] == 0]
53
+ self.targets = [x for x in ds["train"] if x['label'] == 1]
54
+ assert len(self.originals) == len(self.targets)
55
+ print(f"Number of original images: {len(self.originals)}")
56
+ print(f"Number of target images: {len(self.targets)}")
57
+ self.transform = transform
58
+
59
+ def __len__(self):
60
+ return len(self.originals)
61
+
62
+ def __getitem__(self, idx):
63
+ original_img = self.originals[idx]['image']
64
+ target_img = self.targets[idx]['image']
65
+ original = original_img.convert('RGB')
66
+ target = target_img.convert('RGB')
67
+ return self.transform(original), self.transform(target)
68
+
69
+ # UNetWrapper class remains the same
70
+ class UNetWrapper:
71
+ # ... [Previous UNetWrapper implementation remains unchanged]
72
+ pass
73
+
74
+ def prepare_input(image, device='cpu'):
75
+ """Prepare image for inference"""
76
+ transform = transforms.Compose([
77
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
78
+ transforms.ToTensor(),
79
+ ])
80
+
81
+ if isinstance(image, np.ndarray):
82
+ image = Image.fromarray(image)
83
+ input_tensor = transform(image).unsqueeze(0).to(device)
84
+ return input_tensor
85
+
86
+ def run_inference(image):
87
+ """Run inference on a single image"""
88
+ global global_model
89
+ if global_model is None:
90
+ return "Error: Model not loaded"
91
+
92
+ global_model.eval()
93
+ input_tensor = prepare_input(image, device)
94
+
95
+ with torch.no_grad():
96
+ output = global_model(input_tensor)
97
+
98
+ # Convert output to image
99
+ output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
100
+ output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
101
+ return output
102
+
103
+ def train_model(epochs):
104
+ """Training function"""
105
+ global global_model
106
+
107
+ ds = load_dataset(dataset_id)
108
+ transform = transforms.Compose([
109
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
110
+ transforms.ToTensor(),
111
+ ])
112
+
113
+ dataset = Pix2PixDataset(ds, transform)
114
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
115
+
116
+ model = global_model
117
+ criterion = nn.L1Loss()
118
+ optimizer = optim.Adam(model.parameters(), lr=LR)
119
+ output_text = []
120
+
121
+ for epoch in range(epochs):
122
+ model.train()
123
+ for i, (original, target) in enumerate(dataloader):
124
+ original, target = original.to(device), target.to(device)
125
+ optimizer.zero_grad()
126
+ output = model(target)
127
+ loss = criterion(output, original)
128
+ loss.backward()
129
+ optimizer.step()
130
+
131
+ if i % 10 == 0:
132
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
133
+ print(status)
134
+ output_text.append(status)
135
+
136
+ global_model = model
137
+ return model, "\n".join(output_text)
138
+
139
+ def gradio_train(epochs):
140
+ """Gradio training interface function"""
141
+ model, training_log = train_model(int(epochs))
142
+ wrapper = UNetWrapper(model, model_repo_id)
143
+ wrapper.push_to_hub()
144
+ return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
145
+
146
+ def gradio_inference(input_image):
147
+ """Gradio inference interface function"""
148
+ return input_image, run_inference(input_image)
149
+
150
+ # Create Gradio interface with tabs
151
+ with gr.Blocks() as app:
152
+ gr.Markdown("# Pix2Pix Model Training and Inference")
153
+
154
+ with gr.Tabs():
155
+ with gr.TabItem("Training"):
156
+ epochs_input = gr.Number(label="Number of Epochs")
157
+ train_button = gr.Button("Train Model")
158
+ output_text = gr.Textbox(label="Training Progress", lines=10)
159
+ train_button.click(gradio_train, inputs=epochs_input, outputs=output_text)
160
+
161
+ with gr.TabItem("Inference"):
162
+ with gr.Row():
163
+ input_image = gr.Image(label="Input Image")
164
+ output_image = gr.Image(label="Model Output")
165
+ infer_button = gr.Button("Run Inference")
166
+ infer_button.click(gradio_inference, inputs=input_image, outputs=[input_image, output_image])
167
+
168
+ if __name__ == '__main__':
169
+ # Load model at startup
170
+ load_model()
171
+
172
+ # Launch the Gradio app
173
+ app.launch()