custom-mymodel_v1 / test_model.py
txt2audio's picture
Upload model
8d8287a
raw
history blame contribute delete
1.31 kB
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from typing import List
import torch.nn as nn
import torch
class MyModelConfig(PretrainedConfig):
def __init__(# 每个参数都必须带有默认值,否则会报错
self,
input_dim=100,
layers_num=5,
**kwargs,
):
self.input_dim = input_dim
self.layers_num = layers_num
super().__init__(**kwargs)
class MyModel(PreTrainedModel):
config_class = MyModelConfig
def __init__(self, config):
super().__init__(config)
modules = []
assert config.layers_num >= 1
if config.layers_num == 1:
modules.append(nn.Linear(config.input_dim,1))
else:
modules.append(nn.Linear(config.input_dim,30))
for i in range(config.layers_num-2):
modules.append(nn.Linear(30,30))
modules.append(nn.Linear(30,1))
self.model = nn.ModuleList(modules)
def forward(self, tensor):
return self.model(tensor)
if __name__ == '__main__':
save_config = MyModelConfig(input_dim=10,layers_num=3)
save_config.save_pretrained("custom-mymodel")
mymodel = MyModel(save_config)
torch.save(mymodel.state_dict(),'pytorch_model.bin') # 通常以此命名