初始化
Contact me
- Blog -> https://cugtyt.github.io/blog/index
- Email -> cugtyt@qq.com
- GitHub -> Cugtyt@GitHub
本系列博客主页及相关见此处
来自resnet 官方实现。
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
像nn.init
下还有很多初始化方式,如nn.init.uniform_(tensor, a=0.0, b=1.0)
,nn.init.normal_(tensor, mean=0.0, std=1.0)
, nn.init.constant_(tensor, val)
, nn.init.ones_(tensor)
等,最知名的有如下几个:
torch.nn.init.xavier_uniform_(tensor, gain=1.0)
从$U(-a, a)$的均匀分布中采样,其中a为
\[a = gain \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}\]torch.nn.init.xavier_normal_(tensor, gain=1.0)
从$N(0, std^2)$正态分布中采样,其中std为
\[std = gain \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}\]torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
从$U(-\text{bound}, \text{bound})$的均匀分布中采样
\[bound = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}}\]torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
从$N(0, std^2)$正态分布中采样,其中std为
\[std = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_out}}}\]