K00B404 commited on
Commit
c7b0000
·
verified ·
1 Parent(s): 35fac14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -31,8 +31,9 @@ global_model = None
31
  def load_model():
32
  """Load the model at startup"""
33
  global global_model
 
34
  try:
35
- checkpoint = torch.load('model_weights.pth', map_location=device)
36
  model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
37
  model.load_state_dict(checkpoint['model_state_dict'])
38
  model.to(device)
@@ -86,7 +87,7 @@ class UNetWrapper:
86
  }
87
 
88
  # Save model locally
89
- pth_name = 'model_weights.pth'
90
  torch.save(save_dict, pth_name)
91
 
92
  # Create repo if it doesn't exist
@@ -115,14 +116,20 @@ tags:
115
  - pix2pix
116
  - pytorch
117
  library_name: pytorch
 
 
 
 
 
 
118
  ---
119
 
120
  # Pix2Pix UNet Model
121
 
122
  ## Model Description
123
  Custom UNet model for Pix2Pix image translation.
124
- - **Image Size:** {1024 if isinstance(self.model, big_UNet) else 256}
125
- - **Model Type:** {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
126
 
127
  ## Usage
128
 
@@ -130,9 +137,10 @@ Custom UNet model for Pix2Pix image translation.
130
  import torch
131
  from small_256_model import UNet as small_UNet
132
  from big_1024_model import UNet as big_UNet
133
-
134
  # Load the model
135
- checkpoint = torch.load('model_weights.pth')
 
136
  model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
137
  model.load_state_dict(checkpoint['model_state_dict'])
138
  model.eval()
 
31
  def load_model():
32
  """Load the model at startup"""
33
  global global_model
34
+ weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
35
  try:
36
+ checkpoint = torch.load(weights_name, map_location=device)
37
  model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
38
  model.load_state_dict(checkpoint['model_state_dict'])
39
  model.to(device)
 
87
  }
88
 
89
  # Save model locally
90
+ pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
91
  torch.save(save_dict, pth_name)
92
 
93
  # Create repo if it doesn't exist
 
116
  - pix2pix
117
  - pytorch
118
  library_name: pytorch
119
+ license: wtfpl
120
+ datasets:
121
+ - K00B404/pix2pix_flux_set
122
+ language:
123
+ - en
124
+ pipeline_tag: image-to-image
125
  ---
126
 
127
  # Pix2Pix UNet Model
128
 
129
  ## Model Description
130
  Custom UNet model for Pix2Pix image translation.
131
+ - **Image Size:** 1024
132
+ - **Model Type:** Big (1024)
133
 
134
  ## Usage
135
 
 
137
  import torch
138
  from small_256_model import UNet as small_UNet
139
  from big_1024_model import UNet as big_UNet
140
+ big = True
141
  # Load the model
142
+ name='big_model_weights.pth' if big else 'small_model_weights.pth'
143
+ checkpoint = torch.load(name)
144
  model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
145
  model.load_state_dict(checkpoint['model_state_dict'])
146
  model.eval()