Update handler.py
Browse files- handler.py +14 -11
handler.py
CHANGED
@@ -4,18 +4,22 @@ import subprocess
|
|
4 |
import pkg_resources
|
5 |
|
6 |
# Verify and print the transformers version
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Update transformers to the development version if necessary
|
11 |
if transformers_version != "4.40.2":
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
|
20 |
class CustomModelHandler:
|
21 |
def __init__(self, model_name_or_path: str):
|
@@ -30,8 +34,7 @@ class CustomModelHandler:
|
|
30 |
self.model = AutoModelForCausalLM.from_pretrained(
|
31 |
self.model_name_or_path,
|
32 |
trust_remote_code=True,
|
33 |
-
torch_dtype="auto"
|
34 |
-
_attn_implementation="eager"
|
35 |
)
|
36 |
self.model.to(self.device)
|
37 |
print(f"Model loaded and moved to {self.device}")
|
|
|
4 |
import pkg_resources
|
5 |
|
6 |
# Verify and print the transformers version
|
7 |
+
try:
|
8 |
+
transformers_version = pkg_resources.get_distribution("transformers").version
|
9 |
+
print(f"Transformers version: {transformers_version}")
|
10 |
+
except pkg_resources.DistributionNotFound:
|
11 |
+
transformers_version = None
|
12 |
+
print("Transformers not installed")
|
13 |
|
14 |
# Update transformers to the development version if necessary
|
15 |
if transformers_version != "4.40.2":
|
16 |
+
try:
|
17 |
+
subprocess.run('pip uninstall -y transformers', shell=True, check=True)
|
18 |
+
subprocess.run('pip install git+https://github.com/huggingface/transformers', shell=True, check=True)
|
19 |
+
transformers_version = pkg_resources.get_distribution("transformers").version
|
20 |
+
print(f"Updated Transformers version: {transformers_version}")
|
21 |
+
except subprocess.CalledProcessError as e:
|
22 |
+
print(f"Error occurred while updating transformers: {e}")
|
23 |
|
24 |
class CustomModelHandler:
|
25 |
def __init__(self, model_name_or_path: str):
|
|
|
34 |
self.model = AutoModelForCausalLM.from_pretrained(
|
35 |
self.model_name_or_path,
|
36 |
trust_remote_code=True,
|
37 |
+
torch_dtype="auto"
|
|
|
38 |
)
|
39 |
self.model.to(self.device)
|
40 |
print(f"Model loaded and moved to {self.device}")
|