zoukankan      html  css  js  c++  java
  • 统计模型计算量~pytorch

    import time
    from options.train_options import TrainOptions
    from data import create_dataset
    from models import create_model
    from util.visualizer import Visualizer
    from torchsummaryX import summary
    
    if __name__ == '__main__':
        opt = TrainOptions().parse()   # get training options
        dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
        dataset_size = len(dataset)    # get the number of images in the dataset.
        print('The number of training images = %d' % dataset_size)
        model = create_model(opt)      # create a model given opt.model and other options
        model.setup(opt)               # regular setup: load and print networks; create schedulers
        visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
        total_iters = 0                # the total number of training iterations
    
        for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
            epoch_start_time = time.time()  # timer for entire epoch
            iter_data_time = time.time()    # timer for data loading per iteration
            epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
            #visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch
    
            for i, data in enumerate(dataset):  # inner loop within one epoch
                iter_start_time = time.time()  # timer for computation per iteration
                if total_iters % opt.print_freq == 0:
                    t_data = iter_start_time - iter_data_time
    
                total_iters += opt.batch_size
                epoch_iter += opt.batch_size
                model.set_input(data)         # unpack data from dataset and apply preprocessing
                summary(model, [data['label'], data['image']])
    

      

  • 相关阅读:
    mysql配置图解(mysql 5.5)
    C++中的enum
    vc6.0中的dsp,dsw,ncb,opt,clw,plg,aps等文件的简单说明
    using namespace std
    C#中Cache的使用 迎客
    数据库里的存储过程和事务有什么区别? 迎客
    WINDOWS远程默认端口3389的正确修改方式 迎客
    DES加密和解密PHP,Java,ObjectC统一的方法 迎客
    转:15点 老外聊iPhone游戏开发注意事项 迎客
    windows server 2003 删除默认共享 迎客
  • 原文地址:https://www.cnblogs.com/wjjcjj/p/14601009.html
Copyright © 2011-2022 走看看