Skip to the content.

初始化

Contact me

本系列博客主页及相关见此处


来自resnet 官方实现

for m in self.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
    for m in self.modules():
        if isinstance(m, Bottleneck):
            nn.init.constant_(m.bn3.weight, 0)
        elif isinstance(m, BasicBlock):
            nn.init.constant_(m.bn2.weight, 0)

nn.init下还有很多初始化方式,如nn.init.uniform_(tensor, a=0.0, b=1.0)nn.init.normal_(tensor, mean=0.0, std=1.0)nn.init.constant_(tensor, val), nn.init.ones_(tensor)等,最知名的有如下几个:

torch.nn.init.xavier_uniform_(tensor, gain=1.0)

从$U(-a, a)$的均匀分布中采样,其中a为

\[a = gain \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}\]
torch.nn.init.xavier_normal_(tensor, gain=1.0)

从$N(0, std^2)$正态分布中采样,其中std为

\[std = gain \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}\]
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

从$U(-\text{bound}, \text{bound})$的均匀分布中采样

\[bound = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}}\]
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

从$N(0, std^2)$正态分布中采样,其中std为

\[std = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_out}}}\]