Pytorchで学習済みパラメータの保存

tzmi.hatenablog.com

GPUで保存→GPUで呼び出す場合
保存

model_path = 'model.pth'
torch.save(model.state_dict(), model_path)

呼び出し

model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))


GPUで学習→CPUで保存

model_path = 'model.pth'
torch.save(model.to('cpu').state_dict(), model_path)