zoukankan      html  css  js  c++  java
  • Pytorch实现Top1准确率和Top5准确率

    之前一直不清楚Top1和Top5是什么,其实搞清楚了很简单,就是两种衡量指标,其中,Top1就是普通的Accuracy,Top5比Top1衡量标准更“严格”,

    具体来讲,比如一共需要分10类,每次分类器的输出结果都是10个相加为1的概率值,Top1就是这十个值中最大的那个概率值对应的分类恰好正确的频率,而Top5则是在十个概率值中从大到小排序出前五个,然后看看这前五个分类中是否存在那个正确分类,再计算频率。Pytorch实现如下:

    def evaluteTop1(model, loader):
        model.eval()
        
        correct = 0
        total = len(loader.dataset)
    
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            with torch.no_grad():
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct += torch.eq(pred, y).sum().float().item()
            #correct += torch.eq(pred, y).sum().item()
        return correct / total
    
    def evaluteTop5(model, loader):
        model.eval()
        correct = 0
        total = len(loader.dataset)
        for x, y in loader:
            x,y = x.to(device),y.to(device)
            with torch.no_grad():
                logits = model(x)
                maxk = max((1,5))
            y_resize = y.view(-1,1) _, pred
    = logits.topk(maxk, 1, True, True) correct += torch.eq(pred, y_resize).sum().float().item() return correct / total

    注意:y_resize = y.view(-1,1)是非常关键的一步,在correct的运算中,关键就是要pred和y_resize维度匹配,而原来的y是[128],128是batch大小;

    pred的维度则是[128,10],假设这里是CIFAR10十分类;因此必须把y转化成[128,1]这种维度,但是不能直接是y.view(128,1),因为遍历整个数据集的时候,

    最后一个batch大小并不是128,所以view()里面第一个size就设为-1未知,而确保第二个size是1就行

    topk函数的具体用法参见https://blog.csdn.net/u014264373/article/details/86525621

  • 相关阅读:
    原型与原型链
    数据类型与计算
    JavaScript实现版本号比较
    vue依赖
    vue
    面试经验
    第十一节课 课堂总结
    第十一次作业
    第十课课堂总结
    第十次作业
  • 原文地址:https://www.cnblogs.com/yqpy/p/11391972.html
Copyright © 2011-2022 走看看