Transformers
PyTorch
Chinese
megatron-bert
suolyer commited on
Commit
a8400d3
·
1 Parent(s): b78f377

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -1
README.md CHANGED
@@ -15,7 +15,7 @@ There is no structure of Zhouwenwang-1.3B in [Transformers](https://github.com/h
15
  git clone https://github.com/IDEA-CCNL/Fengshenbang-LM.git
16
  ```
17
 
18
- ### Load Model
19
  ```python
20
  from model.roformer.modeling_roformer import RoFormerModel
21
  from model.roformer.configuration_roformer import RoFormerConfig
@@ -27,6 +27,37 @@ model = RoFormerModel.from_pretrained("IDEA-CCNL/Zhouwenwang-1.3B")
27
 
28
 
29
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ## Scores on downstream chinese tasks (without any data augmentation)
31
  | Model| afqmc | tnews | iflytek | ocnli | cmnli | wsc | csl |
32
  | :--------: | :-----: | :----: | :-----: | :----: | :----: | :----: | :----: |
 
15
  git clone https://github.com/IDEA-CCNL/Fengshenbang-LM.git
16
  ```
17
 
18
+ ### Load model
19
  ```python
20
  from model.roformer.modeling_roformer import RoFormerModel
21
  from model.roformer.configuration_roformer import RoFormerConfig
 
27
 
28
 
29
  ```
30
+ ### Generate task
31
+ You can use Zhouwenwang-1.3B to continue writing
32
+
33
+ ```python
34
+ from model.roformer.modeling_roformer import RoFormerModel
35
+ from transformers import AutoTokenizer
36
+ import torch
37
+ import numpy as np
38
+
39
+ sentence = '清华大学位于'
40
+ max_length = 32
41
+ model_pretrained_weight_path = '/home/' # 预训练模型权重路径
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(model_pretrained_weight_path)
44
+ model = RoFormerModel.from_pretrained(model_pretrained_weight_path)
45
+
46
+ for i in range(max_length):
47
+ encode = torch.tensor(
48
+ [[tokenizer.cls_token_id]+tokenizer.encode(sentence, add_special_tokens=False)]).long()
49
+ logits = model(encode)[0]
50
+ logits = torch.nn.functional.linear(
51
+ logits, model.embeddings.word_embeddings.weight)
52
+ logits = torch.nn.functional.softmax(
53
+ logits, dim=-1).cpu().detach().numpy()[0]
54
+ sentence = sentence + \
55
+ tokenizer.decode(int(np.random.choice(logits.shape[1], p=logits[-1])))
56
+ if sentence[-1] == '。':
57
+ break
58
+ print(sentence)
59
+ ```
60
+
61
  ## Scores on downstream chinese tasks (without any data augmentation)
62
  | Model| afqmc | tnews | iflytek | ocnli | cmnli | wsc | csl |
63
  | :--------: | :-----: | :----: | :-----: | :----: | :----: | :----: | :----: |