qiaoruiyt commited on
Commit
dc25f02
·
verified ·
1 Parent(s): 5abdbe4

Update modeling_reasonir_8b.py

Browse files
Files changed (1) hide show
  1. modeling_reasonir_8b.py +3 -2
modeling_reasonir_8b.py CHANGED
@@ -1683,8 +1683,7 @@ class ReasonIRModel(LlamaModel):
1683
  self.pooling_method = "mean"
1684
  self.normalized = True
1685
  self.embed_eos = ""
1686
- self.reasonir_config = config
1687
- self.tokenizer = AutoTokenizer.from_pretrained('reasonir/ReasonIR-8B')
1688
 
1689
  def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
1690
  """Used for encoding the queries of retrieval or reranking tasks"""
@@ -1716,6 +1715,8 @@ class ReasonIRModel(LlamaModel):
1716
  **kwargs,
1717
  ) -> np.ndarray:
1718
 
 
 
1719
  # get number of gpus
1720
  num_gpus = torch.cuda.device_count()
1721
  if num_gpus > 0:
 
1683
  self.pooling_method = "mean"
1684
  self.normalized = True
1685
  self.embed_eos = ""
1686
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, padding_side='right', trust_remote_code=True)
 
1687
 
1688
  def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
1689
  """Used for encoding the queries of retrieval or reranking tasks"""
 
1715
  **kwargs,
1716
  ) -> np.ndarray:
1717
 
1718
+ self.eval()
1719
+
1720
  # get number of gpus
1721
  num_gpus = torch.cuda.device_count()
1722
  if num_gpus > 0: