|
--- |
|
language: |
|
- ru |
|
tags: |
|
- summarization |
|
license: MIT |
|
model: |
|
- bert |
|
- rubert |
|
--- |
|
|
|
# rubert_ria_headlines |
|
|
|
## Description |
|
*bert2bert* model, initialized with the `DeepPavlov/rubert-base-cased` pretrained weights and |
|
fine-tuned on the first 90% of ["Rossiya Segodnya" news dataset](https://github.com/RossiyaSegodnya/ria_news_dataset) for 3 epochs. |
|
|
|
## Usage example |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
MODEL_NAME = "dmitry-vorobiev/rubert_ria_headlines" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
|
text = "Скопируйте текст статьи / новости" |
|
|
|
encoded_batch = tokenizer.prepare_seq2seq_batch( |
|
[text], |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512) |
|
|
|
output_ids = model.generate( |
|
input_ids=encoded_batch["input_ids"], |
|
max_length=32, |
|
no_repeat_ngram_size=3, |
|
num_beams=5, |
|
top_k=0 |
|
) |
|
|
|
headline = tokenizer.decode(output_ids[0], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False) |
|
print(headline) |
|
``` |
|
|
|
## Datasets |
|
- [ria_news](https://github.com/RossiyaSegodnya/ria_news_dataset) |
|
|
|
## How it was trained? |
|
|
|
I used free TPUv3 on kaggle. The model was trained for 3 epochs with effective batch size 192 and soft restarts (warmup steps 1500 / 500 / 500 with new optimizer state on each epoch start). |
|
|
|
- [1 epoch notebook](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53254694) |
|
- [2 epoch notebook](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53269040) |
|
- [3 epoch notebook](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53280797) |
|
|
|
Common train params: |
|
|
|
```shell |
|
export XLA_USE_BF16=1 |
|
export XLA_TENSOR_ALLOCATOR_MAXSIZE=100000000 |
|
|
|
python nlp_headline_rus/src/train_seq2seq.py \ |
|
--do_train \ |
|
--tie_encoder_decoder \ |
|
--max_source_length 512 \ |
|
--max_target_length 32 \ |
|
--val_max_target_length 48 \ |
|
--tpu_num_cores 8 \ |
|
--per_device_train_batch_size 24 \ |
|
--gradient_accumulation_steps 1 \ |
|
--learning_rate 5e-4 \ |
|
--adam_epsilon 1e-6 \ |
|
--weight_decay 1e-5 \ |
|
``` |
|
|
|
## Validation results |
|
|
|
- Using [last 1% of ria](https://drive.google.com/drive/folders/1xtCnkbGNNu5jGQ9H9Mg55Cx7RTcyhQw9) dataset |
|
- Using [last 10% of ria](https://drive.google.com/drive/folders/1w6rAXhpFUO8I4A7xfHKUjMBPEKBHEO3h) dataset |