qiaoruiyt commited on
Commit
80d0fb2
·
verified ·
1 Parent(s): 8d377a5

Update modeling_reasonir_8b.py

Browse files
Files changed (1) hide show
  1. modeling_reasonir_8b.py +166 -5
modeling_reasonir_8b.py CHANGED
@@ -51,6 +51,10 @@ from transformers.utils import (
51
  replace_return_docstrings,
52
  )
53
  from transformers.models.llama.configuration_llama import LlamaConfig
 
 
 
 
54
 
55
  if is_flash_attn_2_available():
56
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -428,7 +432,7 @@ class LlamaFlashAttention2(LlamaAttention):
428
  dropout=0.0,
429
  softmax_scale=None,
430
  use_sliding_windows=False,
431
- is_causal=False,
432
  ):
433
  """
434
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -529,7 +533,7 @@ class LlamaFlashAttention2(LlamaAttention):
529
  use_cache: bool = False,
530
  cache_position: Optional[torch.LongTensor] = None,
531
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
532
- is_causal: bool = False,
533
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
534
  if isinstance(past_key_value, StaticCache):
535
  raise ValueError(
@@ -656,7 +660,7 @@ class LlamaSdpaAttention(LlamaAttention):
656
  use_cache: bool = False,
657
  cache_position: Optional[torch.LongTensor] = None,
658
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
659
- is_causal: bool = False,
660
  **kwargs,
661
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
662
  if output_attentions:
@@ -763,7 +767,7 @@ class LlamaDecoderLayer(nn.Module):
763
  use_cache: Optional[bool] = False,
764
  cache_position: Optional[torch.LongTensor] = None,
765
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
766
- is_causal: bool = False,
767
  **kwargs,
768
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
769
  """
@@ -948,6 +952,8 @@ LLAMA_INPUTS_DOCSTRING = r"""
948
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
949
  LLAMA_START_DOCSTRING,
950
  )
 
 
951
  class LlamaModel(LlamaPreTrainedModel):
952
  """
953
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
@@ -991,7 +997,7 @@ class LlamaModel(LlamaPreTrainedModel):
991
  output_hidden_states: Optional[bool] = None,
992
  return_dict: Optional[bool] = None,
993
  cache_position: Optional[torch.LongTensor] = None,
994
- is_causal: Optional[bool] = False,
995
  ) -> Union[Tuple, BaseModelOutputWithPast]:
996
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
997
  output_hidden_states = (
@@ -1663,3 +1669,158 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
1663
  hidden_states=outputs.hidden_states,
1664
  attentions=outputs.attentions,
1665
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  replace_return_docstrings,
52
  )
53
  from transformers.models.llama.configuration_llama import LlamaConfig
54
+ from typing import Dict, List, Union, cast
55
+ import numpy as np
56
+ from tqdm import tqdm
57
+ from transformers import AutoTokenizer
58
 
59
  if is_flash_attn_2_available():
60
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
432
  dropout=0.0,
433
  softmax_scale=None,
434
  use_sliding_windows=False,
435
+ is_causal=True,
436
  ):
437
  """
438
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
533
  use_cache: bool = False,
534
  cache_position: Optional[torch.LongTensor] = None,
535
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
536
+ is_causal: bool = True,
537
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
538
  if isinstance(past_key_value, StaticCache):
539
  raise ValueError(
 
660
  use_cache: bool = False,
661
  cache_position: Optional[torch.LongTensor] = None,
662
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
663
+ is_causal: bool = True,
664
  **kwargs,
665
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
666
  if output_attentions:
 
767
  use_cache: Optional[bool] = False,
768
  cache_position: Optional[torch.LongTensor] = None,
769
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
770
+ is_causal: bool = True,
771
  **kwargs,
772
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
773
  """
 
952
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
953
  LLAMA_START_DOCSTRING,
954
  )
955
+
956
+
957
  class LlamaModel(LlamaPreTrainedModel):
958
  """
959
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
 
997
  output_hidden_states: Optional[bool] = None,
998
  return_dict: Optional[bool] = None,
999
  cache_position: Optional[torch.LongTensor] = None,
1000
+ is_causal: Optional[bool] = True,
1001
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1002
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1003
  output_hidden_states = (
 
1669
  hidden_states=outputs.hidden_states,
1670
  attentions=outputs.attentions,
1671
  )
1672
+
1673
+
1674
+ class ReasonIRModel(LLamaModel):
1675
+ """
1676
+ ReasonIRModel is a wrapper around LlamaModel with bi-directional attention for retrieval tasks
1677
+ """
1678
+
1679
+ def __init__(self, config: LlamaConfig):
1680
+ """
1681
+ Initializes the ReasonIRModel with the given configuration.
1682
+ """
1683
+ super().__init__(config)
1684
+ self.pooling_method = "mean"
1685
+ self.normalized = True
1686
+ self.embed_eos = ""
1687
+ self.reasonir_config = config
1688
+ self.tokenizer = AutoTokenizer.from_pretrained('reasonir/ReasonIR-8B')
1689
+
1690
+ def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
1691
+ """Used for encoding the queries of retrieval or reranking tasks"""
1692
+ return self.encode(queries, **kwargs)
1693
+
1694
+ def encode_corpus(self, corpus: Union[List[str], str, List[Dict[str, str]]], **kwargs) -> np.ndarray:
1695
+ """Used for encoding the corpus of retrieval tasks"""
1696
+ if isinstance(corpus, dict):
1697
+ corpus = [corpus]
1698
+ if isinstance(corpus, list) and isinstance(corpus[0], dict):
1699
+ corpus = [
1700
+ doc["title"] + " " + doc["text"] if "title" in doc
1701
+ else doc["text"] for doc in corpus
1702
+ ]
1703
+ return self.encode(corpus, **kwargs)
1704
+
1705
+ @torch.inference_mode()
1706
+ def encode(
1707
+ self,
1708
+ sentences: Union[List[str], str],
1709
+ batch_size: int = 256,
1710
+ max_length: int = 512,
1711
+ instruction: str = "",
1712
+ embed_instruction: bool = False,
1713
+ get_cache: bool = False,
1714
+ convert_to_tensor: bool = False,
1715
+ recast: bool = False,
1716
+ add_special_tokens: bool = True,
1717
+ **kwargs,
1718
+ ) -> np.ndarray:
1719
+
1720
+ # get number of gpus
1721
+ num_gpus = torch.cuda.device_count()
1722
+ if num_gpus > 0:
1723
+ batch_size *= num_gpus
1724
+
1725
+ input_was_string = False
1726
+ if isinstance(sentences, str):
1727
+ sentences = [sentences]
1728
+ input_was_string = True
1729
+
1730
+ all_embeddings, all_kv_caches = [], []
1731
+ for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
1732
+ sentences_batch = [
1733
+ instruction + s + self.embed_eos for s in sentences[start_index:start_index + batch_size]
1734
+ ]
1735
+ # This will prepend the bos token if the tokenizer has `add_bos_token=True`
1736
+ inputs = self.tokenizer(
1737
+ sentences_batch,
1738
+ padding=True,
1739
+ truncation=True,
1740
+ return_tensors='pt',
1741
+ max_length=max_length,
1742
+ add_special_tokens=add_special_tokens,
1743
+ ).to(self.device)
1744
+
1745
+ inputs["is_causal"] = False
1746
+ if get_cache:
1747
+ inputs['use_cache'] = True
1748
+ outputs = self(**inputs)
1749
+ last_hidden_state = outputs[0]
1750
+ if get_cache:
1751
+ # Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`
1752
+ assert len(all_kv_caches) == 0, "Can only get cache for one batch at a time"
1753
+ all_kv_caches = outputs[1]
1754
+
1755
+ if (instruction) and (embed_instruction is False) and ("mean" in self.pooling_method):
1756
+ # Remove instruction tokens from the embeddings by masking them
1757
+ instruction_tokens = self.tokenizer(
1758
+ instruction,
1759
+ padding=False,
1760
+ truncation=True,
1761
+ max_length=max_length,
1762
+ add_special_tokens=add_special_tokens,
1763
+ )["input_ids"]
1764
+ inputs['attention_mask'][:, :len(instruction_tokens)] = 0
1765
+ embeddings = self.pooling(last_hidden_state, inputs['attention_mask'], recast=recast)
1766
+ # Normalize can change the dtype (https://discuss.pytorch.org/t/tensor-in-float16-is-transformed-into-float32-after-torch-norm/110891)
1767
+ if self.normalized:
1768
+ in_dtype = embeddings.dtype
1769
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1).to(in_dtype)
1770
+ embeddings = cast(torch.Tensor, embeddings)
1771
+ if convert_to_tensor:
1772
+ all_embeddings.append(embeddings)
1773
+ else:
1774
+ # NumPy does not support bfloat16
1775
+ all_embeddings.append(embeddings.cpu().to(torch.float32).numpy())
1776
+
1777
+ all_embeddings = (
1778
+ torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0)
1779
+ )
1780
+ if input_was_string:
1781
+ all_embeddings = all_embeddings[0]
1782
+ if get_cache:
1783
+ return all_embeddings, all_kv_caches
1784
+ return all_embeddings
1785
+
1786
+ def pooling(
1787
+ self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, recast: bool = False
1788
+ ) -> torch.Tensor:
1789
+ """
1790
+ Args:
1791
+ hidden_state: [b, n, d]
1792
+ attention_mask: [b, n]
1793
+ """
1794
+ # In case the model is distributed across multiple devices; hidden_state may end up on diff device
1795
+ hidden_state = hidden_state.to(attention_mask.device)
1796
+ if self.pooling_method == 'cls':
1797
+ embedding = hidden_state[:, 0]
1798
+ elif self.pooling_method == 'lasttoken':
1799
+ b, n, d = hidden_state.size()
1800
+ # Get the last `1` in the attention mask of each item
1801
+ # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
1802
+ # except when 1) There's all 1's 2) There's 0's before the 1's
1803
+ reversed_mask = torch.flip(attention_mask, dims=(1,))
1804
+ argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
1805
+ gather_indices = attention_mask.size(1) - argmax_reverse - 1
1806
+ # If there are empty sequences, where the index would become -1 it will crash so set them to 0
1807
+ gather_indices = torch.clamp(gather_indices, min=0)
1808
+ # Turn indices from shape [b] -> [b, 1, d]
1809
+ gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
1810
+ gather_indices = gather_indices.unsqueeze(1)
1811
+ assert gather_indices.shape == (b, 1, d)
1812
+ # Gather along the seq len: [b, n, d] -> [b, d]
1813
+ # Actually no need for the attention mask as we gather the last token where attn_mask=1 but
1814
+ # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
1815
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
1816
+ embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
1817
+ elif self.pooling_method in ['mean', 'weightedmean']:
1818
+ if self.pooling_method == 'weightedmean':
1819
+ attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
1820
+ s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
1821
+ d = attention_mask.sum(dim=1, keepdim=True).float()
1822
+ embedding = s / d
1823
+ else: raise NotImplementedError(f"Unknown pooling method: {self.pooling_method}")
1824
+ # Recasting performs slightly worse but saves 50% space
1825
+ if recast: return embedding.to(hidden_state.dtype)
1826
+ return embedding