zoukankan      html  css  js  c++  java
  • 深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解

    pytorch版本sphereface的原作者地址:https://github.com/clcarwin/sphereface_pytorch

    由于接触深度学习不久,所以花了较长时间来阅读源码,以下对项目中的lfw_eval.py文件做了详细解释

    (不知是版本问题还是作者code有误,原代码存在很多的bug,需要自行一一纠正,另:由于在windows下运行,故而去掉了gpu加速以及多线程)

      1 #-*- coding:utf-8 -*-
      2 from __future__ import print_function
      3 
      4 import torch
      5 import torch.nn as nn
      6 import torch.optim as optim
      7 import torch.nn.functional as F
      8 from torch.autograd import Variable
      9 torch.backends.cudnn.bencmark = True
     10 
     11 import os,sys,cv2,random,datetime
     12 import argparse
     13 import numpy as np
     14 import zipfile
     15 
     16 from dataset import ImageDataset
     17 from matlab_cp2tform import get_similarity_transform_for_cv2
     18 import net_sphere
     19 from matplotlib import pyplot as plt
     20 
     21 #图像对齐和裁剪
     22 def alignment(src_img,src_pts):
     23     #使用标准人脸坐标对图像进行仿射
     24     ref_pts = [ [30.2946, 51.6963],[65.5318, 51.5014],
     25         [48.0252, 71.7366],[33.5493, 92.3655],[62.7299, 92.2041] ]
     26     crop_size = (96, 112)
     27     src_pts = np.array(src_pts).reshape(5,2)
     28 
     29     s = np.array(src_pts).astype(np.float32)
     30     r = np.array(ref_pts).astype(np.float32)
     31 
     32     tfm = get_similarity_transform_for_cv2(s, r)
     33     face_img = cv2.warpAffine(src_img, tfm, crop_size)
     34     return face_img
     35 
     36 #k-fold cross validation(k-折叠交叉验证)
     37 #将n份数据分为n_folds份,以次将第i份作为测试集,其余部分作为训练集
     38 def KFold(n=200, n_folds=10, shuffle=False):
     39     folds = []
     40     base = list(range(n))
     41     for i in range(n_folds):
     42         test = base[(i*n//n_folds):((i+1)*n//n_folds)]
     43         train = list(set(base)-set(test))
     44         folds.append([train,test])
     45     return folds
     46 
     47 #求解当前阈值时的准确率
     48 def eval_acc(threshold, diff):
     49     y_true = []
     50     y_predict = []
     51     for d in diff:
     52         same = 1 if float(d[2]) > threshold else 0
     53         y_predict.append(same)
     54         y_true.append(int(d[3]))
     55     y_true = np.array(y_true)
     56     y_predict = np.array(y_predict)
     57     accuracy = 1.0*np.count_nonzero(y_true==y_predict)/len(y_true)
     58     return accuracy
     59 
     60 #eval_acc和find_best_threshold共同工作,来求试图找到最佳阈值,
     61 #
     62 def find_best_threshold(thresholds, predicts):
     63     #threshould 阈值
     64     best_threshold = best_acc = 0
     65     for threshold in thresholds:
     66         accuracy = eval_acc(threshold, predicts)
     67         if accuracy >= best_acc:
     68             best_acc = accuracy
     69             best_threshold = threshold
     70     return best_threshold
     71 
     72 
     73 #命令行参数
     74 parser = argparse.ArgumentParser(description='PyTorch sphereface lfw')
     75 parser.add_argument('--net','-n', default='sphere20a', type=str)
     76 parser.add_argument('--lfw', default='../DataSet/lfw.zip', type=str)
     77 parser.add_argument('--model','-m', default='./sphere20a_20171020.pth', type=str)
     78 args = parser.parse_args()
     79 
     80 predicts=[]
     81 
     82 #加载网络
     83 net = getattr(net_sphere,args.net)()
     84 #加载模型
     85 net.load_state_dict(torch.load(args.model))
     86 #
     87 net.eval()
     88 #
     89 net.feature = True
     90 
     91 #加载图片数据
     92 zfile = zipfile.ZipFile(args.lfw)
     93 
     94 #加载landmark,每张照片包括五个特征点,共五组坐标
     95 landmark = {}
     96 with open('data/lfw_landmark.txt') as f:
     97     landmark_lines = f.readlines()
     98 #对每一行进行处理
     99 for line in landmark_lines:
    100     l = line.replace('
    ','').split('	')
    101     #将每一组数据转化为字典形式
    102     landmark[l[0]] = [int(k) for k in l[1:]]
    103 
    104 #加载pairs
    105 with open('data/pairs.txt') as f:
    106     pairs_lines = f.readlines()[1:]
    107 
    108 #range表示测试的图片对数
    109 for i in range(600):
    110     print(str(i)+" start")
    111     p = pairs_lines[i].replace('
    ','').split('	')
    112     # pairs.txt一共有6000行,存在两种形式,
    113     # 分别表示进行对比的两张照片,形式1是同一个人,形式2是不同人:
    114     # name 数字1 数字2
    115     # name 数字1 name数字2
    116     if 3==len(p):
    117         sameflag = 1
    118         #形式例如:Woody_Allen/Woody_Allen_0002.jpg
    119         name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1]))
    120         name2 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[2]))
    121     if 4==len(p):
    122         sameflag = 0
    123         name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1]))
    124         name2 = p[2]+'/'+p[2]+'_'+'{:04}.jpg'.format(int(p[3]))
    125 
    126     #分别加载两张照片,并对其进行图像对齐
    127     org_img1=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name1),np.uint8),1)
    128     org_img2=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name2),np.uint8),1)
    129     img1 = alignment(org_img1,landmark[name1])
    130     img2 = alignment(org_img2,landmark[name2])
    131     #1.对输出图像使用cv2进行展示
    132     # cv2.imshow("org_img1", org_img1)
    133     # cv2.imshow("org_img2", org_img2)
    134     # cv2.imshow("img1",img1)
    135     # cv2.imshow("img2", img2)
    136     # cv2.waitKey(0)
    137     # cv2.destroyAllWindows()
    138     #2.对输出图像使用matplotlib进行展示
    139     fig_new=plt.figure()
    140     img_list=[[org_img1,221],[org_img2,222],[img1,223],[img2,224]]
    141     for p,q in img_list:
    142         ax=fig_new.add_subplot(q)
    143         p = p[:, :, (2, 1, 0)]
    144         ax.imshow(p)
    145     plt.show()
    146 
    147     #cv.flip图像翻转,第二个参数:1:水平翻转,0:垂直翻转,-1:水平垂直翻转
    148     imglist = [img1,cv2.flip(img1,1),img2,cv2.flip(img2,1)]
    149     #分别对图片进行
    150     for m in range(len(imglist)):
    151         imglist[m] = imglist[m].transpose(2, 0, 1).reshape((1,3,112,96))
    152         imglist[m] = (imglist[m]-127.5)/128.0
    153 
    154     # p.vstack: 垂直(按照行顺序)的把数组给堆叠起来
    155     #******举例******
    156     # import numpy as np
    157     # a = [1, 2, 3]
    158     # b = [4, 5, 6]
    159     # print(np.vstack((a, b)))
    160     #
    161     # 输出:
    162     # [[1 2 3]
    163     #  [4 5 6]]
    164     img = np.vstack(imglist)
    165     #将numpy形式转化为variable形式
    166     img = Variable(torch.from_numpy(img).float(),volatile=True)
    167     output = net(img)
    168     #得到计算结果,f1和f2均为512维向量形式
    169     f = output.data
    170     f1,f2 = f[0],f[2]
    171     #计算二者的余弦相似度,后面加上常量是为了防止分母为0
    172     #关于余弦相似度请自行百度或google
    173     #这里给出一个简单说明的链接:http://blog.csdn.net/huangfei711/article/details/78469614
    174     #a*b/|a||b|
    175     cosdistance = f1.dot(f2)/(f1.norm()*f2.norm()+1e-5)
    176     predicts.append('{}	{}	{}	{}
    '.format(name1,name2,cosdistance,sameflag))
    177     print(str(i) + " end")
    178 
    179 
    180 #准确率
    181 accuracy = []
    182 #(最佳)阈值
    183 thd = []
    184 #k-fold cross validation(k-折叠交叉验证)
    185 #folds的形式为[[train,test],[train,test].....]
    186 folds = KFold(n=600, n_folds=10, shuffle=False)
    187 #取数组为-1到1,步长为0.005
    188 thresholds = np.arange(-1.0, 1.0, 0.005)
    189 # 此处为原作者code,疑似有误,已做修改
    190 # predicts = np.array(map(lambda line:frd.append(line.strip('
    ').split()), predicts))
    191 predicts = np.array([k.strip('
    ').split() for k in predicts])
    192 for idx, (train, test) in enumerate(folds):
    193     # predicts[train/test]形式为:
    194     # [['Doris_Roberts/Doris_Roberts_0001.jpg'
    195     # 'Doris_Roberts/Doris_Roberts_0003.jpg' '0.6532696413605743' '1'],.....]
    196     #寻找最佳阈值
    197     best_thresh = find_best_threshold(thresholds, predicts[train])
    198     #通过上面的得到的最佳阈值来对test数据集进行测试得到准确率
    199     accuracy.append(eval_acc(best_thresh, predicts[test]))
    200     #thd阈值
    201     thd.append(best_thresh)
    202 #np.mean:计算均值,np.std:计算标准差
    203 #输出结果分别为:准确率均值,准确率标准差,阈值均值
    204 print('LFWACC={:.4f} std={:.4f} thd={:.4f}'.format(np.mean(accuracy), np.std(accuracy), np.mean(thd)))
    205 #例如结果为 LFWACC=0.9800 std=0.0600 thd=0.3490
    206 #则说明准确率为98%,准确率标准差为0.06,阈值的均值为0.3490
    207 #因此我们可以认为余弦相似度大于0.3490的两张图片里是同一个人
  • 相关阅读:
    SQL Server身份验证登录失败
    课程总结及加分项
    导入并配置Guns框架
    python数据化中文是方块显示
    服务外包平台测试
    idea配置javap
    interface和abstract的区别
    简记Vue弹窗组件eldaolog被父界面创建后,子界面created函数只调用一次的解决方案
    Vue computed属性和methods区别
    记录一下前端查询条件对应后端多个条件的一种简单粗暴解决方法
  • 原文地址:https://www.cnblogs.com/lomooo/p/8523232.html
Copyright © 2011-2022 走看看