zoukankan      html  css  js  c++  java
  • 【pytorch报错解决】expected input to have 3 channels, but got 1 channels instead

    遇到的问题

    数据是png图像的时候,如果用PIL读取图像,获得的是单通道的,不是多通道的。虽然使用opencv读取图片可以获得三通道图像数据,如下:

        def __getitem__(self, idx):
            image_root = self.train_image_file_paths[idx]
            image_name = image_root.split(os.path.sep)[-1]
            image = cv.imread(image_root)
    
            if self.transform is not None:
                image = self.transform(image)
            label = ohe.encode(image_name.split('_')[0]) 
            return image, label
    

    但是会出现报错:

    TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

      File "c:/Users/pprp/Desktop/pytorch-captcha-recognition-master/captcha_train.py", line 77, in <module>
        main(args)
      File "c:/Users/pprp/Desktop/pytorch-captcha-recognition-master/captcha_train.py", line 47, in main
        predict_labels = cnn(images)
      File "E:ProgramDataMiniconda3envspytorchlibsite-packages	orch
    nmodulesmodule.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "E:ProgramDataMiniconda3envspytorchlibsite-packages	orchvisionmodels
    esnet.py", line 192, in forward
        x = self.conv1(x)
      File "E:ProgramDataMiniconda3envspytorchlibsite-packages	orch
    nmodulesmodule.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "E:ProgramDataMiniconda3envspytorchlibsite-packages	orch
    nmodulesconv.py", line 338, in forward
        self.padding, self.dilation, self.groups)
    RuntimeError: Given groups=1, weight of size 64 3 7 7, expected input[64, 60, 160, 3] to have 3 channels, but got 60 channels instead
    

    最终解决方案:

    class mydataset(Dataset):
        def __init__(self, folder, transform=None):
            self.train_image_file_paths = [os.path.join(folder, image_file) for image_file in os.listdir(folder)]
            self.transform = transforms.Compose([
                                                transforms.ToTensor(), # 转化为pytorch中的tensor
                                                transforms.Lambda(lambda x: x.repeat(1,1,1)), # 由于图片是单通道的,所以重叠三张图像,获得一个三通道的数据
                                                # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                                ]) # 主要改这个地方
    
        def __len__(self):
            return len(self.train_image_file_paths)
    
        def __getitem__(self, idx):
            image_root = self.train_image_file_paths[idx]
            image_name = image_root.split(os.path.sep)[-1]
            image = Image.open(image_root)
            if self.transform is not None:
                image = self.transform(image)
            label = ohe.encode(image_name.split('_')[0]) 
            return image, label
    

    pytorch transform 知识点:https://blog.csdn.net/u011995719/article/details/85107009
    PIL PNG格式通道问题的解决方法 : https://www.cnblogs.com/wzjbg/p/8516531.html

  • 相关阅读:
    解决centos7的root账户下无法通过code命令启动vscode
    centos7安装epel
    centos7用过yum安装vscode
    yum install gcc报错Error: Package: glibc-2.17-260.el7_6.6.i686 (updates) Requires: glibc-common = 2.17
    centos7通过yum从vim7升级到vim8
    解决VM虚拟机安装centos7无法联网
    centos7设置开机默认使用root账户登陆
    centos7使用sudo命令提示sudo command not found
    不同编译器下C++基本数据类型的字节长度
    C++函数模板
  • 原文地址:https://www.cnblogs.com/pprp/p/11705791.html
Copyright © 2011-2022 走看看