一、引入pom.xml依赖
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>cn.dearcloud</groupId> <artifactId>train-yolo-for-java</artifactId> <version>1.0-SNAPSHOT</version> <dependencies> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-zoo</artifactId> <version>1.0.0-beta</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-modelimport</artifactId> <version>1.0.0-beta</version> </dependency> <!--GPU--> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-cuda-8.0-platform</artifactId> <version>1.0.0-beta</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-cuda-8.0</artifactId> <version>1.0.0-beta</version> </dependency> <!--CPU--> <!--<dependency>--> <!--<groupId>org.nd4j</groupId>--> <!--<artifactId>nd4j-native-platform</artifactId>--> <!--<version>1.0.0-beta</version>--> <!--</dependency>--> <!--Log--> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <version>1.16.22</version> </dependency> <dependency> <groupId>org.apache.logging.log4j</groupId> <artifactId>log4j-slf4j-impl</artifactId> <version>2.11.0</version> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.7.0</version> <configuration> <source>1.8</source> <target>1.8</target> <encoding>UTF-8</encoding> </configuration> </plugin> </plugins> </build> </project>
二、读取数据集
假设,数据集文件夹所在路径如下,下面有图片和图片同名的txt文件中记录标注对像。一行一个标注对像,每行依次是:Label,X,Y,Width,Height
D:\Project\AIProject\train-yolo-for-java\docs\pupil-datasets
三、编写标注加载代码
package cn.dearcloud.provider; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.datavec.image.recordreader.objdetect.ImageObject; import org.datavec.image.recordreader.objdetect.ImageObjectLabelProvider; import java.io.File; import java.net.URI; import java.util.ArrayList; import java.util.List; public class CnnLabelProvider implements ImageObjectLabelProvider { public CnnLabelProvider() { } @Override public List<ImageObject> getImageObjectsForPath(String path) { try { List<ImageObject> imageObjects = new ArrayList<>(); File labelFile = new File(FilenameUtils.getFullPath(path), FilenameUtils.getBaseName(path) + ".txt"); List<String> lines = FileUtils.readLines(labelFile, "UTF-8"); for (String line : lines) { //label,x,y,w,h String[] arr = line.split(","); if (arr.length == 5) { String labelName = arr[0]; int x = Integer.parseInt(arr[1]); int y = Integer.parseInt(arr[2]); int w = Integer.parseInt(arr[3]); int h = Integer.parseInt(arr[4]); imageObjects.add(new ImageObject(x, y, x + w, y + h, labelName)); } } return imageObjects; } catch (Exception ex) { throw new RuntimeException(ex); } } @Override public List<ImageObject> getImageObjectsForPath(URI uri) { return getImageObjectsForPath(new File(uri).getPath()); } }
四、编写YoloV2训练代码
package cn.dearcloud; import cn.dearcloud.provider.CnnLabelProvider; import lombok.extern.log4j.Log4j2; import org.bytedeco.javacpp.opencv_core; import org.bytedeco.javacpp.opencv_imgproc; import org.bytedeco.javacv.CanvasFrame; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.api.records.metadata.RecordMetaDataImageURI; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.objdetect.DetectedObject; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.zoo.model.YOLO2; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.learning.config.Adam; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Random; import static org.bytedeco.javacpp.opencv_core.FONT_HERSHEY_DUPLEX; import static org.bytedeco.javacpp.opencv_imgproc.resize; import static org.opencv.core.CvType.CV_8U; @Log4j2 public class Yolo2Trainer { // parameters matching the pretrained TinyYOLO model int width = 480; int height = 320; int nChannels = 3; int gridWidth = 15; int gridHeight = 10; int nClasses = 1; int nBoxes = 5; double lambdaNoObj = 0.5; double lambdaCoord = 5.0; double[][] priorBoxes = {{1.08, 1.19}, {3.42, 4.41}, {6.63, 11.38}, {9.42, 5.11}, {16.62, 10.52}}; double detectionThreshold = 0.3; // parameters for the training phase int batchSize = 1; int nEpochs = 50; double learningRate = 1e-3; double lrMomentum = 0.9; public void read() throws IOException, InterruptedException { String datasetsDir = "D:\Project\AIProject\train-yolo-for-java\docs\pupil-datasets"; File imageDir = new File(datasetsDir); log.info("Load data..."); //切分数据集 Random rng = new Random(); FileSplit fileSplit = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng); InputSplit[] data = fileSplit.sample(null, 0.8, 0.2); InputSplit trainData = data[0]; InputSplit testData = data[1]; //自己实现ImageObjectLabelProvider接口 CnnLabelProvider labelProvider = new CnnLabelProvider(); ObjectDetectionRecordReader trainRecordReader = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider); trainRecordReader.initialize(trainData);//returned values: 4d array, with dimensions [minibatch, 4+C, h, w] ObjectDetectionRecordReader testRecordReader = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider); testRecordReader.initialize(testData); // ObjectDetectionRecordReader performs regression, so we need to specify it here RecordReaderDataSetIterator trainDataSetIterator = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, 1, true); trainDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1)); RecordReaderDataSetIterator testDataSetIterator = new RecordReaderDataSetIterator(testRecordReader, 1, 1, 1, true); testDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1)); ComputationGraph model; String modelFilename = "model_surface_YOLO2.zip"; if (new File(modelFilename).exists()) { log.info("Load model..."); model = ModelSerializer.restoreComputationGraph(modelFilename); } else { ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained(); INDArray priors = org.nd4j.linalg.factory.Nd4j.create(priorBoxes); FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .seed(1234) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(1.0) .updater(new Adam.Builder().learningRate(1e-3).build()) .l2(0.00001) .activation(Activation.IDENTITY) .trainingWorkspaceMode(WorkspaceMode.ENABLED) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .build(); model = new TransferLearning.GraphBuilder(pretrained).fineTuneConfiguration(fineTuneConf).removeVertexKeepConnections("conv2d_23") .addLayer("convolution2d_23", new ConvolutionLayer.Builder(1, 1) .nIn(1024) .nOut(nBoxes * (5 + nClasses)) .stride(1, 1) .convolutionMode(ConvolutionMode.Same) .weightInit(WeightInit.UNIFORM) .hasBias(false) .activation(Activation.IDENTITY) .build(), "leaky_re_lu_22") .addLayer("outputs", new Yolo2OutputLayer.Builder() .boundingBoxPriors(priors) .lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord) .build(), "convolution2d_23") .setOutputs("outputs") .build(); System.out.println(model.summary(InputType.convolutional(width, height, nChannels))); //设置训练时输出 model.setListeners(new org.deeplearning4j.optimize.listeners.ScoreIterationListener(1)); //开始训练 for (int i = 0; i < nEpochs; i++) { trainDataSetIterator.reset(); while (trainDataSetIterator.hasNext()) { model.fit(trainDataSetIterator.next()); } log.info("*** Completed epoch {} ***", i); } ModelSerializer.writeModel(model, modelFilename, true); } // 可视化与测试 NativeImageLoader imageLoader = new NativeImageLoader(); CanvasFrame frame = new CanvasFrame("RedBloodCellDetection"); OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0); List<String> labels = trainDataSetIterator.getLabels(); testDataSetIterator.setCollectMetaData(true); while (testDataSetIterator.hasNext() && frame.isVisible()) { org.nd4j.linalg.dataset.DataSet ds = testDataSetIterator.next(); RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0); INDArray features = ds.getFeatures(); INDArray results = model.outputSingle(features); List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold); File file = new File(metadata.getURI()); log.info(file.getName() + ": " + objs); opencv_core.Mat mat = imageLoader.asMat(features); opencv_core.Mat convertedMat = new opencv_core.Mat(); mat.convertTo(convertedMat, CV_8U, 255, 0); int w = metadata.getOrigW() * 2; int h = metadata.getOrigH() * 2; opencv_core.Mat image = new opencv_core.Mat(); resize(convertedMat, image, new opencv_core.Size(w, h)); for (DetectedObject obj : objs) { double[] xy1 = obj.getTopLeftXY(); double[] xy2 = obj.getBottomRightXY(); String label = labels.get(obj.getPredictedClass()); int x1 = (int) Math.round(w * xy1[0] / gridWidth); int y1 = (int) Math.round(h * xy1[1] / gridHeight); int x2 = (int) Math.round(w * xy2[0] / gridWidth); int y2 = (int) Math.round(h * xy2[1] / gridHeight); opencv_imgproc.rectangle(image, new opencv_core.Point(x1, y1), new opencv_core.Point(x2, y2), opencv_core.Scalar.RED); opencv_imgproc.putText(image, label, new opencv_core.Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, opencv_core.Scalar.GREEN); } frame.setTitle(new File(metadata.getURI()).getName() + " - RedBloodCellDetection"); frame.setCanvasSize(w, h); frame.showImage(converter.convert(image)); frame.waitKey(); } frame.dispose(); } }
五、顺便给大家写写TinyYolo的训练代码
package cn.dearcloud; import lombok.extern.log4j.Log4j2; import org.bytedeco.javacpp.opencv_core; import org.bytedeco.javacpp.opencv_imgproc; import org.bytedeco.javacv.CanvasFrame; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.records.metadata.RecordMetaDataImageURI; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader; import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.objdetect.DetectedObject; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.zoo.model.TinyYOLO; import org.deeplearning4j.zoo.model.YOLO2; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Nesterovs; import java.io.File; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.util.List; import java.util.Random; import static org.bytedeco.javacpp.opencv_core.FONT_HERSHEY_DUPLEX; import static org.bytedeco.javacpp.opencv_imgproc.resize; import static org.opencv.core.CvType.CV_8U; /** * 参考:https://blog.csdn.net/u011669700/article/details/79886619 实现 */ @Log4j2 public class TinyYoloTrainer { // parameters matching the pretrained TinyYOLO model int width = 416; int height = 416; int nChannels = 3; int gridWidth = 13; int gridHeight = 13; int numClasses = 1; // parameters for the Yolo2OutputLayer int nBoxes = 5; double lambdaNoObj = 0.5; double lambdaCoord = 5.0; double[][] priorBoxes = {{2, 2}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; double detectionThreshold = 0.3; // parameters for the training phase int batchSize = 2; int nEpochs = 50; double learningRate = 1e-3; double lrMomentum = 0.9; public void read() throws IOException, InterruptedException { String dataDir = new ClassPathResource("/datasets").getFile().getPath(); File imageDir = new File(dataDir, "JPEGImages"); log.info("Load data..."); //切分数据集 Random rng = new Random(); FileSplit fileSplit = new org.datavec.api.split.FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng); InputSplit[] data = fileSplit.sample(new RandomPathFilter(rng) { @Override protected boolean accept(String name) { boolean isXmlExist = false; try { isXmlExist = new File(new URI(name.replace("JPEGImages", "Annotations").replace(".jpg", ".xml"))).exists(); } catch (URISyntaxException e) { e.printStackTrace(); } return isXmlExist; } }, 0.8, 0.2); InputSplit trainData = data[0]; InputSplit testData = data[1]; //用于解析识别voc方式的label方式,也可以自己实现ImageObjectLabelProvider接口 VocLabelProvider labelProvider = new VocLabelProvider(dataDir); ObjectDetectionRecordReader trainRecordReader = new org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider); trainRecordReader.initialize(trainData);//returned values: 4d array, with dimensions [minibatch, 4+C, h, w] ObjectDetectionRecordReader testRecordReader = new org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider); testRecordReader.initialize(testData); // ObjectDetectionRecordReader performs regression, so we need to specify it here RecordReaderDataSetIterator trainDataSetIterator = new org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, 1, true); trainDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1, 8)); RecordReaderDataSetIterator testDataSetIterator = new org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator(testRecordReader, batchSize, 1, 1, true); testDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1, 8)); String modelFilename = "model_yolov2.zip"; ComputationGraph pretrained = (ComputationGraph) TinyYOLO.builder().build().initPretrained(); INDArray priors = org.nd4j.linalg.factory.Nd4j.create(priorBoxes); FineTuneConfiguration fineTuneConfiguration = new org.deeplearning4j.nn.transferlearning.FineTuneConfiguration.Builder() .seed(100) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(1.0) .updater(Nesterovs.builder().learningRate(learningRate).momentum(lrMomentum).build()) .activation(Activation.IDENTITY) .trainingWorkspaceMode(WorkspaceMode.ENABLED) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .build(); ComputationGraph model = new TransferLearning.GraphBuilder(pretrained).fineTuneConfiguration(fineTuneConfiguration).removeVertexKeepConnections("conv2d_9") .addLayer("convolution2d_9", new ConvolutionLayer.Builder(1, 1) .nIn(1024) .nOut(nBoxes * (5 + numClasses)) .stride(1, 1) .convolutionMode(ConvolutionMode.Same) .weightInit(WeightInit.UNIFORM) .hasBias(false) .activation(Activation.IDENTITY) .build(), "leaky_re_lu_8") .addLayer("outputs", new Yolo2OutputLayer.Builder().lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord).boundingBoxPriors(priors).build(), "convolution2d_9") .setOutputs("outputs") .build(); //设置训练时输出 model.setListeners(new org.deeplearning4j.optimize.listeners.ScoreIterationListener(1)); //开始训练 for (int i = 0; i < nEpochs; i++) { trainDataSetIterator.reset(); while (trainDataSetIterator.hasNext()) { model.fit(trainDataSetIterator.next()); } log.info("*** Completed epoch {} ***", i); } ModelSerializer.writeModel(model, modelFilename, true); // 可视化与测试 NativeImageLoader imageLoader = new NativeImageLoader(); CanvasFrame frame = new CanvasFrame("RedBloodCellDetection"); OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0); List<String> labels = trainDataSetIterator.getLabels(); testDataSetIterator.setCollectMetaData(true); while (testDataSetIterator.hasNext() && frame.isVisible()) { org.nd4j.linalg.dataset.DataSet ds = testDataSetIterator.next(); RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0); INDArray features = ds.getFeatures(); INDArray results = model.outputSingle(features); List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold); File file = new File(metadata.getURI()); log.info(file.getName() + ": " + objs); opencv_core.Mat mat = imageLoader.asMat(features); opencv_core.Mat convertedMat = new opencv_core.Mat(); mat.convertTo(convertedMat, CV_8U, 255, 0); int w = metadata.getOrigW() * 2; int h = metadata.getOrigH() * 2; opencv_core.Mat image = new opencv_core.Mat(); resize(convertedMat, image, new opencv_core.Size(w, h)); for (DetectedObject obj : objs) { double[] xy1 = obj.getTopLeftXY(); double[] xy2 = obj.getBottomRightXY(); String label = labels.get(obj.getPredictedClass()); int x1 = (int) Math.round(w * xy1[0] / gridWidth); int y1 = (int) Math.round(h * xy1[1] / gridHeight); int x2 = (int) Math.round(w * xy2[0] / gridWidth); int y2 = (int) Math.round(h * xy2[1] / gridHeight); opencv_imgproc.rectangle(image, new opencv_core.Point(x1, y1), new opencv_core.Point(x2, y2), opencv_core.Scalar.RED); opencv_imgproc.putText(image, label, new opencv_core.Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, opencv_core.Scalar.GREEN); } frame.setTitle(new File(metadata.getURI()).getName() + " - RedBloodCellDetection"); frame.setCanvasSize(w, h); frame.showImage(converter.convert(image)); frame.waitKey(); } frame.dispose(); } }
日志如下: