move model
Browse files
model/config.json → config.json
RENAMED
File without changes
|
model/model.safetensors → model.safetensors
RENAMED
File without changes
|
model/preprocessor_config.json → preprocessor_config.json
RENAMED
File without changes
|
src/train.py
CHANGED
@@ -5,7 +5,7 @@ from transformers import (
|
|
5 |
Trainer,
|
6 |
)
|
7 |
from datasets import load_dataset
|
8 |
-
from .utils import
|
9 |
|
10 |
|
11 |
def train():
|
@@ -67,5 +67,5 @@ def train():
|
|
67 |
trainer.train()
|
68 |
|
69 |
# Save model and processor
|
70 |
-
model.save_pretrained(
|
71 |
-
processor.save_pretrained(
|
|
|
5 |
Trainer,
|
6 |
)
|
7 |
from datasets import load_dataset
|
8 |
+
from .utils import ROOT_DIR
|
9 |
|
10 |
|
11 |
def train():
|
|
|
67 |
trainer.train()
|
68 |
|
69 |
# Save model and processor
|
70 |
+
model.save_pretrained(ROOT_DIR)
|
71 |
+
processor.save_pretrained(ROOT_DIR)
|
src/utils.py
CHANGED
@@ -2,8 +2,7 @@ from transformers import ViTForImageClassification
|
|
2 |
import os
|
3 |
|
4 |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
5 |
-
MODEL_DIR = os.path.join(ROOT_DIR, "model")
|
6 |
|
7 |
|
8 |
def load_model():
|
9 |
-
return ViTForImageClassification.from_pretrained(
|
|
|
2 |
import os
|
3 |
|
4 |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
5 |
|
6 |
|
7 |
def load_model():
|
8 |
+
return ViTForImageClassification.from_pretrained(ROOT_DIR)
|