Commit
·
bff4577
1
Parent(s):
7a34ec2
upd README
Browse files
README.md
CHANGED
@@ -11,6 +11,66 @@ license: MIT
|
|
11 |
## Description
|
12 |
*bert2bert* model, initialized with the `DeepPavlov/rubert-base-cased` pretrained weights and
|
13 |
fine-tuned on the first 90% of ["Rossiya Segodnya" news dataset](https://github.com/RossiyaSegodnya/ria_news_dataset) for 1.6 epochs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
## Datasets
|
16 |
-
- [ria_news](https://github.com/RossiyaSegodnya/ria_news_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
## Description
|
12 |
*bert2bert* model, initialized with the `DeepPavlov/rubert-base-cased` pretrained weights and
|
13 |
fine-tuned on the first 90% of ["Rossiya Segodnya" news dataset](https://github.com/RossiyaSegodnya/ria_news_dataset) for 1.6 epochs.
|
14 |
+
|
15 |
+
## Usage example
|
16 |
+
|
17 |
+
```python
|
18 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
19 |
+
|
20 |
+
MODEL_NAME = "dmitry-vorobiev/rubert_ria_headlines"
|
21 |
+
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
23 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
24 |
+
|
25 |
+
text = "Скопируйте текст статьи / новости"
|
26 |
+
|
27 |
+
encoded_batch = tokenizer.prepare_seq2seq_batch(
|
28 |
+
[text],
|
29 |
+
return_tensors="pt",
|
30 |
+
padding="max_length",
|
31 |
+
truncation=True,
|
32 |
+
max_length=512)
|
33 |
+
|
34 |
+
output_ids = model.generate(
|
35 |
+
input_ids=encoded_batch["input_ids"],
|
36 |
+
max_length=32,
|
37 |
+
no_repeat_ngram_size=3,
|
38 |
+
num_beams=5,
|
39 |
+
top_k=0
|
40 |
+
)
|
41 |
+
|
42 |
+
headline = tokenizer.decode(output_ids[0],
|
43 |
+
skip_special_tokens=True,
|
44 |
+
clean_up_tokenization_spaces=False)
|
45 |
+
print(headline)
|
46 |
+
```
|
47 |
|
48 |
## Datasets
|
49 |
+
- [ria_news](https://github.com/RossiyaSegodnya/ria_news_dataset)
|
50 |
+
|
51 |
+
## How it was trained?
|
52 |
+
|
53 |
+
Short answer - it's a mess :D
|
54 |
+
|
55 |
+
1. [0.4 ep](https://www.kaggle.com/dvorobiev/train-seq2seq?scriptVersionId=52758945)
|
56 |
+
2. [0.8 ep](https://www.kaggle.com/dvorobiev/train-seq2seq?scriptVersionId=52794838)
|
57 |
+
3. [1.2 ep](https://www.kaggle.com/dvorobiev/train-seq2seq?scriptVersionId=52838778)
|
58 |
+
4. [1.6 ep](https://www.kaggle.com/dvorobiev/train-seq2seq?scriptVersionId=52876230)
|
59 |
+
|
60 |
+
Common train params:
|
61 |
+
|
62 |
+
```shell
|
63 |
+
python nlp_headline_rus/src/train_seq2seq.py \
|
64 |
+
--do_train \
|
65 |
+
--fp16 \
|
66 |
+
--tie_encoder_decoder \
|
67 |
+
--max_source_length 512 \
|
68 |
+
--max_target_length 32 \
|
69 |
+
--val_max_target_length 48 \
|
70 |
+
--per_device_train_batch_size 14 \
|
71 |
+
--gradient_accumulation_steps 4 \
|
72 |
+
--warmup_steps 2000 \
|
73 |
+
--learning_rate 3e-4 \
|
74 |
+
--adam_epsilon 1e-6 \
|
75 |
+
--weight_decay 1e-5 \
|
76 |
+
```
|