rm deprecated assertion
Browse files- automodel.py +2 -5
automodel.py
CHANGED
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|
17 |
|
18 |
class ClinicalMosaicForEmbeddingGeneration(BertPreTrainedModel):
|
19 |
|
20 |
-
def __init__(self, config,
|
21 |
"""
|
22 |
Initializes the BertEmbeddings class.
|
23 |
|
@@ -26,11 +26,8 @@ class ClinicalMosaicForEmbeddingGeneration(BertPreTrainedModel):
|
|
26 |
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
|
27 |
"""
|
28 |
super().__init__(config)
|
29 |
-
assert (
|
30 |
-
config.num_hidden_layers >= config.num_embedding_layers
|
31 |
-
), "num_hidden_layers should be greater than or equal to num_embedding_layers"
|
32 |
self.config = config
|
33 |
-
self.bert = BertModel(config, add_pooling_layer=
|
34 |
# this resets the weights
|
35 |
self.post_init()
|
36 |
|
|
|
17 |
|
18 |
class ClinicalMosaicForEmbeddingGeneration(BertPreTrainedModel):
|
19 |
|
20 |
+
def __init__(self, config, **kwargs):
|
21 |
"""
|
22 |
Initializes the BertEmbeddings class.
|
23 |
|
|
|
26 |
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
|
27 |
"""
|
28 |
super().__init__(config)
|
|
|
|
|
|
|
29 |
self.config = config
|
30 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
31 |
# this resets the weights
|
32 |
self.post_init()
|
33 |
|