zoukankan      html  css  js  c++  java
  • PointNet++ MSG

    msg的训练的巨慢,不过总算是复现出来了,之后搞一下MRG的,或者裸的

    configuration.py:

    import torch.cuda
    
    
    class config():
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        dataset_root = 'C:/Users/Dell/PycharmProjects/PointNet++/dataset'
        checkpoint_root = 'C:/Users/Dell/PycharmProjects/PointNet++/checkpoint'
    
        num_epochs = 10
        batch_size = 4
        num_seg = 40
    View Code

    Model.py:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.utils.data.dataloader as Dataloader
    import torch.utils.data as data
    from configuration import config
    
    cfg = config()
    
    def index_points(points, idx):
        """
    
        Input:
            points: input points data, [B, N, C]
            idx: sample index data, [B, S],其中S可以是一个多项式
        Return:
            new_points:, indexed points data, [B, S, C]
        """
        device = points.device
        B = points.shape[0]
        view_shape = list(idx.shape)
        view_shape[1:] = [1] * (len(view_shape) - 1)
        repeat_shape = list(idx.shape)
        repeat_shape[0] = 1
        batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
        new_points = points[batch_indices, idx, :]
        return new_points
    
    
    
    def farthest_sample(data, sample_n):
        #data = data.transpose(1, 2)
        B, N, C = data.size()
        B_list = [i for i in range(B)]
        #dis标记每个点到采样点集合的距离
        dis = torch.ones((B, N)) * 1e-10
        #初始随机选一个点
        u = torch.randint(0, N, (B, ))
        vis = torch.ones((B, N))
        #vis标记某点是否被选
        vis[B_list, u] = 0
        ret = torch.zeros((B, sample_n), dtype = torch.long)
        for i in range(sample_n):
            ret[:, i] = u
            #取出点u
            cen = data[B_list, u, :].view(B, 1, 3)
            #求所有点到点u的距离
            distance = torch.sum((data - cen) ** 2, -1)
            #如果存在某个点到u的距离大于该点到采样集合其它点的距离,并且该点不在采样集合中,则更新
            idx = torch.logical_and(dis < distance, vis)
            dis[idx] = distance[idx]
            u = torch.max(dis, -1)[1]
            #如果某点被选
            dis[B_list, u] = 0
            vis[B_list, u] = 0
        return index_points(data, ret)
    
    def square_distance(data, center):
        #input: pointcloud_data center
        #output: dis[bs][sample_n][N], dis[i][j][k] 第i个点云的第j个中心到第k个点的距离
        N = data.size()[1]
        B, sample_n,_ = center.size()
        bl = torch.arange(B, dtype = torch.long)
        nl = torch.arange(N, dtype = torch.long)
        dis = torch.zeros((B, sample_n, N))
        for i in range(sample_n):
            coor = center[:, i, :].view(B, 1, 3)
            value = torch.sum((data - coor) ** 2, -1)
            dis[:, i, :] = value
        return dis
    
    
    def query_ball_point(data, dis, radius, k):
        #input:点云、dis[B][n][N]是每个点离采样中心点的距离、radius半径、k每组取半径内的k个点
        #output:res[B][n][k][3],B个点云,每个点云n个组,每个组k个点的坐标
        B, n, N = dis.size()
        group = torch.zeros([B, n, N], dtype=torch.long)
        B_list = torch.arange(B, dtype=torch.long)
        group[:, :, :] = torch.arange(N, dtype=torch.long)
        # 将不在半径范围内的点赋值为N,并从小到大排序,取前nsample个,group_idx[i, j, k]表示是的是第i个点云,第j个区域的第k个邻近点是谁
        # 这里存的是下标
        group[dis > radius ** 2] = N
        group = group.sort(dim=-1)[0][:, :, : k]
        # 以下三行就是如果半径范围内的点不够需要的数量,则赋值为离center最近的点,就是center自己吧
        idx = group == N
        defa = group[:, :, 0].view(B, n, 1).repeat(1, 1, k)
        group[idx] = defa[idx]
    
        return group
    
    
    # def cal_coor(data, group):
    #     B, n, k = group.size()
    #     B_list = torch.arange(B, dtype = torch.long)
    #     res = torch.zeros([B, n, k, 3])
    #     for i in range(n):
    #         for j in range(k):
    #             res[:, i, j, :] = data[B_list, group[:, i, j]]
    #     return res
    
    def sample_and_group(xyz, point, radius, nsample, dis):
        group = query_ball_point(xyz, dis, radius, nsample)
        new_xyz = index_points(xyz, group) # B * n * k * 3
        new_point = index_points(point, group) # B * n * k * C
        return new_xyz, new_point
    
    
    
    
    class PointNetSetAbstractionMsg(nn.Module):
        #npoint采样点个数,radius_list半径列表,nsample_list每个半径所取的邻近点个数,in_channel输入通道数,mlp_list网络列表
        def __init__(self, npoint = None, radius_list = None, nsample_list = None, in_channel = None, mlp_list = None):
            super().__init__()
            self.npoint = npoint
            self.radius_list = radius_list
            self.nsample_list = nsample_list
            self.conv_blocks = nn.ModuleList()
            self.bn_blocks = nn.ModuleList()
            for i in range(len(mlp_list)):
                convs = nn.ModuleList()
                bns = nn.ModuleList()
                last_channel = in_channel
                for j in range(len(mlp_list[i])):
                    convs.append(nn.Conv2d(last_channel, mlp_list[i][j], 1)) #B * n * k * C所以用conv2d,论文中说是全连接层,这里先用1 * 1 的卷积
                    bns.append(nn.BatchNorm2d(mlp_list[i][j]))
                    last_channel = mlp_list[i][j]
                self.conv_blocks.append(convs)
                self.bn_blocks.append(bns)
        #xyz是每层的点的坐标,point其实是每层的特征(初始 = xyz);主要就是用xyz求出分组下标,然后给point分组,然后卷积
        #二维图像只需要传特征就行,因为像素在tensor中相邻,在实际坐标系中肯定也相邻,而点云不一定
        def forward(self, xyz, point):
            xyz = xyz.transpose(1, 2)
            if point is not None :
                point = point.transpose(1, 2)
            if self.npoint is not None:
                center = farthest_sample(xyz, self.npoint)
                dis = square_distance(xyz, center)
                point_list = []
                for i in range(len(self.radius_list)):
                    new_xyz, new_point = sample_and_group(xyz, point, self.radius_list[i], self.nsample_list[i], dis)
                    new_point = new_point.permute(0, 3, 2, 1)
                    for j in range(len(self.conv_blocks[i])):
                        conv = self.conv_blocks[i][j]
                        bn = self.bn_blocks[i][j]
                        new_point = F.relu(bn(conv(new_point)))
                    new_point = torch.max(new_point, dim=2)[0]  # B * C‘ * n,相当于maxpool,每组中用最大的特征表示该组
                    point_list.append(new_point)
                new_point = torch.cat(point_list, dim = 1) #Msg中将不同半径所生成的向量整合
                center = center.transpose(2, 1)
                return center, new_point
            else:
                center = torch.zeros(xyz.size(0), 1, xyz.size(2))
                new_point = point.view(point.size(0), 1, point.size(1), point.size(2))
                new_point = new_point.permute(0, 3, 2, 1)
                for j in range(len(self.conv_blocks[0])):
                    conv = self.conv_blocks[0][j]
                    bn = self.bn_blocks[0][j]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim = 2)[0]
                center = center.transpose(1, 2)
    
                return center, new_point
    
    class PointNetSetAbstraction(nn.Module):
        def __init__(self, npoint = None, radius = None, nsample = None, in_channel = None, mlp = None, group_all = False):
            super(PointNetSetAbstraction, self).__init__()
            self.group_all = group_all
            if self.group_all == False:
                self.npoint = npoint
                self.radius = radius
                self.nsample = nsample
            self.conv_block = nn.ModuleList()
            self.bn_block = nn.ModuleList()
            last_channel = in_channel
            for i in range(len(mlp)):
                self.conv_block.append(nn.Conv2d(last_channel, mlp[i], 1))
                self.bn_block.append(nn.BatchNorm2d(mlp[i]))
                last_channel = mlp[i]
        def forward(self, xyz, point):
            xyz = xyz.transpose(1, 2)
            point = point.transpose(1 ,2)
            if self.group_all == False:
                center = farthest_sample(xyz, self.npoint)
                dis = square_distance(xyz, center)
                new_xyz, new_point = sample_and_group(xyz, point, self.radius, self.nsample, dis)
                new_point = new_point.permute(0, 3, 2, 1)
                for i in range(len(self.conv_block)):
                    conv = self.conv_block[i]
                    bn = self.bn_block[i]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim = 2)[0]
                center = center.transpose(1, 2)
                return center, new_point
            else:
                new_point = point.view(point.size(0), 1, point.size(1), point.size(2)).permute(0, 3, 2, 1)
                for i in range(len(self.conv_block)):
                    conv = self.conv_block[i]
                    bn = self.bn_block[i]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim = 2)[0]
                center = torch.zeros(xyz.size(0), 1, xyz.size(2))
                center = center.transpose(1, 2)
                return center, new_point
    
    
    
    
    class PointNetFeaturePropagation(nn.Module):
        def __init__(self, inchannel, mlp_list):
            super().__init__()
            self.conv_blocks = nn.ModuleList()
            self.bn_blocks = nn.ModuleList()
            last_channel = inchannel
            for i in range(len(mlp_list)):
                self.conv_blocks.append(nn.Conv1d(last_channel, mlp_list[i], 1))
                self.bn_blocks.append(nn.BatchNorm1d(mlp_list[i]))
                last_channel = mlp_list[i]
        def forward(self, xyz1, xyz2, point1, point2):
            #由xyz2的特征point2推出xyz1中每个点的特征,并与point1连接后minipointnet
            xyz1 = xyz1.transpose(1, 2)
            xyz2 = xyz2.transpose(1, 2)
            point2 = point2.transpose(1, 2)
            B, N, _ = xyz1.size()
            _, n, _ = xyz2.size()
            if n == 1:
                point2 = point2.repeat(1, N, 1)
            else:
                dis, idx = square_distance(xyz2, xyz1).sort(dim = -1)
                dis = dis[:, :, 1 : 4] #B * N * 3,3个最近的点的距离的平方,因为求dis的时候没开方
                idx = idx[:, :, 1 : 4] #三个最近的点的原坐标
                feature = index_points(point2, idx) # B * N * 3 * D
                w = 1.0 / (dis + 1e-8) # B * N * 3
                ep_w = torch.sum(w, dim = -1, keepdim = True) # B * N * 1
                ratio = w / ep_w # B * N * 3
                #每个系数都 * 相应的特征,并将每一维的三个特征加和
                point2 = torch.sum(feature * ratio.view(B, N , 3, 1), dim = 2) # B * N * D
            if point1 is not None:
                point1 = point1.transpose(1, 2)
                point = torch.cat([point1, point2], dim = -1)
            else:
                point = point2
            point = point.transpose(1, 2)
            for i in range(len(self.conv_blocks)):
                conv = self.conv_blocks[i]
                bn = self.bn_blocks[i]
                point = F.relu(bn(conv(point)))
            return point
    
    
    class PointNet_add_Msg(nn.Module):
        def __init__(self):
            super().__init__()
            self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
            self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 64 + 128 + 128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]])
            self.sa3 = PointNetSetAbstractionMsg(in_channel = 128 + 256 + 256 , mlp_list = [[256, 512, 1024]])
            self.fp1 = PointNetFeaturePropagation(1664, [256, 256])
            self.fp2 = PointNetFeaturePropagation(576, [256, 128])
            self.fp3 = PointNetFeaturePropagation(131, [128, 128])
            self.conv1 = nn.Conv1d(128, 128, 1)
            self.conv2 = nn.Conv1d(128, cfg.num_seg, 1)
    
            self.bn1 = nn.BatchNorm1d(128)
    
        def forward(self, xyz):
            xyz1, point1 = self.sa1(xyz, xyz)
            xyz2, point2 = self.sa2(xyz1, point1)
            xyz3, point3 = self.sa3(xyz2, point2)
            point2 = self.fp1(xyz2, xyz3, point2, point3)
            point1 = self.fp2(xyz1, xyz2, point1, point2)
            point0 = self.fp3(xyz, xyz1, xyz, point1)
            point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5)
            point = self.conv2(point)
            point = point.transpose(1, 2).contiguous()
            point = point.view(-1, cfg.num_seg)
            return point
    View Code

    train.py:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from tqdm import tqdm
    from DataSet import Dataset
    from Model import PointNet_add
    from configuration import config
    import torch.utils.data.dataloader as DataLoader
    from tensorboardX import SummaryWriter
    import os
    
    
    cfg = config()
    
    if __name__ == '__main__':
    
        model = PointNet_add()
        model.to(cfg.device)
        dataset = Dataset(cfg.dataset_root)
        dataloader = DataLoader.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = True)
        optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
        loss = nn.CrossEntropyLoss()
        tbwrite = SummaryWriter(logdir = os.path.join(cfg.checkpoint_root, 'log'))
        model.train()
        for epoch in range(cfg.num_epochs):
            total_true = 0
            total_loss = 0
            cnt = 0
            for xyz, label in tqdm(dataloader):
                optimizer.zero_grad()
                output = model(xyz)
                label = label.view(-1, 1)[:, 0]
                loss_value = loss(output, label)
                loss_value.backward()
                optimizer.step()
                pred = torch.max(output, -1)[1]
                total_true += torch.sum(pred == label)
                total_loss += loss_value.item()
                cnt += 1
            mean_loss = total_loss / float(cnt)
            accuracy = total_true / float(len(dataset))
            tbwrite.add_scalar('Loss', mean_loss, epoch)
            tbwrite.add_scalar('Accuracy', accuracy, epoch)
            print('mean_loss:{:.4f}, accuracy:{:.4f}'.format(mean_loss, accuracy))
            if (epoch + 1) % cfg.num_epochs == 0:
                state = {
                    'model': model.state_dict()
                }
                torch.save(state, os.path.join(cfg.checkpoint_root, 'checkpoint_{}.pth'.format(epoch)))
    View Code

    DataSet.py:

    import numpy as np
    import torch.utils.data as data
    import os
    import random
    import torch
    
    
    
    class Dataset(data.Dataset):
        def __init__(self, root):
            super().__init__()
            self.root = root
            data_list = os.listdir(os.path.join(root, 'points'))
            label_list = os.listdir(os.path.join(root, 'points_label'))
            self.data_list = sorted(data_list, key = lambda x : int(x.split('.')[0]))
            self.label_list = sorted(label_list, key = lambda x : int(x.split('.')[0]))
    
        def __getitem__(self, index):
            self.data = np.loadtxt(os.path.join(self.root, 'points', self.data_list[index]))
            self.label = np.loadtxt(os.path.join(self.root, 'points_label', self.label_list[index]))
    
            #采样2500个点,如果不够,则随机抽样补全
            if self.data.shape[0] >= 2500:
                sample_list = random.sample(range(self.data.shape[0]), 2500)
                self.data = self.data[sample_list, :]
                self.label = self.label[sample_list]
            else:
                sample_list = random.sample(range(self.data.shape[0]), 2500 - self.data.shape[0])
                dup_data = self.data[sample_list, :]
                dup_label = self.label[sample_list]
                self.data = np.concatenate([self.data, dup_data], 0)
                self.label = np.concatenate([self.label, dup_label], 0)
    
            self.label = torch.tensor(self.label)
            self.label = self.label.type(torch.LongTensor)
            self.data = torch.tensor(self.data.T)
            #label要是Longtensor,data要是float32
            self.data = self.data.to(torch.float32)
    
            return self.data, self.label
    
        def __len__(self):
            return len(self.label_list)
    View Code

    以下是分类和分割的结合代码,有普通的和msg的

    Dataset.py:

    import numpy as np
    import torch.utils.data as data
    import os
    import random
    import torch
    import h5py
    
    
    
    class Dataset(data.Dataset):
        def __init__(self, root):
            super().__init__()
            self.root = root
            data_list = os.listdir(os.path.join(root, 'points'))
            label_list = os.listdir(os.path.join(root, 'points_label'))
            self.data_list = sorted(data_list, key = lambda x : int(x.split('.')[0]))
            self.label_list = sorted(label_list, key = lambda x : int(x.split('.')[0]))
    
        def __getitem__(self, index):
            self.data = np.loadtxt(os.path.join(self.root, 'points', self.data_list[index]))
            self.label = np.loadtxt(os.path.join(self.root, 'points_label', self.label_list[index]))
            #采样2500个点,如果不够,则随机抽样补全
            if self.data.shape[0] >= 2500:
                sample_list = random.sample(range(self.data.shape[0]), 2500)
                self.data = self.data[sample_list, :]
                self.label = self.label[sample_list]
            else:
                sample_list = random.sample(range(self.data.shape[0]), 2500 - self.data.shape[0])
                dup_data = self.data[sample_list, :]
                dup_label = self.label[sample_list]
                self.data = np.concatenate([self.data, dup_data], 0)
                self.label = np.concatenate([self.label, dup_label], 0)
    
            self.label = torch.tensor(self.label)
            self.label = self.label.type(torch.LongTensor)
            self.data = torch.tensor(self.data.T)
            #label要是Longtensor,data要是float32
            self.data = self.data.to(torch.float32)
    
            return self.data, self.label
    
        def __len__(self):
            return len(self.label_list)
    
    
    
    class claDataset():
        def __init__(self, root):
            super(claDataset, self).__init__()
            dataset = h5py.File(root, 'r')
            self.data = dataset['data'][:]
            self.label = dataset['label'][:][:, 0]
        def __getitem__(self, index):
            label = torch.tensor(self.label[index])
            label = label.type(torch.LongTensor)
            return torch.tensor(self.data[index].T), label
        def __len__(self):
            return len(self.label)
    View Code

    configuration.py:

    import torch.cuda
    
    
    class config():
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        dataset_root = 'C:/Users/Dell/PycharmProjects/PointNet++/dataset'
        checkpoint_root = 'C:/Users/Dell/PycharmProjects/PointNet++/checkpoint'
        cladataset_root = 'H:/DataSet/modelnet40_ply_hdf5_2048/ply_data_train0.h5'
        num_epochs = 10
        batch_size = 4
        num_seg = 50
        num_classes = 40
    View Code

    Model.py:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.utils.data.dataloader as Dataloader
    import torch.utils.data as data
    from configuration import config
    
    cfg = config()
    
    def index_points(points, idx):
        """
    
        Input:
            points: input points data, [B, N, C]
            idx: sample index data, [B, S],其中S可以是一个多项式
        Return:
            new_points:, indexed points data, [B, S, C]
        """
        device = points.device
        B = points.shape[0]
        view_shape = list(idx.shape)
        view_shape[1:] = [1] * (len(view_shape) - 1)
        repeat_shape = list(idx.shape)
        repeat_shape[0] = 1
        batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
        new_points = points[batch_indices, idx, :]
        return new_points
    
    
    
    def farthest_sample(data, sample_n):
        #data = data.transpose(1, 2)
        B, N, C = data.size()
        B_list = [i for i in range(B)]
        #dis标记每个点到采样点集合的距离
        dis = torch.ones((B, N)) * 1e-10
        #初始随机选一个点
        u = torch.randint(0, N, (B, ))
        vis = torch.ones((B, N))
        #vis标记某点是否被选
        vis[B_list, u] = 0
        ret = torch.zeros((B, sample_n), dtype = torch.long)
        for i in range(sample_n):
            ret[:, i] = u
            #取出点u
            cen = data[B_list, u, :].view(B, 1, 3)
            #求所有点到点u的距离
            distance = torch.sum((data - cen) ** 2, -1)
            #如果存在某个点到u的距离大于该点到采样集合其它点的距离,并且该点不在采样集合中,则更新
            idx = torch.logical_and(dis < distance, vis)
            dis[idx] = distance[idx]
            u = torch.max(dis, -1)[1]
            #如果某点被选
            dis[B_list, u] = 0
            vis[B_list, u] = 0
        return index_points(data, ret)
    
    def square_distance(data, center):
        #input: pointcloud_data center
        #output: dis[bs][sample_n][N], dis[i][j][k] 第i个点云的第j个中心到第k个点的距离
        N = data.size()[1]
        B, sample_n,_ = center.size()
        bl = torch.arange(B, dtype = torch.long)
        nl = torch.arange(N, dtype = torch.long)
        dis = torch.zeros((B, sample_n, N))
        for i in range(sample_n):
            coor = center[:, i, :].view(B, 1, 3)
            value = torch.sum((data - coor) ** 2, -1)
            dis[:, i, :] = value
        return dis
    
    
    def query_ball_point(data, dis, radius, k):
        #input:点云、dis[B][n][N]是每个点离采样中心点的距离、radius半径、k每组取半径内的k个点
        #output:res[B][n][k][3],B个点云,每个点云n个组,每个组k个点的坐标
        B, n, N = dis.size()
        group = torch.zeros([B, n, N], dtype=torch.long)
        B_list = torch.arange(B, dtype=torch.long)
        group[:, :, :] = torch.arange(N, dtype=torch.long)
        # 将不在半径范围内的点赋值为N,并从小到大排序,取前nsample个,group_idx[i, j, k]表示是的是第i个点云,第j个区域的第k个邻近点是谁
        # 这里存的是下标
        group[dis > radius ** 2] = N
        group = group.sort(dim=-1)[0][:, :, : k]
        # 以下三行就是如果半径范围内的点不够需要的数量,则赋值为离center最近的点,就是center自己吧
        idx = group == N
        defa = group[:, :, 0].view(B, n, 1).repeat(1, 1, k)
        group[idx] = defa[idx]
    
        return group
    
    
    # def cal_coor(data, group):
    #     B, n, k = group.size()
    #     B_list = torch.arange(B, dtype = torch.long)
    #     res = torch.zeros([B, n, k, 3])
    #     for i in range(n):
    #         for j in range(k):
    #             res[:, i, j, :] = data[B_list, group[:, i, j]]
    #     return res
    
    def sample_and_group(xyz, point, radius, nsample, dis):
        group = query_ball_point(xyz, dis, radius, nsample)
        new_xyz = index_points(xyz, group) # B * n * k * 3
        new_point = index_points(point, group) # B * n * k * C
        return new_xyz, new_point
    
    
    
    
    class PointNetSetAbstractionMsg(nn.Module):
        #npoint采样点个数,radius_list半径列表,nsample_list每个半径所取的邻近点个数,in_channel输入通道数,mlp_list网络列表
        def __init__(self, npoint = None, radius_list = None, nsample_list = None, in_channel = None, mlp_list = None):
            super().__init__()
            self.npoint = npoint
            self.radius_list = radius_list
            self.nsample_list = nsample_list
            self.conv_blocks = nn.ModuleList()
            self.bn_blocks = nn.ModuleList()
            for i in range(len(mlp_list)):
                convs = nn.ModuleList()
                bns = nn.ModuleList()
                last_channel = in_channel
                for j in range(len(mlp_list[i])):
                    convs.append(nn.Conv2d(last_channel, mlp_list[i][j], 1)) #B * n * k * C所以用conv2d,论文中说是全连接层,这里先用1 * 1 的卷积
                    bns.append(nn.BatchNorm2d(mlp_list[i][j]))
                    last_channel = mlp_list[i][j]
                self.conv_blocks.append(convs)
                self.bn_blocks.append(bns)
        #xyz是每层的点的坐标,point其实是每层的特征(初始 = xyz);主要就是用xyz求出分组下标,然后给point分组,然后卷积
        #二维图像只需要传特征就行,因为像素在tensor中相邻,在实际坐标系中肯定也相邻,而点云不一定
        def forward(self, xyz, point):
            xyz = xyz.transpose(1, 2)
            if point is not None :
                point = point.transpose(1, 2)
            if self.npoint is not None:
                center = farthest_sample(xyz, self.npoint)
                dis = square_distance(xyz, center)
                point_list = []
                for i in range(len(self.radius_list)):
                    new_xyz, new_point = sample_and_group(xyz, point, self.radius_list[i], self.nsample_list[i], dis)
                    new_point = new_point.permute(0, 3, 2, 1)
                    for j in range(len(self.conv_blocks[i])):
                        conv = self.conv_blocks[i][j]
                        bn = self.bn_blocks[i][j]
                        new_point = F.relu(bn(conv(new_point)))
                    new_point = torch.max(new_point, dim=2)[0]  # B * C‘ * n,相当于maxpool,每组中用最大的特征表示该组
                    point_list.append(new_point)
                new_point = torch.cat(point_list, dim = 1) #Msg中将不同半径所生成的向量整合
                center = center.transpose(2, 1)
                return center, new_point
            else:
                center = torch.zeros(xyz.size(0), 1, xyz.size(2))
                new_point = point.view(point.size(0), 1, point.size(1), point.size(2))
                new_point = new_point.permute(0, 3, 2, 1)
                for j in range(len(self.conv_blocks[0])):
                    conv = self.conv_blocks[0][j]
                    bn = self.bn_blocks[0][j]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim = 2)[0]
                center = center.transpose(1, 2)
    
                return center, new_point
    
    class PointNetSetAbstraction(nn.Module):
        def __init__(self, npoint = None, radius = None, nsample = None, in_channel = None, mlp = None, group_all = False):
            super(PointNetSetAbstraction, self).__init__()
            self.group_all = group_all
            if self.group_all == False:
                self.npoint = npoint
                self.radius = radius
                self.nsample = nsample
            self.conv_block = nn.ModuleList()
            self.bn_block = nn.ModuleList()
            last_channel = in_channel
            for i in range(len(mlp)):
                self.conv_block.append(nn.Conv2d(last_channel, mlp[i], 1))
                self.bn_block.append(nn.BatchNorm2d(mlp[i]))
                last_channel = mlp[i]
        def forward(self, xyz, point):
            xyz = xyz.transpose(1, 2)
            point = point.transpose(1 ,2)
            if self.group_all == False:
                center = farthest_sample(xyz, self.npoint)
                dis = square_distance(xyz, center)
                new_xyz, new_point = sample_and_group(xyz, point, self.radius, self.nsample, dis)
                new_point = new_point.permute(0, 3, 2, 1)
                for i in range(len(self.conv_block)):
                    conv = self.conv_block[i]
                    bn = self.bn_block[i]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim = 2)[0]
                center = center.transpose(1, 2)
                return center, new_point
            else:
                new_point = point.view(point.size(0), 1, point.size(1), point.size(2)).permute(0, 3, 2, 1)
                for i in range(len(self.conv_block)):
                    conv = self.conv_block[i]
                    bn = self.bn_block[i]
                    new_point = F.relu(bn(conv(new_point)))
                new_point = torch.max(new_point, dim = 2)[0]
                center = torch.zeros(xyz.size(0), 1, xyz.size(2))
                center = center.transpose(1, 2)
                return center, new_point
    
    
    
    
    class PointNetFeaturePropagation(nn.Module):
        def __init__(self, inchannel, mlp_list):
            super().__init__()
            self.conv_blocks = nn.ModuleList()
            self.bn_blocks = nn.ModuleList()
            last_channel = inchannel
            for i in range(len(mlp_list)):
                self.conv_blocks.append(nn.Conv1d(last_channel, mlp_list[i], 1))
                self.bn_blocks.append(nn.BatchNorm1d(mlp_list[i]))
                last_channel = mlp_list[i]
        def forward(self, xyz1, xyz2, point1, point2):
            #由xyz2的特征point2推出xyz1中每个点的特征,并与point1连接后minipointnet
            xyz1 = xyz1.transpose(1, 2)
            xyz2 = xyz2.transpose(1, 2)
            point2 = point2.transpose(1, 2)
            B, N, _ = xyz1.size()
            _, n, _ = xyz2.size()
            if n == 1:
                point2 = point2.repeat(1, N, 1)
            else:
                dis, idx = square_distance(xyz2, xyz1).sort(dim = -1)
                dis = dis[:, :, 1 : 4] #B * N * 3,3个最近的点的距离的平方,因为求dis的时候没开方
                idx = idx[:, :, 1 : 4] #三个最近的点的原坐标
                feature = index_points(point2, idx) # B * N * 3 * D
                w = 1.0 / (dis + 1e-8) # B * N * 3
                ep_w = torch.sum(w, dim = -1, keepdim = True) # B * N * 1
                ratio = w / ep_w # B * N * 3
                #每个系数都 * 相应的特征,并将每一维的三个特征加和
                point2 = torch.sum(feature * ratio.view(B, N , 3, 1), dim = 2) # B * N * D
            if point1 is not None:
                point1 = point1.transpose(1, 2)
                point = torch.cat([point1, point2], dim = -1)
            else:
                point = point2
            point = point.transpose(1, 2)
            for i in range(len(self.conv_blocks)):
                conv = self.conv_blocks[i]
                bn = self.bn_blocks[i]
                point = F.relu(bn(conv(point)))
            return point
    
    class PointNet_add_cla(nn.Module):
        def __init__(self):
            super(PointNet_add_cla, self).__init__()
            self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128])
            self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
            self.sa3 = PointNetSetAbstraction(in_channel = 256, mlp = [256, 512, 1024], group_all = True)
            self.fc1 = nn.Linear(1024, 512)
            self.fc2 = nn.Linear(512, 256)
            self.fc3 = nn.Linear(256, cfg.num_classes)
    
        def forward(self, xyz):
            xyz1, point1 = self.sa1(xyz, xyz)
            xyz2, point2 = self.sa2(xyz1, point1)
            xyz3, point3 = self.sa3(xyz2, point2)
            point = point3.view(point3.size(0), -1)
            point = F.dropout(F.relu(self.fc1(point)), 0.5)
            point = F.dropout(F.relu(self.fc2(point)), 0.5)
            point = self.fc3(point)
            return point
    
    
    
    
    class PointNet_add(nn.Module):
        def __init__(self):
            super(PointNet_add, self).__init__()
            self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128])
            self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
            self.sa3 = PointNetSetAbstraction(in_channel = 256, mlp = [256, 512, 1024], group_all = True)
            self.fp1 = PointNetFeaturePropagation(1280, [256, 256])
            self.fp2 = PointNetFeaturePropagation(384, [256, 128])
            self.fp3 = PointNetFeaturePropagation(131, [128, 128])
            self.conv1 = nn.Conv1d(128, 128, 1)
            self.bn1 = nn.BatchNorm1d(128)
            self.conv2 = nn.Conv1d(128, 128, 1)
            self.bn2 = nn.BatchNorm1d(128)
            self.conv3 = nn.Conv1d(128, cfg.num_seg, 1)
    
        def forward(self, xyz):
            xyz1, point1 = self.sa1(xyz, xyz)
            xyz2, point2 = self.sa2(xyz1, point1)
            xyz3, point3 = self.sa3(xyz2, point2)
            point2 = self.fp1(xyz2, xyz3, point2, point3)
            # print(point2)
            point1 = self.fp2(xyz1, xyz2, point1, point2)
            # print(point1)
            point0 = self.fp3(xyz, xyz1, xyz, point1)
            point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5)
            point = F.dropout(F.relu(self.bn2(self.conv2(point))), 0.5)
            point = self.conv3(point)
            point = point.transpose(1, 2).contiguous()
            point = point.view(-1, cfg.num_seg)
            return point
    
    
    
    
    class PointNet_add_Msg(nn.Module):
        def __init__(self):
            super().__init__()
            self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
            self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 64 + 128 + 128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]])
            self.sa3 = PointNetSetAbstractionMsg(in_channel = 128 + 256 + 256 , mlp_list = [[256, 512, 1024]])
            self.fp1 = PointNetFeaturePropagation(1664, [256, 256])
            self.fp2 = PointNetFeaturePropagation(576, [256, 128])
            self.fp3 = PointNetFeaturePropagation(131, [128, 128])
            self.conv1 = nn.Conv1d(128, 128, 1)
            self.conv2 = nn.Conv1d(128, cfg.num_seg, 1)
    
            self.bn1 = nn.BatchNorm1d(128)
    
        def forward(self, xyz):
            xyz1, point1 = self.sa1(xyz, xyz)
            xyz2, point2 = self.sa2(xyz1, point1)
            xyz3, point3 = self.sa3(xyz2, point2)
            point2 = self.fp1(xyz2, xyz3, point2, point3)
            point1 = self.fp2(xyz1, xyz2, point1, point2)
            point0 = self.fp3(xyz, xyz1, xyz, point1)
            point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5)
            point = self.conv2(point)
            point = point.transpose(1, 2).contiguous()
            point = point.view(-1, cfg.num_seg)
            return point
    View Code

    train.py:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from tqdm import tqdm
    from DataSet import Dataset, claDataset
    from Model import PointNet_add_Msg, PointNet_add, PointNet_add_cla
    from configuration import config
    import torch.utils.data.dataloader as DataLoader
    from tensorboardX import SummaryWriter
    import os
    
    
    cfg = config()
    
    if __name__ == '__main__':
    
        # model = PointNet_add_Msg()
        # model = PointNet_add()
        model = PointNet_add_cla()
        model.to(cfg.device)
        # dataset = Dataset(cfg.dataset_root)
        dataset = claDataset(cfg.cladataset_root)
        dataloader = DataLoader.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = True)
        optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
        loss = nn.CrossEntropyLoss()
        tbwrite = SummaryWriter(logdir = os.path.join(cfg.checkpoint_root, 'log'))
        model.train()
        for epoch in range(cfg.num_epochs):
            total_true = 0
            total_loss = 0
            cnt = 0
            for xyz, label in tqdm(dataloader):
                optimizer.zero_grad()
                output = model(xyz)
                label = label.view(-1, 1)[:, 0]
                loss_value = loss(output, label)
                loss_value.backward()
                optimizer.step()
                pred = torch.max(output, -1)[1]
                total_true += torch.sum(pred == label)
                total_loss += loss_value
                cnt += 1
            mean_loss = total_loss / float(cnt)
            # accuracy = total_true / float(len(dataset) * 2500)
            accuracy = total_true / float(len(dataset))
            tbwrite.add_scalar('Loss', mean_loss, epoch)
            tbwrite.add_scalar('Accuracy', accuracy, epoch)
            print('mean_loss:{:.4f}, accuracy:{:.4f}'.format(mean_loss, accuracy))
            if (epoch + 1) % cfg.num_epochs == 0:
                state = {
                    'model': model.state_dict()
                }
                torch.save(state, os.path.join(cfg.checkpoint_root, 'checkpoint_{}.pth'.format(epoch + 1)))
    View Code
  • 相关阅读:
    [EOJ]2019 ECNU XCPC March Selection #1
    [模板]宏定义
    [POJ]poj1961,poj2406(KMP)
    [模板]KMP
    [CF]Avito Cool Challenge 2018
    [CF]Codeforces Round #528 Div.2
    [POJ]POJ1328(贪心)
    洛谷 P3808 【模板】AC自动机(简单版) 题解
    中科院的难题 题解
    【转】洛谷 P3722 [AH2017/HNOI2017]影魔 题解
  • 原文地址:https://www.cnblogs.com/WTSRUVF/p/15412898.html
Copyright © 2011-2022 走看看