Gokulram2710 commited on
Commit
f1d1223
·
verified ·
1 Parent(s): 0b32e87

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -11
handler.py CHANGED
@@ -4,18 +4,22 @@ import subprocess
4
  import pkg_resources
5
 
6
  # Verify and print the transformers version
7
- transformers_version = pkg_resources.get_distribution("transformers").version
8
- print(f"Transformers version: {transformers_version}")
 
 
 
 
9
 
10
  # Update transformers to the development version if necessary
11
  if transformers_version != "4.40.2":
12
- subprocess.run('pip uninstall -y transformers', shell=True)
13
- subprocess.run('pip install git+https://github.com/huggingface/transformers', shell=True)
14
- transformers_version = pkg_resources.get_distribution("transformers").version
15
- print(f"Updated Transformers version: {transformers_version}")
16
-
17
- # Install flash-attn if needed
18
- # subprocess.run('pip install flash-attn', shell=True)
19
 
20
  class CustomModelHandler:
21
  def __init__(self, model_name_or_path: str):
@@ -30,8 +34,7 @@ class CustomModelHandler:
30
  self.model = AutoModelForCausalLM.from_pretrained(
31
  self.model_name_or_path,
32
  trust_remote_code=True,
33
- torch_dtype="auto",
34
- _attn_implementation="eager"
35
  )
36
  self.model.to(self.device)
37
  print(f"Model loaded and moved to {self.device}")
 
4
  import pkg_resources
5
 
6
  # Verify and print the transformers version
7
+ try:
8
+ transformers_version = pkg_resources.get_distribution("transformers").version
9
+ print(f"Transformers version: {transformers_version}")
10
+ except pkg_resources.DistributionNotFound:
11
+ transformers_version = None
12
+ print("Transformers not installed")
13
 
14
  # Update transformers to the development version if necessary
15
  if transformers_version != "4.40.2":
16
+ try:
17
+ subprocess.run('pip uninstall -y transformers', shell=True, check=True)
18
+ subprocess.run('pip install git+https://github.com/huggingface/transformers', shell=True, check=True)
19
+ transformers_version = pkg_resources.get_distribution("transformers").version
20
+ print(f"Updated Transformers version: {transformers_version}")
21
+ except subprocess.CalledProcessError as e:
22
+ print(f"Error occurred while updating transformers: {e}")
23
 
24
  class CustomModelHandler:
25
  def __init__(self, model_name_or_path: str):
 
34
  self.model = AutoModelForCausalLM.from_pretrained(
35
  self.model_name_or_path,
36
  trust_remote_code=True,
37
+ torch_dtype="auto"
 
38
  )
39
  self.model.to(self.device)
40
  print(f"Model loaded and moved to {self.device}")