zoukankan      html  css  js  c++  java
  • MindSpore中使用model.train,在每一步训练结束后自动进行调用自定义函数 —— from mindspore.train.callback import Callback

    在MindSpore中使用model.train训练网络时我们难以处理间断性的任务,为此我们可以考虑使用MindSpore中的Callback机制。


    Callback 函数可以在 model.train 的每一步(step)训练结束后进行自定义的操作。



    Callback 函数
    from mindspore.train.callback import Callback





    在官方文档中一般使用 Callback 函数来记录每一步的loss 或 在一定训练步数后进行算法评估:
    官网地址:
    https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/quick_start/quick_start.html






    具体使用的代码:
    参考:https://www.cnblogs.com/devilmaycry812839668/p/14971668.html
    import matplotlib.pyplot as plt
    import matplotlib
    import numpy as np
    import os
    
    import mindspore.nn as nn
    from mindspore.nn import Accuracy
    from mindspore.nn import SoftmaxCrossEntropyWithLogits
    from mindspore import dtype as mstype
    import mindspore.dataset as ds
    import mindspore.dataset.vision.c_transforms as CV
    import mindspore.dataset.transforms.c_transforms as C
    from mindspore.dataset.vision import Inter
    from mindspore.common.initializer import Normal
    from mindspore import Tensor, Model
    from mindspore.train.callback import Callback
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
    
    
    def create_dataset(data_path, batch_size=32, repeat_size=1,
                       num_parallel_workers=1):
        """
        create dataset for train or test
    
        Args:
            data_path (str): Data path
            batch_size (int): The number of data records in each group
            repeat_size (int): The number of replicated data records
            num_parallel_workers (int): The number of parallel workers
        """
        # define dataset
        mnist_ds = ds.MnistDataset(data_path)
    
        # define some parameters needed for data enhancement and rough justification
        resize_height, resize_width = 32, 32
        rescale = 1.0 / 255.0
        shift = 0.0
        rescale_nml = 1 / 0.3081
        shift_nml = -1 * 0.1307 / 0.3081
    
        # according to the parameters, generate the corresponding data enhancement method
        resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
        rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
        rescale_op = CV.Rescale(rescale, shift)
        hwc2chw_op = CV.HWC2CHW()
        type_cast_op = C.TypeCast(mstype.int32)
    
        # using map to apply operations to a dataset
        mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    
        # process the generated dataset
        buffer_size = 10000
        mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
        mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
        mnist_ds = mnist_ds.repeat(repeat_size)
    
        return mnist_ds
    
    
    class LeNet5(nn.Cell):
        """Lenet network structure."""
        # define the operator required
        def __init__(self, num_class=10, num_channel=1):
            super(LeNet5, self).__init__()
            self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
            self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
            self.relu = nn.ReLU()
            self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
            self.flatten = nn.Flatten()
    
        # use the preceding operators to construct networks
        def construct(self, x):
            x = self.max_pool2d(self.relu(self.conv1(x)))
            x = self.max_pool2d(self.relu(self.conv2(x)))
            x = self.flatten(x)
            x = self.relu(self.fc1(x))
            x = self.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    
    # custom callback function
    class StepLossAccInfo(Callback):
        def __init__(self, model, eval_dataset, steps_loss, steps_eval):
            self.model = model
            self.eval_dataset = eval_dataset
            self.steps_loss = steps_loss
            self.steps_eval = steps_eval
            self.steps = 0
    
        def step_end(self, run_context):
            cb_params = run_context.original_args()
            cur_epoch = cb_params.cur_epoch_num
            #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
            self.steps = self.steps+1
            cur_step = self.steps
    
            self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
            self.steps_loss["step"].append(str(cur_step))
            if cur_step % 125 == 0:
                acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
                self.steps_eval["step"].append(cur_step)
                self.steps_eval["acc"].append(acc["Accuracy"])
    
    
    def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path):
        ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size)
        eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32)
    
        # save the network model and parameters for subsequence fine-tuning
        config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)
        # group layers into an object with training and evaluation features
        ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck)
    
        steps_loss = {"step": [], "loss_value": []}
        steps_eval = {"step": [], "acc": []}
        # collect the steps,loss and accuracy information
        step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval)
    
        model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=False)
    
        return steps_loss, steps_eval
    
    
    
    
    
    
    epoch_size = 1
    repeat_size = 1
    mnist_path = "./datasets/MNIST_Data"
    model_path = "./models/ckpt/mindspore_quick_start/"
    
    # clean up old run files before in Linux
    os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path))
    
    lr = 0.01
    momentum = 0.9
    
    # create the network
    network = LeNet5()
    
    # define the optimizer
    net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
    
    # define the loss function
    net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    
    # define the model
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
    
    steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path)
    
    print(steps_loss, steps_eval)
    View Code

    运行结果:









    核心代码:
    from mindspore.train.callback import Callback
    
    # custom callback function
    class StepLossAccInfo(Callback):
        def __init__(self, model, eval_dataset, steps_loss, steps_eval):
            self.model = model
            self.eval_dataset = eval_dataset
            self.steps_loss = steps_loss
            self.steps_eval = steps_eval
            self.steps = 0
    
        def step_end(self, run_context):
            cb_params = run_context.original_args()
            cur_epoch = cb_params.cur_epoch_num
            #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
            self.steps = self.steps+1
            cur_step = self.steps
    
            self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
            self.steps_loss["step"].append(str(cur_step))
            if cur_step % 125 == 0:
                acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
                self.steps_eval["step"].append(cur_step)
                self.steps_eval["acc"].append(acc["Accuracy"])
    可以看到,继承 Callback 类后我们可以自己定义新的功能类,只要我们实现 step_end 方法即可。
    默认传入给 step_end 方法的参数 run_context 可以通过以下方法获得当前刚结束的step数和当前的epoch数:

    cb_params = run_context.original_args()
    cur_epoch = cb_params.cur_epoch_num
    cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num



    其中,cb_params.cur_epoch_num 为当前的epoch数,
    cb_params.cur_step_num 为在当前epoch中的当前步数,
    需要注意的是,cb_params.cur_step_num 步数不是总共的计算步数,而是在当前epoch中的计算步数。



    当前step训练中的损失值也是可以获得的,具体如下:
    cb_params.net_outputs  代表当前step的损失值









    =========================================================


    上述代码,引入绘图功能的代码:
    import matplotlib.pyplot as plt
    import matplotlib
    import numpy as np
    import os
    
    import mindspore.nn as nn
    from mindspore.nn import Accuracy
    from mindspore.nn import SoftmaxCrossEntropyWithLogits
    from mindspore import dtype as mstype
    import mindspore.dataset as ds
    import mindspore.dataset.vision.c_transforms as CV
    import mindspore.dataset.transforms.c_transforms as C
    from mindspore.dataset.vision import Inter
    from mindspore.common.initializer import Normal
    from mindspore import Tensor, Model
    from mindspore.train.callback import Callback
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
    
    
    def create_dataset(data_path, batch_size=32, repeat_size=1,
                       num_parallel_workers=1):
        """
        create dataset for train or test
    
        Args:
            data_path (str): Data path
            batch_size (int): The number of data records in each group
            repeat_size (int): The number of replicated data records
            num_parallel_workers (int): The number of parallel workers
        """
        # define dataset
        mnist_ds = ds.MnistDataset(data_path)
    
        # define some parameters needed for data enhancement and rough justification
        resize_height, resize_width = 32, 32
        rescale = 1.0 / 255.0
        shift = 0.0
        rescale_nml = 1 / 0.3081
        shift_nml = -1 * 0.1307 / 0.3081
    
        # according to the parameters, generate the corresponding data enhancement method
        resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
        rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
        rescale_op = CV.Rescale(rescale, shift)
        hwc2chw_op = CV.HWC2CHW()
        type_cast_op = C.TypeCast(mstype.int32)
    
        # using map to apply operations to a dataset
        mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
        mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    
        # process the generated dataset
        buffer_size = 10000
        mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
        mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
        mnist_ds = mnist_ds.repeat(repeat_size)
    
        return mnist_ds
    
    
    class LeNet5(nn.Cell):
        """Lenet network structure."""
        # define the operator required
        def __init__(self, num_class=10, num_channel=1):
            super(LeNet5, self).__init__()
            self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
            self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
            self.relu = nn.ReLU()
            self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
            self.flatten = nn.Flatten()
    
        # use the preceding operators to construct networks
        def construct(self, x):
            x = self.max_pool2d(self.relu(self.conv1(x)))
            x = self.max_pool2d(self.relu(self.conv2(x)))
            x = self.flatten(x)
            x = self.relu(self.fc1(x))
            x = self.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    
    # custom callback function
    class StepLossAccInfo(Callback):
        def __init__(self, model, eval_dataset, steps_loss, steps_eval):
            self.model = model
            self.eval_dataset = eval_dataset
            self.steps_loss = steps_loss
            self.steps_eval = steps_eval
            self.steps = 0
    
        def step_end(self, run_context):
            cb_params = run_context.original_args()
            cur_epoch = cb_params.cur_epoch_num
            #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
            self.steps = self.steps+1
            cur_step = self.steps
    
            self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
            self.steps_loss["step"].append(str(cur_step))
            if cur_step % 125 == 0:
                acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
                self.steps_eval["step"].append(cur_step)
                self.steps_eval["acc"].append(acc["Accuracy"])
    
    
    def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path):
        ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size)
        eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32)
    
        # save the network model and parameters for subsequence fine-tuning
        config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)
        # group layers into an object with training and evaluation features
        ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck)
    
        steps_loss = {"step": [], "loss_value": []}
        steps_eval = {"step": [], "acc": []}
        # collect the steps,loss and accuracy information
        step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval)
    
        model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=True)
    
        return steps_loss, steps_eval
    
    
    
    
    
    
    epoch_size = 1
    repeat_size = 1
    mnist_path = "./datasets/MNIST_Data"
    model_path = "./models/ckpt/mindspore_quick_start/"
    
    # clean up old run files before in Linux
    os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path))
    
    lr = 0.01
    momentum = 0.9
    
    # create the network
    network = LeNet5()
    
    # define the optimizer
    net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
    
    # define the loss function
    net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    
    # define the model
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
    
    steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path)
    
    
    steps = steps_loss["step"]
    loss_value = steps_loss["loss_value"]
    steps = list(map(int, steps))
    loss_value = list(map(float, loss_value))
    plt.plot(steps, loss_value, color="red")
    plt.xlabel("Steps")
    plt.ylabel("Loss_value")
    plt.title("Change chart of model loss value")
    plt.show()
    
    
    def eval_show(steps_eval):
        plt.xlabel("step number")
        plt.ylabel("Model accuracy")
        plt.title("Model accuracy variation chart")
        plt.plot(steps_eval["step"], steps_eval["acc"], "red")
        plt.show()
    
    eval_show(steps_eval)
    View Code
    
    
    

     




    本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。
  • 相关阅读:
    pdfobject (前台展示PDF插件)
    ERROR 19608 --- [ost-startStop-1] c.atomikos.persistence.imp.LogFileLock : ERROR: the specified log seems to be in use already: tmlog in D: ools omcatapache-tomcat-8.5.51in ransaction-logs
    文件上传下载(四) 读 txt 文本 ajaxfileupload
    1129
    centos服务器上部署项目(二) -tomcat
    Guns 打包
    centos服务器上部署项目(一) -jdk,mysql
    layui 学习笔记(四) 复杂表头前台Excel导出
    SpringCloud项目搭建(四) zuul
    sql的基本查询语句
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/14995006.html
Copyright © 2011-2022 走看看