Commit
·
a128f1f
1
Parent(s):
53adc0a
Update README.md
Browse files
README.md
CHANGED
@@ -28,9 +28,6 @@ A deep VAE model pretrained on Wudao dataset. Both encoder and decoder are based
|
|
28 |
|
29 |
参考论文:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
|
30 |
|
31 |
-
基于[Randeng-Pegasus-523M-Chinese](https://huggingface.co/IDEA-CCNL/Randeng-Pegasus-523M-Chinese),我们在收集的7个中文领域的文本摘要数据集(约4M个样本)上微调了它,得到了summary版本。这7个数据集为:education, new2016zh, nlpcc, shence, sohu, thucnews和weibo。
|
32 |
-
|
33 |
-
Based on [Randeng-Pegasus-523M-Chinese](https://huggingface.co/IDEA-CCNL/Randeng-Pegasus-523M-Chinese), we fine-tuned a text summarization version (summary) on 7 Chinese text summarization datasets, with totaling around 4M samples. The datasets include: education, new2016zh, nlpcc, shence, sohu, thucnews and weibo.
|
34 |
|
35 |
### 下游效果 Performance
|
36 |
|
@@ -41,27 +38,65 @@ Based on [Randeng-Pegasus-523M-Chinese](https://huggingface.co/IDEA-CCNL/Randeng
|
|
41 |
## 使用 Usage
|
42 |
|
43 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
from transformers import PegasusForConditionalGeneration
|
46 |
-
# Need to download tokenizers_pegasus.py and other Python script from Fengshenbang-LM github repo in advance,
|
47 |
-
# or you can download tokenizers_pegasus.py and data_utils.py in https://huggingface.co/IDEA-CCNL/Randeng_Pegasus_523M/tree/main
|
48 |
-
# Strongly recommend you git clone the Fengshenbang-LM repo:
|
49 |
-
# 1. git clone https://github.com/IDEA-CCNL/Fengshenbang-LM
|
50 |
-
# 2. cd Fengshenbang-LM/fengshen/examples/pegasus/
|
51 |
-
# and then you will see the tokenizers_pegasus.py and data_utils.py which are needed by pegasus model
|
52 |
-
from tokenizers_pegasus import PegasusTokenizer
|
53 |
-
|
54 |
-
model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
|
55 |
-
tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
|
56 |
|
57 |
-
text = "据微信公众号“界面”报道,4日上午10点左右,中国发改委反垄断调查小组突击查访奔驰上海办事处,调取数据材料,并对多名奔驰高管进行了约谈。截止昨日晚9点,包括北京梅赛德斯-奔驰销售服务有限公司东区总经理在内的多名管理人员仍留在上海办公室内"
|
58 |
-
inputs = tokenizer(text, max_length=1024, return_tensors="pt")
|
59 |
|
60 |
-
# Generate Summary
|
61 |
-
summary_ids = model.generate(inputs["input_ids"])
|
62 |
-
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
63 |
|
64 |
-
# model Output: 反垄断调查小组突击查访奔驰上海办事处,对多名奔��高管进行约谈
|
65 |
```
|
66 |
|
67 |
## 引用 Citation
|
|
|
28 |
|
29 |
参考论文:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
|
30 |
|
|
|
|
|
|
|
31 |
|
32 |
### 下游效果 Performance
|
33 |
|
|
|
38 |
## 使用 Usage
|
39 |
|
40 |
```python
|
41 |
+
# Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory
|
42 |
+
import sys
|
43 |
+
import torch
|
44 |
+
import argparse
|
45 |
+
from torch.nn.utils.rnn import pad_sequence
|
46 |
+
from fengshen.models.deepVAE.vae_pl_module import DeepVAEModule
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
# TODO: Update this path to the downloaded directory
|
52 |
+
checkpoint_path = '..../Randeng-DELLA-226M-Chinese'
|
53 |
+
gpt2_model_path = '..../Randeng-DELLA-226M-Chinese'
|
54 |
+
|
55 |
+
args_parser = argparse.ArgumentParser()
|
56 |
+
args_parser.add_argument("--checkpoint_path", type=str, default=checkpoint_path)
|
57 |
+
args_parser.add_argument("--gpt2_model_path", type=str, default=gpt2_model_path)
|
58 |
+
args_parser.add_argument("--latent_dim", type=int, default=256)
|
59 |
+
args_parser.add_argument("--beta_kl_constraints_start", type=float, default=1e-5)
|
60 |
+
args_parser.add_argument("--beta_kl_constraints_stop", type=float, default=1.)
|
61 |
+
args_parser.add_argument("--beta_n_cycles", type=int, default=10)
|
62 |
+
args_parser.add_argument("--latent_lmf_rank", type=int, default=4)
|
63 |
+
args_parser.add_argument("--CVAE", action='store_true')
|
64 |
+
args_parser.add_argument("--share_param", action='store_false',
|
65 |
+
help="specify this argument if we want to share dec's and enc's params")
|
66 |
+
|
67 |
+
args, unknown_args = args_parser.parse_known_args()
|
68 |
+
|
69 |
+
# load model
|
70 |
+
model, tokenizer = DeepVAEModule.load_model(args, labels_dict=None)
|
71 |
+
# VAE generation
|
72 |
+
sentence = "本模型是在通用数据集下预训练的VAE模型,如要获得最佳效果请在特定领域微调后使用。"
|
73 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))
|
74 |
+
decoder_target = [tokenizer.bos_token_id] + tokenized_text + [tokenizer.eos_token_id]
|
75 |
+
inputs = []
|
76 |
+
inputs.append(torch.tensor(decoder_target, dtype=torch.long))
|
77 |
+
inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
|
78 |
+
|
79 |
+
max_length = 256
|
80 |
+
top_p = 0.5
|
81 |
+
top_k = 0
|
82 |
+
temperature = .7
|
83 |
+
repetition_penalty = 1.0
|
84 |
+
sample = False
|
85 |
+
device = 0
|
86 |
+
model = model.eval()
|
87 |
+
model = model.to(device)
|
88 |
+
|
89 |
+
outputs = model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
|
90 |
+
temperature=temperature, repetition_penalty=repetition_penalty)
|
91 |
+
|
92 |
+
for gen_sent, orig_sent in zip(outputs, inputs):
|
93 |
+
print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
|
94 |
+
print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
|
95 |
+
print("-"*20)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
|
|
98 |
|
|
|
|
|
|
|
99 |
|
|
|
100 |
```
|
101 |
|
102 |
## 引用 Citation
|