zoukankan      html  css  js  c++  java
  • torchvision.transforms.ToTensor(),torchvision.trasnsforms.Normalize()

    用PyTorch进行神经网络训练时,如果训练用的数据是图像数据,则需要在训练之前对图像进行预处理。以MNIST数据为例:

    train_data = torchvision.datasets.MNIST(
        root='./mnist/',   
        train=True,                                     
        transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        download=True,                        
    )
    

    transform=torchvision.transforms.ToTensor()起到的作用是把PIL.Image或者numpy.narray数据类型转变为torch.FloatTensor类型,shape是C*H*W,数值范围缩小为[0.0, 1.0]。

    如果想把数值范围调整为[-1.0, 1.0],则可加torchvision.transforms.Normalize([mean_channel1,mean_channel2,mean_channel3], [std_channel1,std_channel2,std_channel3]),如果是黑白图像,比如MNIST里的图像,只有一个通道,则mean只需要一个,std也只需要一个。

    im_tfs = torchvision.trasnsforms.Compose([
        torchvision.trasnsforms.ToTensor(),
        torchvision.trasnsforms.Normalize([0.5], [0.5]) 
    ])
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',   
        train=True,                                     
        transform=torchvision.transforms.ToTensor(),    
        download=True,                        
    )
  • 相关阅读:
    System lock
    skip_slave_start
    慢查询日志分析
    wait_timeout 、interactive_timeout、slave_net_timeout、master_heartbeat_period
    reset slave,reset slave all,reset master都干了些啥?
    强制删除有外键约束的数据
    集群拓扑结构变更
    在线开启gtid与在线关闭gtid
    less
    pg流复制
  • 原文地址:https://www.cnblogs.com/picassooo/p/12584904.html
Copyright © 2011-2022 走看看