模型保存与加载有两种方式,本文暂时只讨论模型参数方式
1> 单GPU
保存
1 torch.save(model.state_dict(), "model.pth")
加载
1 model = SimpleNet() 2 model.load_state_dict(torch.load("./model.pth"))
2> 多GPU
保存
1 torch.save(model.module.state_dict(), "./model.pth")
加载
1 mdoel = SimpleNet() 2 model.load_state_dict(torch.load("./model.pth"))