https://github.com/biubug6/Pytorch_Retinaface
训练这个模型很耗GPU,mobilenet0.25还好(训练8G以内,测试1G以内),resnet50至少要43G。
0.环境
ubuntu16.04
python3.6
cuda9.0
torch==1.1.0
ipython
1.准备数据与模型
(1)准备数据
原始数据:train+val+test http://shuoyang1213.me/WIDERFACE/WiderFace_Results.html
标签数据:https://pan.baidu.com/s/1Laby0EctfuJGgGMgRRgykA
按照官方给的格式放:
./data/widerface/
train/
images/
label.txt
val/
images/
label.txt
test/
images/
label.txt
由于val只有label.txt,所以我们要自己转一下(这一步一定要转一下,不然会测试时会出错的):
# -*- coding: UTF-8 -*-
'''
@author: mengting gu
@contact: 1065504814@qq.com
@time: 2020/11/2 上午11:47
@file: widerValFile.py
@desc:
'''
import os
import argparse
parser = argparse.ArgumentParser(description='Retinaface')
parser.add_argument('--dataset_folder', default='./data/widerface/val/images/', type=str, help='dataset path')
args = parser.parse_args()
if __name__ == '__main__':
# testing dataset
testset_folder = args.dataset_folder
testset_list = args.dataset_folder[:-7] + "label.txt"
with open(testset_list, 'r') as fr:
test_dataset = fr.read().split()
num_images = len(test_dataset)
for i, img_name in enumerate(test_dataset):
print("line i :{}".format(i))
if img_name.endswith('.jpg'):
print(" img_name :{}".format(img_name))
f = open(args.dataset_folder[:-7] + 'wider_val.txt', 'a')
f.write(img_name+'
')
f.close()
最后将数据按照作者的要求,写成下面这种:
./data/widerface/
train/
images/
label.txt
val/
images/
label.txt
wider_val.txt
test/
images/
label.txt
(2)准备模型
google cloud and baidu cloud Password: fstq
2.训练
CUDA_VISIBLE_DEVICES=0 python train.py --network mobile0.25
要想训练Resnet50的话,在train.py代码的139行后, 添加:
torch.cuda.empty_cache()
3.测试评估
这份代码原来是测试与评估是分开的。
(1)测试
python test_widerface.py --trained_model weight_file --network mobile0.25
# eg
python test_widerface.py --trained_model ./weights/mobilenet0.25_epoch_245.pth --network mobile0.25
# eg 测试结果保存 去掉test_widerface.py中参数设置parser(save_image) action="store_true"
python test_widerface.py --trained_model ./weights/mobilenet0.25_epoch_245.pth --network mobile0.25 -s=True
检测结果jpg与txt分别在results与widerface_evaluate/widerface_txt目录下。
检测画框结果:
检测txt结果:
0_Parade_marchingband_1_20
261
542 357 36 42 0.9983047
29 404 27 33 0.9953412
254 371 29 33 0.98255116
629 394 19 21 0.96621066
464 355 27 29 0.9414734
81 391 22 25 0.9405682
386 368 17 20 0.9325298
730 346 17 19 0.88296247
(2)评估
cd ./widerface_evaluate
python setup.py build_ext --inplace
python evaluation.py
下面是作者自己跑出来的结果:
对比来看与第二行这项结果相差较小,基本能复现结果了。
参考
1.Pytorch_Retinaface
2.官方RetinaFace
原文链接:https://blog.csdn.net/qq_35975447/article/details/109447929