zoukankan      html  css  js  c++  java
  • Pytorch实现Top1准确率和Top5准确率

    之前一直不清楚Top1和Top5是什么,其实搞清楚了很简单,就是两种衡量指标,其中,Top1就是普通的Accuracy,Top5比Top1衡量标准更“严格”,

    具体来讲,比如一共需要分10类,每次分类器的输出结果都是10个相加为1的概率值,Top1就是这十个值中最大的那个概率值对应的分类恰好正确的频率,而Top5则是在十个概率值中从大到小排序出前五个,然后看看这前五个分类中是否存在那个正确分类,再计算频率。Pytorch实现如下:

    def evaluteTop1(model, loader):
        model.eval()
        
        correct = 0
        total = len(loader.dataset)
    
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            with torch.no_grad():
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct += torch.eq(pred, y).sum().float().item()
            #correct += torch.eq(pred, y).sum().item()
        return correct / total
    
    def evaluteTop5(model, loader):
        model.eval()
        correct = 0
        total = len(loader.dataset)
        for x, y in loader:
            x,y = x.to(device),y.to(device)
            with torch.no_grad():
                logits = model(x)
                maxk = max((1,5))
            y_resize = y.view(-1,1) _, pred
    = logits.topk(maxk, 1, True, True) correct += torch.eq(pred, y_resize).sum().float().item() return correct / total

    注意:y_resize = y.view(-1,1)是非常关键的一步,在correct的运算中,关键就是要pred和y_resize维度匹配,而原来的y是[128],128是batch大小;

    pred的维度则是[128,10],假设这里是CIFAR10十分类;因此必须把y转化成[128,1]这种维度,但是不能直接是y.view(128,1),因为遍历整个数据集的时候,

    最后一个batch大小并不是128,所以view()里面第一个size就设为-1未知,而确保第二个size是1就行

    topk函数的具体用法参见https://blog.csdn.net/u014264373/article/details/86525621

  • 相关阅读:
    2013=730 胆子要大,敢想敢做
    2013=7=30 自增量的浅谈
    2013=7=29 nyist 13题
    2013=726 整合,优化,利用自身资源。 让自己的时间更有意义,最大化利用
    2013=7=22
    2013=7=23 超级阶梯
    机器人写诗项目——递归神经网络(RNN)
    和程序员在一起是怎样的体验
    和程序员在一起是怎样的体验
    人工智能数学基础——线性代数
  • 原文地址:https://www.cnblogs.com/yqpy/p/11391972.html
Copyright © 2011-2022 走看看