K00B404 commited on
Commit
b19f010
·
verified ·
1 Parent(s): d71e34d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -5,6 +5,8 @@ from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
  from datasets import load_dataset
7
  from huggingface_hub import Repository
 
 
8
  import gradio as gr
9
  from PIL import Image
10
  import os
@@ -53,7 +55,21 @@ class Pix2PixDataset(torch.utils.data.Dataset):
53
 
54
  # Apply the necessary transforms
55
  return self.transform(original), self.transform(target)
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Training function
58
  def train_model(epochs):
59
  # Load the dataset
@@ -100,8 +116,11 @@ def train_model(epochs):
100
 
101
  # Push model to Hugging Face Hub
102
  def push_model_to_hub(model, repo_name):
 
 
 
103
  # Push the model to the Hugging Face hub
104
- model.push_to_hub(repo_name)
105
 
106
  # Gradio interface function
107
  def gradio_train(epochs):
 
5
  from torchvision import transforms
6
  from datasets import load_dataset
7
  from huggingface_hub import Repository
8
+ from huggingface_hub import HfApi, HfFolder, Repository, create_repo
9
+
10
  import gradio as gr
11
  from PIL import Image
12
  import os
 
55
 
56
  # Apply the necessary transforms
57
  return self.transform(original), self.transform(target)
58
+
59
+ class UNetWrapper:
60
+ def __init__(self, unet_model, repo_id):
61
+ self.model = unet_model
62
+ self.repo_id = repo_id
63
+
64
+ def push_to_hub(self):
65
+ # Initialize the Hugging Face API
66
+ api = HfApi()
67
+ # Create a repository if it doesn't exist
68
+ create_repo(self.repo_id, exist_ok=True)
69
+ # Push the model's state dict to the Hugging Face Hub
70
+ self.model.save_pretrained(self.repo_id) # You may need to implement this method
71
+
72
+
73
  # Training function
74
  def train_model(epochs):
75
  # Load the dataset
 
116
 
117
  # Push model to Hugging Face Hub
118
  def push_model_to_hub(model, repo_name):
119
+ # Usage example
120
+ model_wrapper = UNetWrapper(model, model_repo_id)
121
+ model_wrapper.push_to_hub()
122
  # Push the model to the Hugging Face hub
123
+ #model.push_to_hub(repo_name)
124
 
125
  # Gradio interface function
126
  def gradio_train(epochs):