zoukankan      html  css  js  c++  java
  • Learning to Compare: Relation Network 源码调试

    CVPR 2018 的一篇少样本学习论文

    Learning to Compare: Relation Network for Few-Shot Learning

    源码地址:https://github.com/floodsung/LearningToCompare_FSL

    在自己的破笔记本上跑了下这个源码,windows 系统,pycharm + Anaconda3 + pytorch-cpu 1.0.1

    报了一堆bug, 总结如下:

    procs_images.py里 ‘cp’报错

    用procs_images.py处理 miniImangenet 数据集的时候:

    报错信息:
    /LearningToCompare_FSL-master/datas/miniImagenet/proc_images.py
    'cp' �����ڲ����ⲿ���Ҳ���ǿ����еij������������ļ���

    具体位置是

    /datas/miniImagenet/procs_images.py  Line 48:
    os.system('cp images/' + image_name + ' ' + cur_dir)

    这个‘cp’是linux环境运行的。

    用windows系统的话要改成:

    os.rename('images/' + image_name, cur_dir + image_name)

    除此之外,所有的 os.system('mkdir ' + filename)

    也要改成 os.mkdir(filename),虽然不一定会报错。

    cpu RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False.

    我的torch版本是是cpu, 所以把所有 .cuda(GPU)删了,另外

    使用torch.load时添加 ,map_location ='cpu'

    以miniImagenet_train_few_shots.py 为例
    Line 150:
    feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
    改成
    feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location = 'cpu'
    ))
    Line:153:
    relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
    改成
    relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location = 'cpu'))

    KeyError: '..\datas\omniglot_resized'

    报错信息:
      File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in main
        task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
      File "LearningToCompare_FSL-masteromniglot	ask_generator.py", line 72, in <listcomp>
        self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
    KeyError: '..\datas\omniglot_resized'

    关键的地方其实是在:

     task_generator.py, line 74:
      def get_class(self, sample):
            return os.path.join(*sample.split('/')[:-1])

    print (os.path.join(*sample.split('/')[:-1])) 结果是

    ..datasomniglot_resized

    而labels是

      {'../datas/omniglot_resized/Malay_(Jawi_-_Arabic)\character25': 0, '../datas/omniglot_resized/Japanese_(hiragana)\character15': 1, '…}

    而 print(os.path.join(*sample.split('\')[:-1]))  结果正是

    ../datas/omniglot_resized/Malay_(Jawi_-_Arabic)character25

    解决方法:把'/'改成'\'即可 
    def get_class(self, sample): return os.path.join(*sample.split('\')[:-1])

    RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

    报错信息:
    File "/LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 193, in main torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)) RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

    解决方法:在前面加一句

     batch_labels = batch_labels.long()

    RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

    报错信息:  
    File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_test_few_shot.py", line 247, in <listcomp> rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)] RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

    解决方法:在前面加上

    predict_labels = predict_labels.long()
    test_labels = test_labels.long()

    这两个好像是使用torch的数据格式问题

    IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

    报错信息:
    File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
        print("episode:",episode+1,"loss",loss.data[0])
    IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    
    按要求改成
    print("episode:", episode + 1, "loss", loss.item())
    就可以了

    RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

    报错信息:
    File "LearningToCompare_FSL-masteromniglot	ask_generator.py", line 107, in __getitem__
        image = self.transform(image)
      File "...Anaconda3envspython36libsite-packages	orchvision	ransforms	ransforms.py", line 60, in __call__
        img = t(img)
      File "...Anaconda3envspython36libsite-packages	orchvision	ransforms	ransforms.py", line 163, in __call__
        return F.normalize(tensor, self.mean, self.std, self.inplace)
      File "...Anaconda3envspython36libsite-packages	orchvision	ransformsfunctional.py", line 208, in normalize
        tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
    RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

    这个是使用Omniglot数据集时的报错,主要原因在于

    "omniglot	ask_generator.py", line 139:
    
    def get_data_loader(task, num_per_class=1, split='train',shuffle=True,rotation=0):    
        normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
        dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation),transforms.ToTensor(),normalize]))

    使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28]

    解决方法:

    把
     normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
    改成
     normalize = transforms.Normalize(mean=[0.92206], std=[0.08426]) 

    UserWarning: nn.functional.sigmoid is deprecated.

    类似的warning 还有

    UserWarning : torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.

    按要求改就行

    torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
    改成
    torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
    
    def forward里的
    out = F.sigmoid(self.fc2(out))
    改成
    out = F.torch.sigmoid(self.fc2(out))
  • 相关阅读:
    客户端回传事件接口IPostBackEventHandler
    《敏捷无敌》—— 一本非常好看的“IT技术言情小说”
    面向对象之设计
    Zac谈网络编辑需要注意的SEO技巧
    面向对象之领悟
    《网络营销实战密码》推荐
    设计模式建造者模式(builder)
    设计模式工厂方法(FactoryMethod)
    struts1.x与struts2的比较表
    设计模式原型模式(ProtoType)
  • 原文地址:https://www.cnblogs.com/smartweed/p/10750065.html
Copyright © 2011-2022 走看看