zoukankan      html  css  js  c++  java
  • Pytorch 技巧总结(持续更新)

    • 定义自己的数据集

    第一种方法:ImageFolder函数,具体参考官方文档:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

    第二种方法:定义数据类继承Dataset 

    My code :

    class SkData(Dataset):
        def __init__(self,img_path,transforms=None):
            self.imgs = pd.read_csv(img_path)
            self.transforms = data_transforms
        def __len__(self):
            return len(self.imgs)
        def __getitem__(self,index):
            # if self.way == 'sk_class':
            img_paths = self.imgs.loc[index].values[0]
            label1 = np.array(int(img_paths.split('/')[-2].split('A')[-1])-1)
            label1 = torch.from_numpy(label1)
            label2 = np.array(np.random.randint(0,2))
            label2 = torch.from_numpy(label2)
            img = cv2.imread(img_paths)
            img = Image.fromarray(img)
            if self.transforms:
                img = self.transforms(img)
            return img,label1,label2

    __init__中主要进行数据路径的读取和定义一些操作

    __getitem__中主要进行opencv或者其他库对图片的读取和标签制作,最后一般返回的是图片与标签(可以定义成字典)

    • 定义自己的网络反向传播与正向传播,以GRL(梯度反置层)为例
    class ReverseLayerF(torch.autograd.Function):
    
        @staticmethod
        def forward(ctx, x, alpha):
            ctx.alpha = alpha
    
            return x
    
        @staticmethod
        def backward(ctx, grad_output):
            output = grad_output.neg() * ctx.alpha
    
            return output, None

    继承Function,后使用修饰器定义自己的forward和backward

    • 定义网络结构

    一种可以使用堆叠的方式nn.Sequential方式,此种方式可以定义简单网络

    另一种是继承nn.Module类定义自己的网络

    class DiffNet(nn.Module):
        def __init__(self):
            super(DiffNet,self).__init__()
            self.base_model = torchvision.models.resnet50(pretrained=True)
            # self.base_model.aux_logits = False
            self.flatten = nn.Flatten()
            self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
            self.avgpooling = nn.AdaptiveAvgPool2d((1,1))
            self.fc1 = nn.Linear(2048,1024)
            self.fc2 = nn.Linear(1024,512)
            self.output1 = nn.Linear(512,2)
            self.output2 = nn.Linear(512,60)
            self.dropout = nn.Dropout(p=0.5)
        def forward(self,x):
            output_feature = self.base_model(x)
            output_feature = self.avgpooling(output_feature)
            output_feature = self.flatten(output_feature)
            output_feature_reverse = ReverseLayerF.apply(output_feature,0.9)
            # output_feature_reverse = output_feature
            dropout1 = self.dropout(output_feature_reverse)
            fc1 = F.relu(self.fc1(dropout1))
            dropout2 = self.dropout(fc1)
            fc2 = F.relu(self.fc2(dropout2))
            output1 = self.output1(fc2)
            # output1 = nn.LogSoftmax(output1)#rota_class
            dropout3 = self.dropout(output_feature)
            fc3 = self.fc1(dropout3)
            dropout4 = self.dropout(fc3)
            fc4 = self.fc2(dropout4)
            output2 = self.output2(fc4)#sk_class
            # output2 = nn.LogSoftmax(output2)
            return output1,output2

    Note:在定义loss的时候注意torch.nn.NLLLoss与torch.nn.CrossEntropyLoss的区别

    • 冻结某些层的参数
    1 for child in model.children():
    2     ct += 1
    3     if ct == 1:
    4            for param in childs.parameters():
    5                  param.requires_grad = False

    注意在定义优化器的时候过滤掉冻结的参数

     1 optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_init) 

  • 相关阅读:
    【转】ios输入框被键盘挡住的解决办法
    【转】操作系统Unix、Windows、Mac OS、Linux的故事
    mac 下删除非空文件夹
    解决Win7 64bit + VS2013 使用opencv时出现提“应用程序无法正常启动(0xc000007b)”错误
    图的邻接表表示
    图的邻接矩阵表示
    并查集
    05-树9 Huffman Codes及基本操作
    05-树7 堆中的路径
    堆的操作集
  • 原文地址:https://www.cnblogs.com/lizhe-cnblogs/p/14034997.html
Copyright © 2011-2022 走看看