python中torch的使用
初始化:
1
2
3
4
5
6
7
8
9
10
11import torch.nn as nn
def xavier_init_weights(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
if type(m) == nn.GRU:
for param in m._flat_weights_names:
if "weight" in param:
nn.init.xavier_uniform_(m._parameters[param])
net.apply(xavier_init_weights)生成矩阵:
1
2
torch.eye(10).unsqueeze(0).unsqueeze(0).repeat(2, 2, 1, 1)