shoffman commited on
Commit
2054227
·
verified ·
1 Parent(s): d4b4e5a

Upload GP-MoLFormer code and pre-trained model

Browse files
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MolformerForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_molformer.MolformerConfig",
7
+ "AutoModelForCausalLM": "modeling_molformer.MolformerForCausalLM"
8
+ },
9
+ "bos_token_id": 0,
10
+ "deterministic_eval": false,
11
+ "embedding_dropout_prob": 0.2,
12
+ "eos_token_id": 1,
13
+ "feature_map_kernel": "relu",
14
+ "hidden_act": "gelu",
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 768,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 768,
19
+ "is_decoder": true,
20
+ "layer_norm_eps": 1e-12,
21
+ "linear_attention_eps": 1e-06,
22
+ "max_position_embeddings": 202,
23
+ "model_type": "molformer",
24
+ "num_attention_heads": 12,
25
+ "num_hidden_layers": 12,
26
+ "num_random_features": 32,
27
+ "pad_token_id": 2,
28
+ "tie_word_embeddings": false,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.32.1",
31
+ "use_cache": true,
32
+ "vocab_size": 2362
33
+ }
configuration_molformer.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Molformer model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.onnx import OnnxConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ MOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "ibm/GP-MoLFormer-Uniq": "https://huggingface.co/ibm/GP-MoLFormer-Uniq/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ class MolformerConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`MolformerModel`]. It is used to instantiate an
35
+ Molformer model according to the specified arguments, defining the model architecture. Instantiating a
36
+ configuration with the defaults will yield a similar configuration to that of the Molformer
37
+ [ibm/MoLFormer-XL-both-10pct](https://huggingface.co/ibm/MoLFormer-XL-both-10pct) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 2362):
45
+ Vocabulary size of the Molformer model. Defines the number of different tokens that can be represented by
46
+ the `inputs_ids` passed when calling [`MolformerModel`] or [`TFMolformerModel`].
47
+ hidden_size (`int`, *optional*, defaults to 768):
48
+ Dimension of the encoder layers and the pooler layer.
49
+ num_hidden_layers (`int`, *optional*, defaults to 12):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ intermediate_size (`int`, *optional*, defaults to 768):
54
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
56
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
57
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
58
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
59
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
60
+ embedding_dropout_prob (`float`, *optional*, defaults to 0.2):
61
+ The dropout probability for the word embeddings.
62
+ max_position_embeddings (`int`, *optional*, defaults to 202):
63
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
64
+ just in case (e.g., 512 or 1024 or 1536).
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ linear_attention_eps (`float`, *optional*, defaults to 1e-06):
70
+ The epsilon used by the linear attention layers normalization step.
71
+ num_random_features (`int`, *optional*, defaults to 32):
72
+ Random feature map dimension used in linear attention.
73
+ feature_map_kernel (`str` or `function`, *optional*, defaults to `"relu"`):
74
+ The non-linear activation function (function or string) in the generalized random features. If string,
75
+ `"gelu"`, `"relu"`, `"selu"`, and `"gelu_new"` ar supported.
76
+ deterministic_eval (`bool`, *optional*, defaults to `False`):
77
+ Whether the random features should only be redrawn when training or not. If `True` and `model.training` is
78
+ `False`, linear attention random feature weights will be constant, i.e., deterministic.
79
+ bos_token_id (`int`, *optional*, defaults to 0):
80
+ Beginning of stream token id.
81
+ eos_token_id (`int`, *optional*, defaults to 1):
82
+ End of stream token id.
83
+ pad_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the _padding_ token.
85
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
+ Whether to tie weight embeddings
87
+ is_decoder (`bool`, *optional*, defaults to `True`):
88
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
89
+ use_cache (`bool`, *optional*, defaults to `True`):
90
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
91
+ relevant if `config.is_decoder=True`.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import MolformerModel, MolformerConfig
97
+
98
+ >>> # Initializing a Molformer ibm/MoLFormer-XL-both-10pct style configuration
99
+ >>> configuration = MolformerConfig()
100
+
101
+ >>> # Initializing a model from the ibm/MoLFormer-XL-both-10pct style configuration
102
+ >>> model = MolformerModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "molformer"
108
+
109
+ def __init__(
110
+ self,
111
+ vocab_size=2362,
112
+ hidden_size=768,
113
+ num_hidden_layers=12,
114
+ num_attention_heads=12,
115
+ intermediate_size=768,
116
+ hidden_act="gelu",
117
+ hidden_dropout_prob=0.1,
118
+ embedding_dropout_prob=0.2,
119
+ max_position_embeddings=202,
120
+ initializer_range=0.02,
121
+ layer_norm_eps=1e-12,
122
+ linear_attention_eps=1e-6,
123
+ num_random_features=32,
124
+ feature_map_kernel="relu",
125
+ deterministic_eval=False,
126
+ bos_token_id=0,
127
+ eos_token_id=1,
128
+ pad_token_id=2,
129
+ tie_word_embeddings=False,
130
+ is_decoder=True,
131
+ use_cache=True,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(
135
+ bos_token_id=bos_token_id,
136
+ eos_token_id=eos_token_id,
137
+ pad_token_id=pad_token_id,
138
+ tie_word_embeddings=tie_word_embeddings,
139
+ **kwargs
140
+ )
141
+
142
+ self.vocab_size = vocab_size
143
+ self.hidden_size = hidden_size
144
+ self.num_hidden_layers = num_hidden_layers
145
+ self.num_attention_heads = num_attention_heads
146
+ self.hidden_act = hidden_act
147
+ self.intermediate_size = intermediate_size
148
+ self.hidden_dropout_prob = hidden_dropout_prob
149
+ self.embedding_dropout_prob = embedding_dropout_prob
150
+ self.max_position_embeddings = max_position_embeddings
151
+ self.initializer_range = initializer_range
152
+ self.layer_norm_eps = layer_norm_eps
153
+ self.linear_attention_eps = linear_attention_eps
154
+ self.num_random_features = num_random_features
155
+ self.feature_map_kernel = feature_map_kernel
156
+ self.deterministic_eval = deterministic_eval
157
+ self.is_decoder = is_decoder
158
+ self.use_cache = use_cache
159
+
160
+
161
+ # Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->Molformer
162
+ class MolformerOnnxConfig(OnnxConfig):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ if self.task == "multiple-choice":
166
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
167
+ else:
168
+ dynamic_axis = {0: "batch", 1: "sequence"}
169
+ return OrderedDict(
170
+ [
171
+ ("input_ids", dynamic_axis),
172
+ ("attention_mask", dynamic_axis),
173
+ ]
174
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2754f36d8cebc399a54cc486285b81105862159cda42d5af8715d12fd49ebd3c
3
+ size 187248784
modeling_molformer.py ADDED
@@ -0,0 +1,1075 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Molformer model."""
16
+
17
+
18
+ import math
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ )
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
34
+ from transformers.utils import (
35
+ add_code_sample_docstrings,
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from .configuration_molformer import MolformerConfig
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CHECKPOINT_FOR_DOC = "ibm/GP-MoLFormer-Uniq"
47
+ _CONFIG_FOR_DOC = "MolformerConfig"
48
+
49
+ MOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
50
+ "ibm/GP-MoLFormer-Uniq",
51
+ # See all MoLFormer models at https://huggingface.co/models?filter=molformer
52
+ ]
53
+
54
+
55
+ # Copied from transformers.models.esm.modeling_esm.rotate_half
56
+ def rotate_half(x):
57
+ x1, x2 = x.chunk(2, dim=-1)
58
+ return torch.cat((-x2, x1), dim=-1)
59
+
60
+
61
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
62
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
63
+ cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
64
+ sin = sin[position_ids].unsqueeze(1)
65
+ q_embed = (q * cos) + (rotate_half(q) * sin)
66
+ k_embed = (k * cos) + (rotate_half(k) * sin)
67
+ return q_embed, k_embed
68
+
69
+
70
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Molformer
71
+ class MolformerRotaryEmbedding(nn.Module):
72
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
73
+ super().__init__()
74
+
75
+ self.dim = dim
76
+ self.max_position_embeddings = max_position_embeddings
77
+ self.base = base
78
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
79
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
80
+
81
+ # Build here to make `torch.jit.trace` work.
82
+ self._set_cos_sin_cache(
83
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
84
+ )
85
+
86
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
87
+ self.max_seq_len_cached = seq_len
88
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
89
+
90
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
91
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
92
+ emb = torch.cat((freqs, freqs), dim=-1)
93
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
94
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
95
+
96
+ def forward(self, x, seq_len=None):
97
+ # x: [bs, num_attention_heads, seq_len, head_size]
98
+ if seq_len > self.max_seq_len_cached:
99
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
100
+
101
+ return (
102
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
103
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
104
+ )
105
+
106
+
107
+ class MolformerEmbeddings(nn.Module):
108
+ """Construct the embeddings from word embeddings."""
109
+
110
+ def __init__(self, config):
111
+ super().__init__()
112
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
113
+ self.dropout = nn.Dropout(config.embedding_dropout_prob)
114
+
115
+ def forward(
116
+ self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None
117
+ ) -> torch.Tensor:
118
+ if inputs_embeds is None:
119
+ inputs_embeds = self.word_embeddings(input_ids)
120
+
121
+ embeddings = inputs_embeds
122
+ embeddings = self.dropout(embeddings)
123
+ return embeddings
124
+
125
+
126
+ class MolformerFeatureMap(nn.Module):
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ self.query_size = config.hidden_size // config.num_attention_heads
130
+ self.num_components = config.num_random_features
131
+ self.orthogonal_random_weights()
132
+ if isinstance(config.feature_map_kernel, str):
133
+ self.kernel = ACT2FN[config.feature_map_kernel]
134
+ else:
135
+ self.kernel = config.feature_map_kernel
136
+ self.deterministic = config.deterministic_eval
137
+
138
+ def orthogonal_random_weights(self, device=None):
139
+ # make sure query size evenly divides feature size (round up)
140
+ num_batches = math.ceil(self.num_components / self.query_size)
141
+
142
+ def orthogonal_batch(size):
143
+ block = torch.randn(size, size, device=device)
144
+ norms = torch.linalg.norm(block, dim=1).unsqueeze(0)
145
+ Q, _ = torch.linalg.qr(block)
146
+ return Q * norms
147
+
148
+ random_weights = torch.cat([orthogonal_batch(self.query_size) for _ in range(num_batches)], dim=1)
149
+ random_weights = random_weights[:, : self.num_components]
150
+ self.register_buffer("weight", random_weights)
151
+
152
+ def forward(self, query, key, redraw=True):
153
+ if (not self.deterministic or self.training) and redraw is not False:
154
+ self.orthogonal_random_weights(query.device)
155
+ # generalized random fourier features
156
+ query = torch.matmul(query, self.weight)
157
+ key = torch.matmul(key, self.weight)
158
+ return self.kernel(query), self.kernel(key)
159
+
160
+
161
+ class MolformerSelfAttention(nn.Module):
162
+ def __init__(self, config):
163
+ super().__init__()
164
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
165
+ raise ValueError(
166
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
167
+ f"heads ({config.num_attention_heads})"
168
+ )
169
+
170
+ self.num_attention_heads = config.num_attention_heads
171
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
172
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
173
+
174
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
175
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
176
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
177
+
178
+ self.eps = config.linear_attention_eps
179
+
180
+ self.rotary_embeddings = MolformerRotaryEmbedding(
181
+ dim=self.attention_head_size, max_position_embeddings=config.max_position_embeddings
182
+ )
183
+ self.feature_map = MolformerFeatureMap(config)
184
+
185
+ self.is_decoder = config.is_decoder
186
+
187
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
188
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
189
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
190
+ x = x.view(new_x_shape)
191
+ return x.permute(0, 2, 1, 3)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ attention_mask: Optional[torch.FloatTensor] = None,
197
+ position_ids: Optional[torch.LongTensor] = None,
198
+ head_mask: Optional[torch.FloatTensor] = None,
199
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
200
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
201
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
202
+ output_attentions: Optional[bool] = False,
203
+ use_cache: Optional[bool] = None,
204
+ ) -> Tuple[torch.Tensor]:
205
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
206
+
207
+ # If this is instantiated as a cross-attention module, the keys
208
+ # and values come from an encoder; the attention mask needs to be
209
+ # such that the encoder's padding tokens are not attended to.
210
+ is_cross_attention = encoder_hidden_states is not None
211
+
212
+ if is_cross_attention and past_key_value is not None:
213
+ # reuse k,v, cross_attentions
214
+ key_layer = past_key_value[0]
215
+ value_layer = past_key_value[1]
216
+ attention_mask = encoder_attention_mask
217
+ elif is_cross_attention:
218
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
219
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
220
+ # save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
221
+ # Further calls to cross_attention layer can then reuse all cross-attention key/value_states
222
+ past_key_value = (key_layer, value_layer)
223
+ attention_mask = encoder_attention_mask
224
+ else:
225
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
226
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
227
+
228
+ kv_seq_len = key_layer.shape[-2]
229
+ if past_key_value is not None:
230
+ kv_seq_len += past_key_value[0].shape[-2]
231
+
232
+ cos, sin = self.rotary_embeddings(value_layer, seq_len=kv_seq_len)
233
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
234
+ # Apply the feature map to the queries and keys
235
+ query_layer, key_layer = self.feature_map(query_layer, key_layer, redraw=past_key_value is None)
236
+
237
+ if attention_mask is not None:
238
+ # since we don't use softmax, we need to reconvert this mask to 1/0
239
+ attention_mask = (attention_mask == 0).to(attention_mask.dtype)
240
+ # separate original mask from causal mask
241
+ per_query_attn = attention_mask[:, 0, -1]
242
+ if self.is_decoder and not is_cross_attention:
243
+ batch_size, _, seq_length, _ = key_layer.shape
244
+ per_query_extended = MolformerPreTrainedModel.create_extended_attention_mask_for_decoder(
245
+ (batch_size, seq_length), per_query_attn
246
+ )
247
+ else:
248
+ per_query_extended = per_query_attn[:, None, None, :]
249
+ if not torch.equal(attention_mask, per_query_extended):
250
+ raise ValueError(
251
+ "MolformerSelfAttention does not support arbitrary 3D attention. attention_mask must be 2D (i.e., [batch size, sequence length])"
252
+ )
253
+
254
+ key_layer = key_layer * per_query_attn[:, None, -key_layer.shape[2] :, None]
255
+
256
+ if self.is_decoder and not is_cross_attention:
257
+ # causal linear attention
258
+ if use_cache:
259
+ seq_length = key_layer.shape[2]
260
+ key_value_outer = torch.einsum("bhlr,bhld->bhlrd", key_layer, value_layer)
261
+ if seq_length > 1:
262
+ key_value_outer = key_value_outer.cumsum(dim=2)
263
+ key_layer = key_layer.cumsum(dim=2)
264
+ if past_key_value is not None:
265
+ key_layer = key_layer + past_key_value[0][:, :, [-1]]
266
+ key_value_outer = key_value_outer + past_key_value[1]
267
+ seq_length += past_key_value[0].shape[-2]
268
+ # NOTE: query and key/value length must match (i.e., square attention matrix)
269
+ context_layer = torch.einsum("bhlr,bhlrd->bhld", query_layer, key_value_outer)
270
+ # store running sum of key_layer, key_value_outer, and seq_length -- constant memory!
271
+ # further calls to uni-directional self-attention can sum previous decoder states with current states
272
+ past_key_value = (key_layer[:, :, [-1]].expand(-1, -1, seq_length, -1), key_value_outer[:, :, [-1]])
273
+ else:
274
+ try:
275
+ from fast_transformers.causal_product import causal_dot_product
276
+ context_layer = causal_dot_product(query_layer, key_layer, value_layer)
277
+ except ImportError as e:
278
+ logger.warning(f"{e}: Falling back to (slow) pytorch implementation.")
279
+ key_value_outer = torch.einsum("bhlr,bhld->bhlrd", key_layer, value_layer)
280
+ key_value_outer = key_value_outer.cumsum(dim=2)
281
+ context_layer = torch.einsum("bhlr,bhlrd->bhld", query_layer, key_value_outer)
282
+ key_layer = key_layer.cumsum(dim=2)
283
+ norm = torch.einsum("bhlr,bhlr->bhl", query_layer, key_layer).unsqueeze(-1).clamp(min=self.eps)
284
+ context_layer = context_layer / norm
285
+ else:
286
+ # linear attention
287
+ key_value = torch.matmul(key_layer.transpose(-1, -2), value_layer)
288
+ norm = torch.matmul(query_layer, key_layer.sum(dim=-2).unsqueeze(-1)).clamp(min=self.eps)
289
+ context_layer = torch.matmul(query_layer, key_value) / norm
290
+
291
+ if head_mask is not None:
292
+ context_layer = context_layer * head_mask
293
+
294
+ if output_attentions:
295
+ logger.warning(
296
+ "Outputting attentions in linear attention negates the efficiency gains! Only use for visualization/debugging."
297
+ )
298
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
299
+ if attention_mask is not None:
300
+ attention_scores = attention_scores * attention_mask
301
+ attention_probs = nn.functional.normalize(attention_scores, p=1, dim=-1, eps=self.eps)
302
+ if head_mask is not None:
303
+ attention_probs = attention_probs * head_mask
304
+ # recompute context_layer for grad
305
+ context_layer = torch.matmul(attention_probs, value_layer)
306
+
307
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
308
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
309
+ context_layer = context_layer.view(*new_context_layer_shape)
310
+
311
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
312
+
313
+ if self.is_decoder:
314
+ outputs = outputs + (past_key_value,)
315
+ return outputs
316
+
317
+
318
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
319
+ class MolformerSelfOutput(nn.Module):
320
+ def __init__(self, config):
321
+ super().__init__()
322
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
323
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
325
+
326
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
327
+ hidden_states = self.dense(hidden_states)
328
+ hidden_states = self.dropout(hidden_states)
329
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
330
+ return hidden_states
331
+
332
+
333
+ class MolformerAttention(nn.Module):
334
+ def __init__(self, config):
335
+ super().__init__()
336
+ self.self = MolformerSelfAttention(config)
337
+ self.output = MolformerSelfOutput(config)
338
+ self.pruned_heads = set()
339
+
340
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
341
+ def prune_heads(self, heads):
342
+ if len(heads) == 0:
343
+ return
344
+ heads, index = find_pruneable_heads_and_indices(
345
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
346
+ )
347
+
348
+ # Prune linear layers
349
+ self.self.query = prune_linear_layer(self.self.query, index)
350
+ self.self.key = prune_linear_layer(self.self.key, index)
351
+ self.self.value = prune_linear_layer(self.self.value, index)
352
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
353
+
354
+ # Update hyper params and store pruned heads
355
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
356
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
357
+ self.pruned_heads = self.pruned_heads.union(heads)
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: Optional[torch.FloatTensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ head_mask: Optional[torch.FloatTensor] = None,
365
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
366
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
367
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
368
+ output_attentions: Optional[bool] = False,
369
+ use_cache: Optional[bool] = None,
370
+ ) -> Tuple[torch.Tensor]:
371
+ self_outputs = self.self(
372
+ hidden_states,
373
+ attention_mask,
374
+ position_ids,
375
+ head_mask,
376
+ encoder_hidden_states,
377
+ encoder_attention_mask,
378
+ past_key_value,
379
+ output_attentions,
380
+ use_cache,
381
+ )
382
+ attention_output = self.output(self_outputs[0], hidden_states)
383
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
384
+ return outputs
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
388
+ class MolformerIntermediate(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
392
+ if isinstance(config.hidden_act, str):
393
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
394
+ else:
395
+ self.intermediate_act_fn = config.hidden_act
396
+
397
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
398
+ hidden_states = self.dense(hidden_states)
399
+ hidden_states = self.intermediate_act_fn(hidden_states)
400
+ return hidden_states
401
+
402
+
403
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
404
+ class MolformerOutput(nn.Module):
405
+ def __init__(self, config):
406
+ super().__init__()
407
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
408
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
409
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
410
+
411
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
412
+ hidden_states = self.dense(hidden_states)
413
+ hidden_states = self.dropout(hidden_states)
414
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
415
+ return hidden_states
416
+
417
+
418
+ class MolformerLayer(nn.Module):
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
422
+ self.seq_len_dim = 1
423
+ self.attention = MolformerAttention(config)
424
+ self.is_decoder = config.is_decoder
425
+ self.add_cross_attention = config.add_cross_attention
426
+ if self.add_cross_attention:
427
+ if not self.is_decoder:
428
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
429
+ self.crossattention = MolformerAttention(config)
430
+ self.intermediate = MolformerIntermediate(config)
431
+ self.output = MolformerOutput(config)
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.Tensor,
436
+ attention_mask: Optional[torch.FloatTensor] = None,
437
+ position_ids: Optional[torch.LongTensor] = None,
438
+ head_mask: Optional[torch.FloatTensor] = None,
439
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
440
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
441
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
442
+ output_attentions: Optional[bool] = False,
443
+ use_cache: Optional[bool] = None,
444
+ ) -> Tuple[torch.Tensor]:
445
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
446
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
447
+ self_attention_outputs = self.attention(
448
+ hidden_states,
449
+ attention_mask,
450
+ position_ids,
451
+ head_mask,
452
+ output_attentions=output_attentions,
453
+ past_key_value=self_attn_past_key_value,
454
+ use_cache=use_cache,
455
+ )
456
+ attention_output = self_attention_outputs[0]
457
+
458
+ # if decoder, the last output is tuple of self-attn cache
459
+ if self.is_decoder:
460
+ outputs = self_attention_outputs[1:-1]
461
+ present_key_value = self_attention_outputs[-1]
462
+ else:
463
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
464
+
465
+ cross_attn_present_key_value = None
466
+ if self.is_decoder and encoder_hidden_states is not None:
467
+ if not hasattr(self, "crossattention"):
468
+ raise ValueError(
469
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
470
+ " by setting `config.add_cross_attention=True`"
471
+ )
472
+
473
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
474
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
475
+ cross_attention_outputs = self.crossattention(
476
+ attention_output,
477
+ attention_mask,
478
+ position_ids,
479
+ head_mask,
480
+ encoder_hidden_states,
481
+ encoder_attention_mask,
482
+ cross_attn_past_key_value,
483
+ output_attentions,
484
+ use_cache,
485
+ )
486
+ attention_output = cross_attention_outputs[0]
487
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
488
+
489
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
490
+ cross_attn_present_key_value = cross_attention_outputs[-1]
491
+ present_key_value = present_key_value + cross_attn_present_key_value
492
+
493
+ layer_output = apply_chunking_to_forward(
494
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
495
+ )
496
+ outputs = (layer_output,) + outputs
497
+
498
+ # if decoder, return the attn key/values as the last output
499
+ if self.is_decoder:
500
+ outputs = outputs + (present_key_value,)
501
+
502
+ return outputs
503
+
504
+ def feed_forward_chunk(self, attention_output):
505
+ intermediate_output = self.intermediate(attention_output)
506
+ layer_output = self.output(intermediate_output, attention_output)
507
+ return layer_output
508
+
509
+
510
+ class MolformerEncoder(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.config = config
514
+ self.layer = nn.ModuleList([MolformerLayer(config) for _ in range(config.num_hidden_layers)])
515
+ self.gradient_checkpointing = False
516
+
517
+ def forward(
518
+ self,
519
+ hidden_states: torch.Tensor,
520
+ attention_mask: Optional[torch.FloatTensor] = None,
521
+ position_ids: Optional[torch.LongTensor] = None,
522
+ head_mask: Optional[torch.FloatTensor] = None,
523
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
524
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
525
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
526
+ use_cache: Optional[bool] = None,
527
+ output_attentions: Optional[bool] = False,
528
+ output_hidden_states: Optional[bool] = False,
529
+ return_dict: Optional[bool] = True,
530
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
531
+ all_hidden_states = () if output_hidden_states else None
532
+ all_self_attentions = () if output_attentions else None
533
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
534
+
535
+ if self.gradient_checkpointing and self.training:
536
+ if use_cache:
537
+ logger.warning_once(
538
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
539
+ )
540
+ use_cache = False
541
+
542
+ next_decoder_cache = () if use_cache else None
543
+ for i, layer_module in enumerate(self.layer):
544
+ if output_hidden_states:
545
+ all_hidden_states = all_hidden_states + (hidden_states,)
546
+
547
+ layer_head_mask = head_mask[i] if head_mask is not None else None
548
+ past_key_value = past_key_values[i] if past_key_values is not None else None
549
+
550
+ if self.gradient_checkpointing and self.training:
551
+
552
+ def create_custom_forward(module):
553
+ def custom_forward(*inputs):
554
+ return module(*inputs, past_key_value, output_attentions)
555
+
556
+ return custom_forward
557
+
558
+ layer_outputs = torch.utils.checkpoint.checkpoint(
559
+ create_custom_forward(layer_module),
560
+ hidden_states,
561
+ attention_mask,
562
+ position_ids,
563
+ layer_head_mask,
564
+ encoder_hidden_states,
565
+ encoder_attention_mask,
566
+ )
567
+ else:
568
+ layer_outputs = layer_module(
569
+ hidden_states,
570
+ attention_mask,
571
+ position_ids,
572
+ layer_head_mask,
573
+ encoder_hidden_states,
574
+ encoder_attention_mask,
575
+ past_key_value,
576
+ output_attentions,
577
+ use_cache,
578
+ )
579
+
580
+ hidden_states = layer_outputs[0]
581
+ if use_cache:
582
+ next_decoder_cache += (layer_outputs[-1],)
583
+ if output_attentions:
584
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
585
+ if self.config.add_cross_attention:
586
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
587
+
588
+ if output_hidden_states:
589
+ all_hidden_states = all_hidden_states + (hidden_states,)
590
+
591
+ if not return_dict:
592
+ return tuple(
593
+ v
594
+ for v in [
595
+ hidden_states,
596
+ next_decoder_cache,
597
+ all_hidden_states,
598
+ all_self_attentions,
599
+ all_cross_attentions,
600
+ ]
601
+ if v is not None
602
+ )
603
+ return BaseModelOutputWithPastAndCrossAttentions(
604
+ last_hidden_state=hidden_states,
605
+ past_key_values=next_decoder_cache,
606
+ hidden_states=all_hidden_states,
607
+ attentions=all_self_attentions,
608
+ cross_attentions=all_cross_attentions,
609
+ )
610
+
611
+
612
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform
613
+ class MolformerPredictionHeadTransform(nn.Module):
614
+ def __init__(self, config):
615
+ super().__init__()
616
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
617
+ if isinstance(config.hidden_act, str):
618
+ self.transform_act_fn = ACT2FN[config.hidden_act]
619
+ else:
620
+ self.transform_act_fn = config.hidden_act
621
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
622
+
623
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
624
+ hidden_states = self.dense(hidden_states)
625
+ hidden_states = self.transform_act_fn(hidden_states)
626
+ hidden_states = self.LayerNorm(hidden_states)
627
+ return hidden_states
628
+
629
+
630
+ class MolformerLMPredictionHead(nn.Module):
631
+ def __init__(self, config):
632
+ super().__init__()
633
+ self.transform = MolformerPredictionHeadTransform(config)
634
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
635
+
636
+ def forward(self, hidden_states):
637
+ hidden_states = self.transform(hidden_states)
638
+ hidden_states = self.decoder(hidden_states)
639
+ return hidden_states
640
+
641
+
642
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->Molformer,roberta->molformer
643
+ class MolformerPreTrainedModel(PreTrainedModel):
644
+ """
645
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
646
+ models.
647
+ """
648
+
649
+ config_class = MolformerConfig
650
+ base_model_prefix = "molformer"
651
+ supports_gradient_checkpointing = True
652
+ _no_split_modules = ["MolformerEmbeddings", "MolformerSelfAttention"]
653
+
654
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
655
+ def _init_weights(self, module):
656
+ """Initialize the weights"""
657
+ if isinstance(module, nn.Linear):
658
+ # Slightly different from the TF version which uses truncated_normal for initialization
659
+ # cf https://github.com/pytorch/pytorch/pull/5617
660
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
661
+ if module.bias is not None:
662
+ module.bias.data.zero_()
663
+ elif isinstance(module, nn.Embedding):
664
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
665
+ if module.padding_idx is not None:
666
+ module.weight.data[module.padding_idx].zero_()
667
+ elif isinstance(module, nn.LayerNorm):
668
+ module.bias.data.zero_()
669
+ module.weight.data.fill_(1.0)
670
+
671
+ def _set_gradient_checkpointing(self, module, value=False):
672
+ if isinstance(module, MolformerEncoder):
673
+ module.gradient_checkpointing = value
674
+
675
+
676
+ def masked_avg_pool1d(hidden_states, attention_mask, eps=1e-9):
677
+ attention_mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).float()
678
+ sum_embeddings = torch.sum(hidden_states * attention_mask, dim=1)
679
+ sum_mask = torch.clamp(attention_mask.sum(dim=1), min=eps)
680
+ embedding = sum_embeddings / sum_mask
681
+ return embedding
682
+
683
+
684
+ MOLFORMER_START_DOCSTRING = r"""
685
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
686
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
687
+ behavior.
688
+
689
+ Parameters:
690
+ config ([`MolformerConfig`]): Model configuration class with all the parameters of the model.
691
+ Initializing with a config file does not load the weights associated with the model, only the
692
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
693
+ """
694
+
695
+ MOLFORMER_INPUTS_DOCSTRING = r"""
696
+ Args:
697
+ input_ids (`torch.LongTensor` of shape `({0})`):
698
+ Indices of input sequence tokens in the vocabulary.
699
+
700
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
701
+ [`PreTrainedTokenizer.__call__`] for details.
702
+
703
+ [What are input IDs?](../glossary#input-ids)
704
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
705
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
706
+
707
+ - 1 for tokens that are **not masked**,
708
+ - 0 for tokens that are **masked**.
709
+
710
+ [What are attention masks?](../glossary#attention-mask)
711
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
712
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
713
+ config.n_positions - 1]`.
714
+
715
+ [What are position IDs?](../glossary#position-ids)
716
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
717
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
718
+
719
+ - 1 indicates the head is **not masked**,
720
+ - 0 indicates the head is **masked**.
721
+
722
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
723
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
724
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
725
+ model's internal embedding lookup matrix.
726
+ output_attentions (`bool`, *optional*):
727
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
728
+ tensors for more detail.
729
+ output_hidden_states (`bool`, *optional*):
730
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
731
+ more detail.
732
+ return_dict (`bool`, *optional*):
733
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
734
+ """
735
+
736
+
737
+ @add_start_docstrings(
738
+ "The bare Molformer Model transformer outputting raw hidden-states without any specific head on top.",
739
+ MOLFORMER_START_DOCSTRING,
740
+ )
741
+ class MolformerModel(MolformerPreTrainedModel):
742
+ """
743
+
744
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
745
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
746
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
747
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
748
+
749
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
750
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
751
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
752
+ """
753
+
754
+ def __init__(self, config, add_pooling_layer=True):
755
+ super().__init__(config)
756
+ self.config = config
757
+
758
+ self.embeddings = MolformerEmbeddings(config)
759
+ self.encoder = MolformerEncoder(config)
760
+
761
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
762
+ self.pooler = masked_avg_pool1d if add_pooling_layer else None
763
+
764
+ # Initialize weights and apply final processing
765
+ self.post_init()
766
+
767
+ def get_input_embeddings(self):
768
+ return self.embeddings.word_embeddings
769
+
770
+ def set_input_embeddings(self, value):
771
+ self.embeddings.word_embeddings = value
772
+
773
+ def _prune_heads(self, heads_to_prune):
774
+ """
775
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
776
+ class PreTrainedModel
777
+ """
778
+ for layer, heads in heads_to_prune.items():
779
+ self.encoder.layer[layer].attention.prune_heads(heads)
780
+
781
+ @add_start_docstrings_to_model_forward(MOLFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
782
+ @add_code_sample_docstrings(
783
+ checkpoint=_CHECKPOINT_FOR_DOC,
784
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
785
+ config_class=_CONFIG_FOR_DOC,
786
+ )
787
+ def forward(
788
+ self,
789
+ input_ids: Optional[torch.LongTensor] = None,
790
+ attention_mask: Optional[torch.FloatTensor] = None,
791
+ position_ids: Optional[torch.LongTensor] = None,
792
+ head_mask: Optional[torch.FloatTensor] = None,
793
+ inputs_embeds: Optional[torch.FloatTensor] = None,
794
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
795
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
796
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
797
+ use_cache: Optional[bool] = None,
798
+ output_attentions: Optional[bool] = None,
799
+ output_hidden_states: Optional[bool] = None,
800
+ return_dict: Optional[bool] = None,
801
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.Tensor]]:
802
+ r"""
803
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
804
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
805
+ the model is configured as a decoder.
806
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
807
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
808
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
809
+
810
+ - 1 for tokens that are **not masked**,
811
+ - 0 for tokens that are **masked**.
812
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
813
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
814
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
815
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
816
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
817
+ use_cache (`bool`, *optional*):
818
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
819
+ `past_key_values`).
820
+ """
821
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
822
+ output_hidden_states = (
823
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
824
+ )
825
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
826
+
827
+ if self.config.is_decoder:
828
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
829
+ else:
830
+ use_cache = False
831
+
832
+ if input_ids is not None and inputs_embeds is not None:
833
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
834
+ elif input_ids is not None:
835
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
836
+ input_shape = input_ids.size()
837
+ elif inputs_embeds is not None:
838
+ input_shape = inputs_embeds.size()[:-1]
839
+ else:
840
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
841
+
842
+ batch_size, seq_length = input_shape
843
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
844
+
845
+ # past_key_values_length
846
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
847
+
848
+ if position_ids is None:
849
+ position_ids = torch.arange(
850
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
851
+ )
852
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
853
+ else:
854
+ position_ids = position_ids.view(-1, seq_length).long()
855
+
856
+ if attention_mask is None:
857
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
858
+
859
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
860
+ # ourselves in which case we just need to make it broadcastable to all heads.
861
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
862
+
863
+ # If a 2D or 3D attention mask is provided for the cross-attention
864
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
865
+ if self.config.is_decoder and encoder_hidden_states is not None:
866
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
867
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
868
+ if encoder_attention_mask is None:
869
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
870
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
871
+ else:
872
+ encoder_extended_attention_mask = None
873
+
874
+ # Prepare head mask if needed
875
+ # 1.0 in head_mask indicate we keep the head
876
+ # attention_probs has shape bsz x n_heads x N x N
877
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
878
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
879
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
880
+
881
+ embedding_output = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
882
+
883
+ encoder_outputs = self.encoder(
884
+ embedding_output,
885
+ attention_mask=extended_attention_mask,
886
+ position_ids=position_ids,
887
+ head_mask=head_mask,
888
+ encoder_hidden_states=encoder_hidden_states,
889
+ encoder_attention_mask=encoder_extended_attention_mask,
890
+ past_key_values=past_key_values,
891
+ use_cache=use_cache,
892
+ output_attentions=output_attentions,
893
+ output_hidden_states=output_hidden_states,
894
+ return_dict=return_dict,
895
+ )
896
+ sequence_output = encoder_outputs[0]
897
+ sequence_output = self.LayerNorm(sequence_output)
898
+ pooled_output = self.pooler(sequence_output, attention_mask) if self.pooler is not None else None
899
+
900
+ if not return_dict:
901
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
902
+
903
+ return BaseModelOutputWithPoolingAndCrossAttentions(
904
+ last_hidden_state=sequence_output,
905
+ pooler_output=pooled_output,
906
+ past_key_values=encoder_outputs.past_key_values,
907
+ hidden_states=encoder_outputs.hidden_states,
908
+ attentions=encoder_outputs.attentions,
909
+ cross_attentions=encoder_outputs.cross_attentions,
910
+ )
911
+
912
+
913
+ @add_start_docstrings(
914
+ """Molformer Model with a `language modeling` head on top for CLM fine-tuning.""", MOLFORMER_START_DOCSTRING
915
+ )
916
+ class MolformerForCausalLM(MolformerPreTrainedModel):
917
+ _tied_weights_keys = ["lm_head.decoder.weight"]
918
+
919
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Molformer,roberta->molformer,LMHeadModel->ForCausalLM,LMHead->LMPredictionHead
920
+ def __init__(self, config):
921
+ super().__init__(config)
922
+
923
+ if not config.is_decoder:
924
+ logger.warning("If you want to use `MolformerForCausalLM` as a standalone, add `is_decoder=True.`")
925
+
926
+ self.molformer = MolformerModel(config, add_pooling_layer=False)
927
+ self.lm_head = MolformerLMPredictionHead(config)
928
+
929
+ # Initialize weights and apply final processing
930
+ self.post_init()
931
+
932
+ def get_output_embeddings(self):
933
+ return self.lm_head.decoder
934
+
935
+ def set_output_embeddings(self, new_embeddings):
936
+ self.lm_head.decoder = new_embeddings
937
+
938
+ @add_start_docstrings_to_model_forward(MOLFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
939
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
940
+ def forward(
941
+ self,
942
+ input_ids: Optional[torch.LongTensor] = None,
943
+ attention_mask: Optional[torch.FloatTensor] = None,
944
+ position_ids: Optional[torch.LongTensor] = None,
945
+ head_mask: Optional[torch.FloatTensor] = None,
946
+ inputs_embeds: Optional[torch.FloatTensor] = None,
947
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
948
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
949
+ labels: Optional[torch.LongTensor] = None,
950
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
951
+ use_cache: Optional[bool] = None,
952
+ output_attentions: Optional[bool] = None,
953
+ output_hidden_states: Optional[bool] = None,
954
+ return_dict: Optional[bool] = None,
955
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
956
+ r"""
957
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
958
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
959
+ the model is configured as a decoder.
960
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
961
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
962
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
963
+
964
+ - 1 for tokens that are **not masked**,
965
+ - 0 for tokens that are **masked**.
966
+
967
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
968
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
969
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
970
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
971
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
972
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
973
+
974
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
975
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
976
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
977
+ use_cache (`bool`, *optional*):
978
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
979
+ `past_key_values`).
980
+
981
+ Returns:
982
+ """
983
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
984
+ if labels is not None:
985
+ use_cache = False
986
+
987
+ outputs = self.molformer(
988
+ input_ids,
989
+ attention_mask=attention_mask,
990
+ position_ids=position_ids,
991
+ head_mask=head_mask,
992
+ inputs_embeds=inputs_embeds,
993
+ encoder_hidden_states=encoder_hidden_states,
994
+ encoder_attention_mask=encoder_attention_mask,
995
+ past_key_values=past_key_values,
996
+ use_cache=use_cache,
997
+ output_attentions=output_attentions,
998
+ output_hidden_states=output_hidden_states,
999
+ return_dict=return_dict,
1000
+ )
1001
+
1002
+ sequence_output = outputs[0]
1003
+ prediction_scores = self.lm_head(sequence_output)
1004
+
1005
+ lm_loss = None
1006
+ if labels is not None:
1007
+ # move labels to correct device to enable model parallelism
1008
+ labels = labels.to(prediction_scores.device)
1009
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1010
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1011
+ labels = labels[:, 1:].contiguous()
1012
+ loss_fct = CrossEntropyLoss()
1013
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1014
+
1015
+ if not return_dict:
1016
+ output = (prediction_scores,) + outputs[2:]
1017
+ return ((lm_loss,) + output) if lm_loss is not None else output
1018
+
1019
+ return CausalLMOutputWithCrossAttentions(
1020
+ loss=lm_loss,
1021
+ logits=prediction_scores,
1022
+ past_key_values=outputs.past_key_values,
1023
+ hidden_states=outputs.hidden_states,
1024
+ attentions=outputs.attentions,
1025
+ cross_attentions=outputs.cross_attentions,
1026
+ )
1027
+
1028
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1029
+ def prepare_inputs_for_generation(
1030
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1031
+ ):
1032
+ if past_key_values is not None:
1033
+ past_length = past_key_values[0][0].shape[2]
1034
+
1035
+ # Some generation methods already pass only the last input ID
1036
+ if input_ids.shape[1] > past_length:
1037
+ remove_prefix_length = past_length
1038
+ else:
1039
+ # Default to old behavior: keep only final ID
1040
+ remove_prefix_length = input_ids.shape[1] - 1
1041
+
1042
+ input_ids = input_ids[:, remove_prefix_length:]
1043
+
1044
+ position_ids = kwargs.get("position_ids", None)
1045
+ if attention_mask is not None and position_ids is None:
1046
+ # create position_ids on the fly for batch generation
1047
+ position_ids = attention_mask.long().cumsum(-1) - 1
1048
+ position_ids.masked_fill_(attention_mask == 0, 1)
1049
+ if past_key_values:
1050
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1051
+
1052
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1053
+ if inputs_embeds is not None and past_key_values is None:
1054
+ model_inputs = {"inputs_embeds": inputs_embeds}
1055
+ else:
1056
+ model_inputs = {"input_ids": input_ids}
1057
+
1058
+ model_inputs.update(
1059
+ {
1060
+ "position_ids": position_ids,
1061
+ "past_key_values": past_key_values,
1062
+ "use_cache": kwargs.get("use_cache"),
1063
+ "attention_mask": attention_mask,
1064
+ }
1065
+ )
1066
+ return model_inputs
1067
+
1068
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
1069
+ def _reorder_cache(self, past_key_values, beam_idx):
1070
+ reordered_past = ()
1071
+ for layer_past in past_key_values:
1072
+ reordered_past += (
1073
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1074
+ )
1075
+ return reordered_past