Skip to the content.

混合精度训练

Contact me

本系列博客主页及相关见此处


原理见MIXED PRECISION TRAINING论文,代码来自apex文档

# Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Allow Amp to perform casts as required by the opt_level
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
...
# loss.backward() becomes:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...

如果使用fastai库,那么简单一些,参见fastai 文档

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy]).to_fp16() #<<<<<<<<这里
learn.fit_one_cycle(1)

如果使用pytorch_lightning,也比较容易做到,参见pytorch_lightning文档,注意这个库依赖于上面的apex,需要安装,文档有详细信息:

# DEFAULT
trainer = Trainer(amp_level='O2', use_amp=True)