2 使用块的网络(VGG)
2.1 VGG块
- 经典卷积神经网络的基本组成部分是如下序列:
- 带填充以保持分辨率的卷积层
- 非线性激活函数,如ReLU
- 汇聚层,如最大汇聚层
1 | import torch |
2.2 VGG网络
- 从AlexNet到VGG,它们本质上都是块设计。
- 原始VGG网络有5个卷积块,前两个块各有一个卷积层,后三个块各包含两个卷积层,因此共有8个卷积层。
- 第一个模块有64输出通道,每个后续模块将输出通道翻倍,直到该数字达到512.
- 由于该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11。
1 | # 指定了每个VGG块里卷积层个数和输出通道数。 |
- 查看输出形状
1 | X = torch.randn(size=(1, 1, 224, 224)) |
Sequential output shape: torch.Size([1, 64, 112, 112])
Sequential output shape: torch.Size([1, 128, 56, 56])
Sequential output shape: torch.Size([1, 256, 28, 28])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
Flatten output shape: torch.Size([1, 25088])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Dropout output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Dropout output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 10])
2.3 训练模型
- 由于VGG比AlexNet计算量更大,因此我们构建了一个通道数较少的网络,用于训练Fashion-MNIST数据集
1 | import os |