Update README.md
Browse files
README.md
CHANGED
@@ -15,7 +15,7 @@ There is no structure of Zhouwenwang-1.3B in [Transformers](https://github.com/h
|
|
15 |
git clone https://github.com/IDEA-CCNL/Fengshenbang-LM.git
|
16 |
```
|
17 |
|
18 |
-
### Load
|
19 |
```python
|
20 |
from model.roformer.modeling_roformer import RoFormerModel
|
21 |
from model.roformer.configuration_roformer import RoFormerConfig
|
@@ -27,6 +27,37 @@ model = RoFormerModel.from_pretrained("IDEA-CCNL/Zhouwenwang-1.3B")
|
|
27 |
|
28 |
|
29 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
## Scores on downstream chinese tasks (without any data augmentation)
|
31 |
| Model| afqmc | tnews | iflytek | ocnli | cmnli | wsc | csl |
|
32 |
| :--------: | :-----: | :----: | :-----: | :----: | :----: | :----: | :----: |
|
|
|
15 |
git clone https://github.com/IDEA-CCNL/Fengshenbang-LM.git
|
16 |
```
|
17 |
|
18 |
+
### Load model
|
19 |
```python
|
20 |
from model.roformer.modeling_roformer import RoFormerModel
|
21 |
from model.roformer.configuration_roformer import RoFormerConfig
|
|
|
27 |
|
28 |
|
29 |
```
|
30 |
+
### Generate task
|
31 |
+
You can use Zhouwenwang-1.3B to continue writing
|
32 |
+
|
33 |
+
```python
|
34 |
+
from model.roformer.modeling_roformer import RoFormerModel
|
35 |
+
from transformers import AutoTokenizer
|
36 |
+
import torch
|
37 |
+
import numpy as np
|
38 |
+
|
39 |
+
sentence = '清华大学位于'
|
40 |
+
max_length = 32
|
41 |
+
model_pretrained_weight_path = '/home/' # 预训练模型权重路径
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_pretrained_weight_path)
|
44 |
+
model = RoFormerModel.from_pretrained(model_pretrained_weight_path)
|
45 |
+
|
46 |
+
for i in range(max_length):
|
47 |
+
encode = torch.tensor(
|
48 |
+
[[tokenizer.cls_token_id]+tokenizer.encode(sentence, add_special_tokens=False)]).long()
|
49 |
+
logits = model(encode)[0]
|
50 |
+
logits = torch.nn.functional.linear(
|
51 |
+
logits, model.embeddings.word_embeddings.weight)
|
52 |
+
logits = torch.nn.functional.softmax(
|
53 |
+
logits, dim=-1).cpu().detach().numpy()[0]
|
54 |
+
sentence = sentence + \
|
55 |
+
tokenizer.decode(int(np.random.choice(logits.shape[1], p=logits[-1])))
|
56 |
+
if sentence[-1] == '。':
|
57 |
+
break
|
58 |
+
print(sentence)
|
59 |
+
```
|
60 |
+
|
61 |
## Scores on downstream chinese tasks (without any data augmentation)
|
62 |
| Model| afqmc | tnews | iflytek | ocnli | cmnli | wsc | csl |
|
63 |
| :--------: | :-----: | :----: | :-----: | :----: | :----: | :----: | :----: |
|