zoukankan      html  css  js  c++  java
  • 统计模型计算量~pytorch

    import time
    from options.train_options import TrainOptions
    from data import create_dataset
    from models import create_model
    from util.visualizer import Visualizer
    from torchsummaryX import summary
    
    if __name__ == '__main__':
        opt = TrainOptions().parse()   # get training options
        dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
        dataset_size = len(dataset)    # get the number of images in the dataset.
        print('The number of training images = %d' % dataset_size)
        model = create_model(opt)      # create a model given opt.model and other options
        model.setup(opt)               # regular setup: load and print networks; create schedulers
        visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
        total_iters = 0                # the total number of training iterations
    
        for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
            epoch_start_time = time.time()  # timer for entire epoch
            iter_data_time = time.time()    # timer for data loading per iteration
            epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
            #visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch
    
            for i, data in enumerate(dataset):  # inner loop within one epoch
                iter_start_time = time.time()  # timer for computation per iteration
                if total_iters % opt.print_freq == 0:
                    t_data = iter_start_time - iter_data_time
    
                total_iters += opt.batch_size
                epoch_iter += opt.batch_size
                model.set_input(data)         # unpack data from dataset and apply preprocessing
                summary(model, [data['label'], data['image']])
    

      

  • 相关阅读:
    mysql 错误 1067: 进程意外终止
    VPS主机MSQL意外中断重启就好但10来个小时又中断的了如些反复
    使用hibernate连接mysql自动中断的问题
    40个国外联盟
    从服务里删除mysql
    外国广告联盟[16个]
    stm32学习笔记:GPIO外部中断的使用
    NO.2 设计包含min 函数的栈
    GPS数据,实测
    LATEX使用总结
  • 原文地址:https://www.cnblogs.com/wjjcjj/p/14601009.html
Copyright © 2011-2022 走看看