GPU加速
1. 定义GPU设备
import torch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device)
2. 将模型、张量等放在GPU设备上
# loss.to(device) tensor.to(device) model.to(device)
3. 将数据等放回CPu
predict = model(data) predict = predict.cpu().detach().numpy() # detach() 和 data效果相似,但detach是深拷贝,data是浅拷贝
--------------------------------
随有随更 2021.6.9
--------------------------------