zoukankan      html  css  js  c++  java
  • 学习笔记14:模型保存

    保存训练过程中使得测试集上准确率最高的参数

    import copy
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    for epoch in range(extend_epoch):
        epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
        if epoch_test_acc > best_acc:
            best_model_wts = copy.deepcopy(model.state_dict())
            best_acc = epoch_test_acc
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)
    model.load_state_dict(best_model_wts)
    

    保存模型

    PATH = 'E:/my_model.pth'
    torch.save(model.state_dict(), PATH)
    

    重新加载模型

    new_model = models.resnet101(pretrained = True)
    in_f = new_model.fc.in_features
    new_model.fc = nn.Linear(in_f, 4)
    new_model.load_state_dict(torch.load(PATH))
    

    测试是否加载成功

    new_model.to(device)
    test_correct = 0
    test_total = 0
    new_model.eval()
    with torch.no_grad():
        for x, y in test_dl:
            x, y = x.to(device), y.to(device)
            y_pred = new_model(x)
            loss = loss_func(y_pred, y)
            y_pred = torch.argmax(y_pred, dim = 1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
    epoch_test_acc = test_correct / test_total
    print(epoch_test_acc)
    
  • 相关阅读:
    个人学期总结
    管理信息系统 第三部分 作业
    模型分离(选做)
    密码保护
    实现搜索功能
    完成个人中心—导航标签
    个人中心标签页导航
    评论列表显示及排序,个人中心显示
    完成评论功能
    ASP.NET Core开发者指南
  • 原文地址:https://www.cnblogs.com/miraclepbc/p/14361926.html
Copyright © 2011-2022 走看看