MrOvkill commited on
Commit
cbb08d7
·
verified ·
1 Parent(s): 8eded20

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -4,7 +4,7 @@ from obsidian import Vision
4
 
5
  class EndpointHandler():
6
  def __init__(self, path="", vision_model="obsidian3b"):
7
- self.pipeline = pipeline("text-classification", model=path)
8
  self.vision = Vision(vision_model)
9
 
10
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -29,12 +29,12 @@ class EndpointHandler():
29
  combined_captions = [inputs, image_caption]
30
 
31
  # run text classification on combined captions
32
- prediction = self.pipeline(combined_captions, temperature=0.33, num_beams=5, stop=[])
33
 
34
  return prediction
35
 
36
  else:
37
  # run text classification on plain text input
38
- prediction = self.pipeline(inputs, temperature=0.33, num_beams=5, stop=[])
39
 
40
  return prediction
 
4
 
5
  class EndpointHandler():
6
  def __init__(self, path="", vision_model="obsidian3b"):
7
+ self.pipeline = pipeline("text-generation", model=path)
8
  self.vision = Vision(vision_model)
9
 
10
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
29
  combined_captions = [inputs, image_caption]
30
 
31
  # run text classification on combined captions
32
+ prediction = self.pipeline(combined_captions, temperature=0.33, num_beams=5, stop=[], do_sample=True)
33
 
34
  return prediction
35
 
36
  else:
37
  # run text classification on plain text input
38
+ prediction = self.pipeline(inputs, temperature=0.33, num_beams=5, stop=[], do_sample=True)
39
 
40
  return prediction