1 | import torch.nn as nn |
register_buffer
的作用是将 一个tensor变量注册到模型的 buffers() 属性中, 该变量 不会有梯度传播给它,但是能被模型的state_dict记录下来。可以理解为模型的常数。既然register_buffer的对象是模型中的常数,那为什么不直接使用下面的方法一,还不更直接吗 ?
1 | class net(nn.Module): |
- 我们可能会遇到这样的场景:那个常数不是这么简单的常数,而是外部传入的。
1 | class net(nn.Module): |
1 | #如果是方法一,你又要运行一遍获得x的过程。 |
1 | #如果是方法二,不需要获得x,因为register_buffer会将常数x保存在state_dict中,载入就行了。 |
特性 | 说明 |
---|---|
非可训练参数 | 不会被梯度更新,但参与前向传播 |
设备感知 | 自动随模型切换设备(CPU/GPU) |
状态持久化 | 会被保存到state_dict 中,随模型参数一起保存/加载 |
广播优化 | 通过预设维度对齐,实现张量自动广播 |