zoukankan      html  css  js  c++  java
  • TensorFlow 1.4利用Keras+Estimator API进行训练和预测

    Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中的另一个高级API -- Estimator模型,然后就可以调用Dataset API进行对tfrecords进行操作用来训练/评估模型。而keras本身也用到了Estimator API并且提供了tf.keras.estimator.model_to_estimator函数将keras模型可以很方便的转换成Estimator模型,因此用Keras API搭建模型框架然后用Dataset API操作IO,然后用Estimator训练模型是一套比较方便高效的操作流程。

    注:

    1. tf.keras.estimator.model_to_estimator这个函数只在tf.keras下面有在原生的keras中是没有这个函数的。
    2. Estimator训练的模型类型主要有regressorclassifier两类,如果需要用自定义的模型类型,可以通过自定有model_fn来构建,具体操作可以查看这里
    3. Estimator模型可以通过export_savedmodel()函数输出训练好的estimator模型,然后可以把模型创建服务接受输入数据并输出结果,这在大规模云端部署的时候会非常有用(具体操作流程可以看这里)。

    1. 利用Keras搭建模型框架并转换成estimator模型

    比如我们利用keras的ResNet50构建二分类模型:

    import tensorflow as tf
    import os
    resnet = tf.keras.applications.resnet50
    def my_model_fn():
        base_model = resnet.ResNet50(include_top=True,                   # include fully layers or not
                                          weights='imagenet',                # pre-trained weights
                                          input_shape=(224, 224, 3),    # default input shape
                                          classes=2)
        base_model.summary()
        optimizer = tf.keras.optimizers.RMSprop(lr=2e-3,
                                                decay=0.9)
        base_model.compile(optimizer=optimizer,
                           loss='categorical_crossentropy',
                           metrics=["accurary"])
        # convert keras model to estimator model
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), “train”)
        est_model = tf.keras.estimator.model_to_estimator(base_model, model_dir=model_dir)  # model save dir is 'train',
        return est_model
    

    注意:model_dir必须为全路径,使用相对路径的estimator在检索模型输入输出的时候可能会报错。

    2. 利用Dateset API从tfrecords中读取数据并构建estimator模型的输入input_fn

    比如tfrecords的图片和标签分别保存在“image/encoded”和“image/label”下:(如何写tfrecords可以参考这里

    def _tf_example_parser(record):
        feature = {"image/encoded": tf.FixedLenFeature([], tf.string),
                   "image/class_id": tf.FixedLenFeature([], tf.int64)}
        features = tf.parse_single_example(record, features=feature)
        image = tf.decode_raw(features["image/encoded"], out_type=tf.uint8)   # 写入tfrecords的时候是misc/cv2读取的ndarray
        image = tf.cast(image, dtype=tf.float32)
        image = tf.reshape(image, shape=(224, 224, 3))    # 如果输入图片不做resize,那么不同大小的图片是无法输入到同一个batch中的
        label = tf.cast(features["image/class_id"], dtype=tf.int64)
        return image, label
    
    def input_fn(data_path, batch_size=64, is_training=True):
        """
        Parse image and data in tfrecords file for training
        Args:
            data_path:  a list or single tf records path
            batch_size: size of images returned
            is_training: is training stage or not
        Returns:
            image and labels batches of randomly shuffling tensors
        """
        with K.name_scope("input_pipeline"):
            if not isinstance(data_path, (tuple, list)):
                data_path = [data_path]
            dataset = tf.data.TFRecordDataset(data_path)
            dataset = dataset.map(_tf_example_parser)
            dataset = dataset.repeat(25)              # num of epochs
            dataset = dataset.batch(64)               # batch size
            if is_training:
                dataset = dataset.shuffle(1000)     # 对输入进行shuffle,buffer_size越大,内存占用越大,shuffle的时间也越长,因此可以在写tfrecords的时候就实现用乱序写入,这样的话这里就不需要用shuffle
            iterator = dataset.make_one_shot_iterator()
            images, labels = iterator.get_next()
            # convert to onehot label
            labels = tf.one_hot(labels, 2)  # 二分类
            # preprocess image: scale pixel values from 0-255 to 0-1
            images = tf.image.convert_image_dtype(images, dtype=tf.float32)  # 将图片像素从0-255转换成0-1,tf提供的图像操作大多需要0-1之间的float32类型
            images /= 255.
            images -= 0.5
            images *= 2.
            return dict({"input_1": images}), labels
    

    3. 利用estimator API训练模型

    def train(source_dir):
        if tf.gfile.Exists(source_dir):
            train_data_paths = tf.gfile.Glob(source_dir+"/train*tfrecord")  # 所有train开头的tfrecords都用于模型训练
            val_data_paths = tf.gfile.Glob(source_dir+"/val*tfrecord")       # 所有val开头的tfrecords都用于模型评估
            train_data_paths = val_data_paths
            if not len(train_data_paths):
                raise Exception("[Train Error]: unable to find train*.tfrecord file")
            if not len(val_data_paths):
                raise Exception("[Eval Error]: unable to find val*.tfrecord file")
        else:
            raise Exception("[Train Error]: unable to find input directory!")
        est_model = my_model_fn()
        train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(data_path=train_data_paths,  
                                                                      batch_size=_BATCH_SIZE,
                                                                      is_training=True),
                                            max_steps=300000)
    
        eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(val_data_paths,
                                                                    batch_size=_BATCH_SIZE,
                                                                    is_training=False))
        # train and evaluate model
        tf.estimator.train_and_evaluate(estimator=est_model,
                                        train_spec=train_spec,
                                        eval_spec=eval_spec)
    

    PS: 这里用lambda表示输入为函数,而非函数的返回值;也可以用partial函数进行包裹;而当没有输入变量的时候就可以直接用。
    训练的时候,用途tensorboard监控train目录查看训练过程。

    4. 利用estimator模型进行预测

    由于estimator模型predict函数的输入与训练的时候一样为input_fn,但是此时直接从文件中读取,而非tfrecords,因此需要重新定义一个input_fn来用于predict。

    def predict_input_fn(image_path):
        images = misc.imread(image_path)
        # preprocess image: scale pixel values from 0-255 to 0-1
        images = tf.image.convert_image_dtype(images, dtype=tf.float32)
        images /= 255.
        images -= 0.5
        images *= 2.
        dataset = tf.data.Dataset.from_tensor_slices((images, ))   
        return dataset.batch(1).make_one_shot_iterator().get_next()
    
    def predict(image_path):
        est_model = my_model_fn()
        result = est_model.predict(input_fn=lambda: predict_input_fn(image_path=image_path))
        for r in result:
            print(r)
    

    参考:

    Estimator:

    Dataset

    Tfredords:
    http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.html

    Keras:
    https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html

  • 相关阅读:
    CV2图像操作
    Sobel边缘检测
    matlat之KDTreeSearcher()函数
    linux shell 将多行文件转换为一行
    (转)Shell脚本编程--Uniq命令
    (转)iptables简介
    (转)linux passwd批量修改用户密码
    (转)linux sort 命令详解
    (转)Linux命令之md5sum
    (转)shell实例浅谈之产生随机数七种方法
  • 原文地址:https://www.cnblogs.com/arkenstone/p/8118651.html
Copyright © 2011-2022 走看看