Skip to the content.

保存和加载模型

Contact me

本系列博客主页及相关见此处


来自pytorch官网,推荐看原文。

概要:

推荐的保存和加载

torch.save(model.state_dict(), PATH)

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

完整保存和加载模型

torch.save(model, PATH)
Load:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

保存和加载多项内容,多模型也一样

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)


model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

跨设备保存和加载

1. GPU保存, CPU加载

torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

2. GPU保存, GPU加载

torch.save(model.state_dict(), PATH)
Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

3. CPU保存, GPU加载

torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

4. 保存torch.nn.DataParallel Models模型

torch.save(model.module.state_dict(), PATH)

加载如上