Randolphzeng commited on
Commit
a128f1f
·
1 Parent(s): 53adc0a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -20
README.md CHANGED
@@ -28,9 +28,6 @@ A deep VAE model pretrained on Wudao dataset. Both encoder and decoder are based
28
 
29
  参考论文:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
30
 
31
- 基于[Randeng-Pegasus-523M-Chinese](https://huggingface.co/IDEA-CCNL/Randeng-Pegasus-523M-Chinese),我们在收集的7个中文领域的文本摘要数据集(约4M个样本)上微调了它,得到了summary版本。这7个数据集为:education, new2016zh, nlpcc, shence, sohu, thucnews和weibo。
32
-
33
- Based on [Randeng-Pegasus-523M-Chinese](https://huggingface.co/IDEA-CCNL/Randeng-Pegasus-523M-Chinese), we fine-tuned a text summarization version (summary) on 7 Chinese text summarization datasets, with totaling around 4M samples. The datasets include: education, new2016zh, nlpcc, shence, sohu, thucnews and weibo.
34
 
35
  ### 下游效果 Performance
36
 
@@ -41,27 +38,65 @@ Based on [Randeng-Pegasus-523M-Chinese](https://huggingface.co/IDEA-CCNL/Randeng
41
  ## 使用 Usage
42
 
43
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- from transformers import PegasusForConditionalGeneration
46
- # Need to download tokenizers_pegasus.py and other Python script from Fengshenbang-LM github repo in advance,
47
- # or you can download tokenizers_pegasus.py and data_utils.py in https://huggingface.co/IDEA-CCNL/Randeng_Pegasus_523M/tree/main
48
- # Strongly recommend you git clone the Fengshenbang-LM repo:
49
- # 1. git clone https://github.com/IDEA-CCNL/Fengshenbang-LM
50
- # 2. cd Fengshenbang-LM/fengshen/examples/pegasus/
51
- # and then you will see the tokenizers_pegasus.py and data_utils.py which are needed by pegasus model
52
- from tokenizers_pegasus import PegasusTokenizer
53
-
54
- model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
55
- tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
56
 
57
- text = "据微信公众号“界面”报道,4日上午10点左右,中国发改委反垄断调查小组突击查访奔驰上海办事处,调取数据材料,并对多名奔驰高管进行了约谈。截止昨日晚9点,包括北京梅赛德斯-奔驰销售服务有限公司东区总经理在内的多名管理人员仍留在上海办公室内"
58
- inputs = tokenizer(text, max_length=1024, return_tensors="pt")
59
 
60
- # Generate Summary
61
- summary_ids = model.generate(inputs["input_ids"])
62
- tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63
 
64
- # model Output: 反垄断调查小组突击查访奔驰上海办事处,对多名奔��高管进行约谈
65
  ```
66
 
67
  ## 引用 Citation
 
28
 
29
  参考论文:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
30
 
 
 
 
31
 
32
  ### 下游效果 Performance
33
 
 
38
  ## 使用 Usage
39
 
40
  ```python
41
+ # Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory
42
+ import sys
43
+ import torch
44
+ import argparse
45
+ from torch.nn.utils.rnn import pad_sequence
46
+ from fengshen.models.deepVAE.vae_pl_module import DeepVAEModule
47
+
48
+
49
+
50
+ if __name__ == "__main__":
51
+ # TODO: Update this path to the downloaded directory
52
+ checkpoint_path = '..../Randeng-DELLA-226M-Chinese'
53
+ gpt2_model_path = '..../Randeng-DELLA-226M-Chinese'
54
+
55
+ args_parser = argparse.ArgumentParser()
56
+ args_parser.add_argument("--checkpoint_path", type=str, default=checkpoint_path)
57
+ args_parser.add_argument("--gpt2_model_path", type=str, default=gpt2_model_path)
58
+ args_parser.add_argument("--latent_dim", type=int, default=256)
59
+ args_parser.add_argument("--beta_kl_constraints_start", type=float, default=1e-5)
60
+ args_parser.add_argument("--beta_kl_constraints_stop", type=float, default=1.)
61
+ args_parser.add_argument("--beta_n_cycles", type=int, default=10)
62
+ args_parser.add_argument("--latent_lmf_rank", type=int, default=4)
63
+ args_parser.add_argument("--CVAE", action='store_true')
64
+ args_parser.add_argument("--share_param", action='store_false',
65
+ help="specify this argument if we want to share dec's and enc's params")
66
+
67
+ args, unknown_args = args_parser.parse_known_args()
68
+
69
+ # load model
70
+ model, tokenizer = DeepVAEModule.load_model(args, labels_dict=None)
71
+ # VAE generation
72
+ sentence = "本模型是在通用数据集下预训练的VAE模型,如要获得最佳效果请在特定领域微调后使用。"
73
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))
74
+ decoder_target = [tokenizer.bos_token_id] + tokenized_text + [tokenizer.eos_token_id]
75
+ inputs = []
76
+ inputs.append(torch.tensor(decoder_target, dtype=torch.long))
77
+ inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
78
+
79
+ max_length = 256
80
+ top_p = 0.5
81
+ top_k = 0
82
+ temperature = .7
83
+ repetition_penalty = 1.0
84
+ sample = False
85
+ device = 0
86
+ model = model.eval()
87
+ model = model.to(device)
88
+
89
+ outputs = model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
90
+ temperature=temperature, repetition_penalty=repetition_penalty)
91
+
92
+ for gen_sent, orig_sent in zip(outputs, inputs):
93
+ print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
94
+ print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
95
+ print("-"*20)
96
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
98
 
 
 
 
99
 
 
100
  ```
101
 
102
  ## 引用 Citation