Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ from torch.utils.data import DataLoader
|
|
5 |
from torchvision import transforms
|
6 |
from datasets import load_dataset
|
7 |
from huggingface_hub import Repository
|
|
|
|
|
8 |
import gradio as gr
|
9 |
from PIL import Image
|
10 |
import os
|
@@ -53,7 +55,21 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
53 |
|
54 |
# Apply the necessary transforms
|
55 |
return self.transform(original), self.transform(target)
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# Training function
|
58 |
def train_model(epochs):
|
59 |
# Load the dataset
|
@@ -100,8 +116,11 @@ def train_model(epochs):
|
|
100 |
|
101 |
# Push model to Hugging Face Hub
|
102 |
def push_model_to_hub(model, repo_name):
|
|
|
|
|
|
|
103 |
# Push the model to the Hugging Face hub
|
104 |
-
model.push_to_hub(repo_name)
|
105 |
|
106 |
# Gradio interface function
|
107 |
def gradio_train(epochs):
|
|
|
5 |
from torchvision import transforms
|
6 |
from datasets import load_dataset
|
7 |
from huggingface_hub import Repository
|
8 |
+
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
|
9 |
+
|
10 |
import gradio as gr
|
11 |
from PIL import Image
|
12 |
import os
|
|
|
55 |
|
56 |
# Apply the necessary transforms
|
57 |
return self.transform(original), self.transform(target)
|
58 |
+
|
59 |
+
class UNetWrapper:
|
60 |
+
def __init__(self, unet_model, repo_id):
|
61 |
+
self.model = unet_model
|
62 |
+
self.repo_id = repo_id
|
63 |
+
|
64 |
+
def push_to_hub(self):
|
65 |
+
# Initialize the Hugging Face API
|
66 |
+
api = HfApi()
|
67 |
+
# Create a repository if it doesn't exist
|
68 |
+
create_repo(self.repo_id, exist_ok=True)
|
69 |
+
# Push the model's state dict to the Hugging Face Hub
|
70 |
+
self.model.save_pretrained(self.repo_id) # You may need to implement this method
|
71 |
+
|
72 |
+
|
73 |
# Training function
|
74 |
def train_model(epochs):
|
75 |
# Load the dataset
|
|
|
116 |
|
117 |
# Push model to Hugging Face Hub
|
118 |
def push_model_to_hub(model, repo_name):
|
119 |
+
# Usage example
|
120 |
+
model_wrapper = UNetWrapper(model, model_repo_id)
|
121 |
+
model_wrapper.push_to_hub()
|
122 |
# Push the model to the Hugging Face hub
|
123 |
+
#model.push_to_hub(repo_name)
|
124 |
|
125 |
# Gradio interface function
|
126 |
def gradio_train(epochs):
|