Update modeling_reasonir_8b.py
Browse files- modeling_reasonir_8b.py +166 -5
modeling_reasonir_8b.py
CHANGED
@@ -51,6 +51,10 @@ from transformers.utils import (
|
|
51 |
replace_return_docstrings,
|
52 |
)
|
53 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
|
|
|
|
|
|
54 |
|
55 |
if is_flash_attn_2_available():
|
56 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
@@ -428,7 +432,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
428 |
dropout=0.0,
|
429 |
softmax_scale=None,
|
430 |
use_sliding_windows=False,
|
431 |
-
is_causal=
|
432 |
):
|
433 |
"""
|
434 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
@@ -529,7 +533,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
529 |
use_cache: bool = False,
|
530 |
cache_position: Optional[torch.LongTensor] = None,
|
531 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
532 |
-
is_causal: bool =
|
533 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
534 |
if isinstance(past_key_value, StaticCache):
|
535 |
raise ValueError(
|
@@ -656,7 +660,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
656 |
use_cache: bool = False,
|
657 |
cache_position: Optional[torch.LongTensor] = None,
|
658 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
659 |
-
is_causal: bool =
|
660 |
**kwargs,
|
661 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
662 |
if output_attentions:
|
@@ -763,7 +767,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
763 |
use_cache: Optional[bool] = False,
|
764 |
cache_position: Optional[torch.LongTensor] = None,
|
765 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
766 |
-
is_causal: bool =
|
767 |
**kwargs,
|
768 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
769 |
"""
|
@@ -948,6 +952,8 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|
948 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
949 |
LLAMA_START_DOCSTRING,
|
950 |
)
|
|
|
|
|
951 |
class LlamaModel(LlamaPreTrainedModel):
|
952 |
"""
|
953 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
@@ -991,7 +997,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
991 |
output_hidden_states: Optional[bool] = None,
|
992 |
return_dict: Optional[bool] = None,
|
993 |
cache_position: Optional[torch.LongTensor] = None,
|
994 |
-
is_causal: Optional[bool] =
|
995 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
996 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
997 |
output_hidden_states = (
|
@@ -1663,3 +1669,158 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
|
|
1663 |
hidden_states=outputs.hidden_states,
|
1664 |
attentions=outputs.attentions,
|
1665 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
replace_return_docstrings,
|
52 |
)
|
53 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
54 |
+
from typing import Dict, List, Union, cast
|
55 |
+
import numpy as np
|
56 |
+
from tqdm import tqdm
|
57 |
+
from transformers import AutoTokenizer
|
58 |
|
59 |
if is_flash_attn_2_available():
|
60 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
432 |
dropout=0.0,
|
433 |
softmax_scale=None,
|
434 |
use_sliding_windows=False,
|
435 |
+
is_causal=True,
|
436 |
):
|
437 |
"""
|
438 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
533 |
use_cache: bool = False,
|
534 |
cache_position: Optional[torch.LongTensor] = None,
|
535 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
536 |
+
is_causal: bool = True,
|
537 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
538 |
if isinstance(past_key_value, StaticCache):
|
539 |
raise ValueError(
|
|
|
660 |
use_cache: bool = False,
|
661 |
cache_position: Optional[torch.LongTensor] = None,
|
662 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
663 |
+
is_causal: bool = True,
|
664 |
**kwargs,
|
665 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
666 |
if output_attentions:
|
|
|
767 |
use_cache: Optional[bool] = False,
|
768 |
cache_position: Optional[torch.LongTensor] = None,
|
769 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
770 |
+
is_causal: bool = True,
|
771 |
**kwargs,
|
772 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
773 |
"""
|
|
|
952 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
953 |
LLAMA_START_DOCSTRING,
|
954 |
)
|
955 |
+
|
956 |
+
|
957 |
class LlamaModel(LlamaPreTrainedModel):
|
958 |
"""
|
959 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
|
|
997 |
output_hidden_states: Optional[bool] = None,
|
998 |
return_dict: Optional[bool] = None,
|
999 |
cache_position: Optional[torch.LongTensor] = None,
|
1000 |
+
is_causal: Optional[bool] = True,
|
1001 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1002 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1003 |
output_hidden_states = (
|
|
|
1669 |
hidden_states=outputs.hidden_states,
|
1670 |
attentions=outputs.attentions,
|
1671 |
)
|
1672 |
+
|
1673 |
+
|
1674 |
+
class ReasonIRModel(LLamaModel):
|
1675 |
+
"""
|
1676 |
+
ReasonIRModel is a wrapper around LlamaModel with bi-directional attention for retrieval tasks
|
1677 |
+
"""
|
1678 |
+
|
1679 |
+
def __init__(self, config: LlamaConfig):
|
1680 |
+
"""
|
1681 |
+
Initializes the ReasonIRModel with the given configuration.
|
1682 |
+
"""
|
1683 |
+
super().__init__(config)
|
1684 |
+
self.pooling_method = "mean"
|
1685 |
+
self.normalized = True
|
1686 |
+
self.embed_eos = ""
|
1687 |
+
self.reasonir_config = config
|
1688 |
+
self.tokenizer = AutoTokenizer.from_pretrained('reasonir/ReasonIR-8B')
|
1689 |
+
|
1690 |
+
def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
|
1691 |
+
"""Used for encoding the queries of retrieval or reranking tasks"""
|
1692 |
+
return self.encode(queries, **kwargs)
|
1693 |
+
|
1694 |
+
def encode_corpus(self, corpus: Union[List[str], str, List[Dict[str, str]]], **kwargs) -> np.ndarray:
|
1695 |
+
"""Used for encoding the corpus of retrieval tasks"""
|
1696 |
+
if isinstance(corpus, dict):
|
1697 |
+
corpus = [corpus]
|
1698 |
+
if isinstance(corpus, list) and isinstance(corpus[0], dict):
|
1699 |
+
corpus = [
|
1700 |
+
doc["title"] + " " + doc["text"] if "title" in doc
|
1701 |
+
else doc["text"] for doc in corpus
|
1702 |
+
]
|
1703 |
+
return self.encode(corpus, **kwargs)
|
1704 |
+
|
1705 |
+
@torch.inference_mode()
|
1706 |
+
def encode(
|
1707 |
+
self,
|
1708 |
+
sentences: Union[List[str], str],
|
1709 |
+
batch_size: int = 256,
|
1710 |
+
max_length: int = 512,
|
1711 |
+
instruction: str = "",
|
1712 |
+
embed_instruction: bool = False,
|
1713 |
+
get_cache: bool = False,
|
1714 |
+
convert_to_tensor: bool = False,
|
1715 |
+
recast: bool = False,
|
1716 |
+
add_special_tokens: bool = True,
|
1717 |
+
**kwargs,
|
1718 |
+
) -> np.ndarray:
|
1719 |
+
|
1720 |
+
# get number of gpus
|
1721 |
+
num_gpus = torch.cuda.device_count()
|
1722 |
+
if num_gpus > 0:
|
1723 |
+
batch_size *= num_gpus
|
1724 |
+
|
1725 |
+
input_was_string = False
|
1726 |
+
if isinstance(sentences, str):
|
1727 |
+
sentences = [sentences]
|
1728 |
+
input_was_string = True
|
1729 |
+
|
1730 |
+
all_embeddings, all_kv_caches = [], []
|
1731 |
+
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
|
1732 |
+
sentences_batch = [
|
1733 |
+
instruction + s + self.embed_eos for s in sentences[start_index:start_index + batch_size]
|
1734 |
+
]
|
1735 |
+
# This will prepend the bos token if the tokenizer has `add_bos_token=True`
|
1736 |
+
inputs = self.tokenizer(
|
1737 |
+
sentences_batch,
|
1738 |
+
padding=True,
|
1739 |
+
truncation=True,
|
1740 |
+
return_tensors='pt',
|
1741 |
+
max_length=max_length,
|
1742 |
+
add_special_tokens=add_special_tokens,
|
1743 |
+
).to(self.device)
|
1744 |
+
|
1745 |
+
inputs["is_causal"] = False
|
1746 |
+
if get_cache:
|
1747 |
+
inputs['use_cache'] = True
|
1748 |
+
outputs = self(**inputs)
|
1749 |
+
last_hidden_state = outputs[0]
|
1750 |
+
if get_cache:
|
1751 |
+
# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`
|
1752 |
+
assert len(all_kv_caches) == 0, "Can only get cache for one batch at a time"
|
1753 |
+
all_kv_caches = outputs[1]
|
1754 |
+
|
1755 |
+
if (instruction) and (embed_instruction is False) and ("mean" in self.pooling_method):
|
1756 |
+
# Remove instruction tokens from the embeddings by masking them
|
1757 |
+
instruction_tokens = self.tokenizer(
|
1758 |
+
instruction,
|
1759 |
+
padding=False,
|
1760 |
+
truncation=True,
|
1761 |
+
max_length=max_length,
|
1762 |
+
add_special_tokens=add_special_tokens,
|
1763 |
+
)["input_ids"]
|
1764 |
+
inputs['attention_mask'][:, :len(instruction_tokens)] = 0
|
1765 |
+
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'], recast=recast)
|
1766 |
+
# Normalize can change the dtype (https://discuss.pytorch.org/t/tensor-in-float16-is-transformed-into-float32-after-torch-norm/110891)
|
1767 |
+
if self.normalized:
|
1768 |
+
in_dtype = embeddings.dtype
|
1769 |
+
embeddings = torch.nn.functional.normalize(embeddings, dim=-1).to(in_dtype)
|
1770 |
+
embeddings = cast(torch.Tensor, embeddings)
|
1771 |
+
if convert_to_tensor:
|
1772 |
+
all_embeddings.append(embeddings)
|
1773 |
+
else:
|
1774 |
+
# NumPy does not support bfloat16
|
1775 |
+
all_embeddings.append(embeddings.cpu().to(torch.float32).numpy())
|
1776 |
+
|
1777 |
+
all_embeddings = (
|
1778 |
+
torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0)
|
1779 |
+
)
|
1780 |
+
if input_was_string:
|
1781 |
+
all_embeddings = all_embeddings[0]
|
1782 |
+
if get_cache:
|
1783 |
+
return all_embeddings, all_kv_caches
|
1784 |
+
return all_embeddings
|
1785 |
+
|
1786 |
+
def pooling(
|
1787 |
+
self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, recast: bool = False
|
1788 |
+
) -> torch.Tensor:
|
1789 |
+
"""
|
1790 |
+
Args:
|
1791 |
+
hidden_state: [b, n, d]
|
1792 |
+
attention_mask: [b, n]
|
1793 |
+
"""
|
1794 |
+
# In case the model is distributed across multiple devices; hidden_state may end up on diff device
|
1795 |
+
hidden_state = hidden_state.to(attention_mask.device)
|
1796 |
+
if self.pooling_method == 'cls':
|
1797 |
+
embedding = hidden_state[:, 0]
|
1798 |
+
elif self.pooling_method == 'lasttoken':
|
1799 |
+
b, n, d = hidden_state.size()
|
1800 |
+
# Get the last `1` in the attention mask of each item
|
1801 |
+
# Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
|
1802 |
+
# except when 1) There's all 1's 2) There's 0's before the 1's
|
1803 |
+
reversed_mask = torch.flip(attention_mask, dims=(1,))
|
1804 |
+
argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
|
1805 |
+
gather_indices = attention_mask.size(1) - argmax_reverse - 1
|
1806 |
+
# If there are empty sequences, where the index would become -1 it will crash so set them to 0
|
1807 |
+
gather_indices = torch.clamp(gather_indices, min=0)
|
1808 |
+
# Turn indices from shape [b] -> [b, 1, d]
|
1809 |
+
gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
|
1810 |
+
gather_indices = gather_indices.unsqueeze(1)
|
1811 |
+
assert gather_indices.shape == (b, 1, d)
|
1812 |
+
# Gather along the seq len: [b, n, d] -> [b, d]
|
1813 |
+
# Actually no need for the attention mask as we gather the last token where attn_mask=1 but
|
1814 |
+
# as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
|
1815 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
|
1816 |
+
embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
|
1817 |
+
elif self.pooling_method in ['mean', 'weightedmean']:
|
1818 |
+
if self.pooling_method == 'weightedmean':
|
1819 |
+
attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
|
1820 |
+
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
|
1821 |
+
d = attention_mask.sum(dim=1, keepdim=True).float()
|
1822 |
+
embedding = s / d
|
1823 |
+
else: raise NotImplementedError(f"Unknown pooling method: {self.pooling_method}")
|
1824 |
+
# Recasting performs slightly worse but saves 50% space
|
1825 |
+
if recast: return embedding.to(hidden_state.dtype)
|
1826 |
+
return embedding
|