K00B404 commited on
Commit
3010c48
·
verified ·
1 Parent(s): 071bd98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -64
app.py CHANGED
@@ -1,67 +1,34 @@
1
- # Define the Pix2Pix model (UNet)
2
-
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
  from torch.utils.data import DataLoader
7
  from torchvision import transforms
8
  from datasets import load_dataset
9
- from huggingface_hub import Repository, create_repo
10
  import gradio as gr
11
  from PIL import Image
12
  import os
13
 
14
- # Parameters
15
- IMG_SIZE = 256
16
- BATCH_SIZE = 1
17
- EPOCHS = 12
18
- LR = 0.0002
19
 
20
  # Device configuration
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
 
23
- # Define the Pix2Pix model (Simplified UNet)
24
- class UNet(nn.Module):
25
- def __init__(self):
26
- super(UNet, self).__init__()
27
-
28
- # Encoder
29
- self.encoder = nn.Sequential(
30
- nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # 256 -> 128
31
- nn.ReLU(inplace=True),
32
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64
33
- nn.ReLU(inplace=True),
34
- nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32
35
- nn.ReLU(inplace=True),
36
- nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16
37
- nn.ReLU(inplace=True),
38
- nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1), # 16 -> 8
39
- nn.ReLU(inplace=True)
40
- )
41
-
42
- # Decoder
43
- self.decoder = nn.Sequential(
44
- nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 8 -> 16
45
- nn.ReLU(inplace=True),
46
- nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 16 -> 32
47
- nn.ReLU(inplace=True),
48
- nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 32 -> 64
49
- nn.ReLU(inplace=True),
50
- nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64 -> 128
51
- nn.ReLU(inplace=True),
52
- nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), # 128 -> 256
53
- nn.Tanh() # Output range [-1, 1]
54
- )
55
-
56
- def forward(self, x):
57
- enc = self.encoder(x)
58
- dec = self.decoder(enc)
59
- return dec
60
 
61
  # Training function
62
  def train_model(epochs):
63
  # Load the dataset
64
- ds = load_dataset("K00B404/pix2pix_flux_set")
65
 
66
  # Transform function to resize and convert to tensor
67
  transform = transforms.Compose([
@@ -86,7 +53,11 @@ def train_model(epochs):
86
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
87
 
88
  # Initialize model, loss function, and optimizer
89
- model = UNet().to(device)
 
 
 
 
90
  criterion = nn.L1Loss()
91
  optimizer = optim.Adam(model.parameters(), lr=LR)
92
 
@@ -97,8 +68,8 @@ def train_model(epochs):
97
  optimizer.zero_grad()
98
 
99
  # Forward pass
100
- output = model(target)
101
- loss = criterion(output, original)
102
 
103
  # Backward pass
104
  loss.backward()
@@ -112,21 +83,14 @@ def train_model(epochs):
112
 
113
  # Push model to Hugging Face Hub
114
  def push_model_to_hub(model, repo_name):
115
- repo = Repository(repo_name)
116
- repo.push_to_hub()
117
-
118
- # Save the model state dict
119
- model_save_path = os.path.join(repo_name, "pix2pix_model.pth")
120
- torch.save(model.state_dict(), model_save_path)
121
-
122
- # Push the model to the repo
123
- repo.push_to_hub(commit_message="Initial commit with trained Pix2Pix model.")
124
 
125
  # Gradio interface function
126
  def gradio_train(epochs):
127
  model = train_model(int(epochs))
128
- push_model_to_hub(model, "K00B404/pix2pix_flux")
129
- return f"Model trained for {epochs} epochs and pushed to Hugging Face Hub repository 'K00B404/pix2pix_flux'."
130
 
131
  # Gradio Interface
132
  gr_interface = gr.Interface(
@@ -138,8 +102,9 @@ gr_interface = gr.Interface(
138
  )
139
 
140
  if __name__ == '__main__':
141
- # Create or clone the repository
142
- create_repo("K00B404/pix2pix_flux", exist_ok=True)
143
-
 
144
  # Launch the Gradio app
145
- gr_interface.launch()
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
  from datasets import load_dataset
7
+ from huggingface_hub import Repository
8
  import gradio as gr
9
  from PIL import Image
10
  import os
11
 
12
+ from 256_model import UNet as small_UNet
13
+ from 1024_model import UNet as big_UNet
 
 
 
14
 
15
  # Device configuration
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
+ big = False if device == torch.device('cpu') else True
19
+
20
+ # Parameters
21
+ IMG_SIZE = 1024 if big else 256
22
+ BATCH_SIZE = 16 if big else 1
23
+ EPOCHS = 12
24
+ LR = 0.0002
25
+ dataset_id = "K00B404/pix2pix_flux_set"
26
+ model_repo_id = "K00B404/pix2pix_flux"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Training function
29
  def train_model(epochs):
30
  # Load the dataset
31
+ ds = load_dataset(dataset_id)
32
 
33
  # Transform function to resize and convert to tensor
34
  transform = transforms.Compose([
 
53
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
54
 
55
  # Initialize model, loss function, and optimizer
56
+ try:
57
+ model = UNet2DModel.from_pretrained(model_repo_id).to(device)
58
+ except Exception:
59
+ model = big_UNet().to(device) if big else small_UNet().to(device)
60
+
61
  criterion = nn.L1Loss()
62
  optimizer = optim.Adam(model.parameters(), lr=LR)
63
 
 
68
  optimizer.zero_grad()
69
 
70
  # Forward pass
71
+ output = model(target) # Generate cutout image
72
+ loss = criterion(output, original) # Compare with original image
73
 
74
  # Backward pass
75
  loss.backward()
 
83
 
84
  # Push model to Hugging Face Hub
85
  def push_model_to_hub(model, repo_name):
86
+ # Push the model to the Hugging Face hub
87
+ model.push_to_hub(repo_name)
 
 
 
 
 
 
 
88
 
89
  # Gradio interface function
90
  def gradio_train(epochs):
91
  model = train_model(int(epochs))
92
+ push_model_to_hub(model, model_repo_id)
93
+ return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
94
 
95
  # Gradio Interface
96
  gr_interface = gr.Interface(
 
102
  )
103
 
104
  if __name__ == '__main__':
105
+ # Create or clone the repository if necessary
106
+ repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
107
+ repo.git_pull()
108
+
109
  # Launch the Gradio app
110
+ gr_interface.launch()