msg的训练的巨慢,不过总算是复现出来了,之后搞一下MRG的,或者裸的
configuration.py:
import torch.cuda class config(): device = 'cuda' if torch.cuda.is_available() else 'cpu' dataset_root = 'C:/Users/Dell/PycharmProjects/PointNet++/dataset' checkpoint_root = 'C:/Users/Dell/PycharmProjects/PointNet++/checkpoint' num_epochs = 10 batch_size = 4 num_seg = 40
Model.py:
import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data.dataloader as Dataloader import torch.utils.data as data from configuration import config cfg = config() def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S],其中S可以是一个多项式 Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_sample(data, sample_n): #data = data.transpose(1, 2) B, N, C = data.size() B_list = [i for i in range(B)] #dis标记每个点到采样点集合的距离 dis = torch.ones((B, N)) * 1e-10 #初始随机选一个点 u = torch.randint(0, N, (B, )) vis = torch.ones((B, N)) #vis标记某点是否被选 vis[B_list, u] = 0 ret = torch.zeros((B, sample_n), dtype = torch.long) for i in range(sample_n): ret[:, i] = u #取出点u cen = data[B_list, u, :].view(B, 1, 3) #求所有点到点u的距离 distance = torch.sum((data - cen) ** 2, -1) #如果存在某个点到u的距离大于该点到采样集合其它点的距离,并且该点不在采样集合中,则更新 idx = torch.logical_and(dis < distance, vis) dis[idx] = distance[idx] u = torch.max(dis, -1)[1] #如果某点被选 dis[B_list, u] = 0 vis[B_list, u] = 0 return index_points(data, ret) def square_distance(data, center): #input: pointcloud_data center #output: dis[bs][sample_n][N], dis[i][j][k] 第i个点云的第j个中心到第k个点的距离 N = data.size()[1] B, sample_n,_ = center.size() bl = torch.arange(B, dtype = torch.long) nl = torch.arange(N, dtype = torch.long) dis = torch.zeros((B, sample_n, N)) for i in range(sample_n): coor = center[:, i, :].view(B, 1, 3) value = torch.sum((data - coor) ** 2, -1) dis[:, i, :] = value return dis def query_ball_point(data, dis, radius, k): #input:点云、dis[B][n][N]是每个点离采样中心点的距离、radius半径、k每组取半径内的k个点 #output:res[B][n][k][3],B个点云,每个点云n个组,每个组k个点的坐标 B, n, N = dis.size() group = torch.zeros([B, n, N], dtype=torch.long) B_list = torch.arange(B, dtype=torch.long) group[:, :, :] = torch.arange(N, dtype=torch.long) # 将不在半径范围内的点赋值为N,并从小到大排序,取前nsample个,group_idx[i, j, k]表示是的是第i个点云,第j个区域的第k个邻近点是谁 # 这里存的是下标 group[dis > radius ** 2] = N group = group.sort(dim=-1)[0][:, :, : k] # 以下三行就是如果半径范围内的点不够需要的数量,则赋值为离center最近的点,就是center自己吧 idx = group == N defa = group[:, :, 0].view(B, n, 1).repeat(1, 1, k) group[idx] = defa[idx] return group # def cal_coor(data, group): # B, n, k = group.size() # B_list = torch.arange(B, dtype = torch.long) # res = torch.zeros([B, n, k, 3]) # for i in range(n): # for j in range(k): # res[:, i, j, :] = data[B_list, group[:, i, j]] # return res def sample_and_group(xyz, point, radius, nsample, dis): group = query_ball_point(xyz, dis, radius, nsample) new_xyz = index_points(xyz, group) # B * n * k * 3 new_point = index_points(point, group) # B * n * k * C return new_xyz, new_point class PointNetSetAbstractionMsg(nn.Module): #npoint采样点个数,radius_list半径列表,nsample_list每个半径所取的邻近点个数,in_channel输入通道数,mlp_list网络列表 def __init__(self, npoint = None, radius_list = None, nsample_list = None, in_channel = None, mlp_list = None): super().__init__() self.npoint = npoint self.radius_list = radius_list self.nsample_list = nsample_list self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() for i in range(len(mlp_list)): convs = nn.ModuleList() bns = nn.ModuleList() last_channel = in_channel for j in range(len(mlp_list[i])): convs.append(nn.Conv2d(last_channel, mlp_list[i][j], 1)) #B * n * k * C所以用conv2d,论文中说是全连接层,这里先用1 * 1 的卷积 bns.append(nn.BatchNorm2d(mlp_list[i][j])) last_channel = mlp_list[i][j] self.conv_blocks.append(convs) self.bn_blocks.append(bns) #xyz是每层的点的坐标,point其实是每层的特征(初始 = xyz);主要就是用xyz求出分组下标,然后给point分组,然后卷积 #二维图像只需要传特征就行,因为像素在tensor中相邻,在实际坐标系中肯定也相邻,而点云不一定 def forward(self, xyz, point): xyz = xyz.transpose(1, 2) if point is not None : point = point.transpose(1, 2) if self.npoint is not None: center = farthest_sample(xyz, self.npoint) dis = square_distance(xyz, center) point_list = [] for i in range(len(self.radius_list)): new_xyz, new_point = sample_and_group(xyz, point, self.radius_list[i], self.nsample_list[i], dis) new_point = new_point.permute(0, 3, 2, 1) for j in range(len(self.conv_blocks[i])): conv = self.conv_blocks[i][j] bn = self.bn_blocks[i][j] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim=2)[0] # B * C‘ * n,相当于maxpool,每组中用最大的特征表示该组 point_list.append(new_point) new_point = torch.cat(point_list, dim = 1) #Msg中将不同半径所生成的向量整合 center = center.transpose(2, 1) return center, new_point else: center = torch.zeros(xyz.size(0), 1, xyz.size(2)) new_point = point.view(point.size(0), 1, point.size(1), point.size(2)) new_point = new_point.permute(0, 3, 2, 1) for j in range(len(self.conv_blocks[0])): conv = self.conv_blocks[0][j] bn = self.bn_blocks[0][j] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim = 2)[0] center = center.transpose(1, 2) return center, new_point class PointNetSetAbstraction(nn.Module): def __init__(self, npoint = None, radius = None, nsample = None, in_channel = None, mlp = None, group_all = False): super(PointNetSetAbstraction, self).__init__() self.group_all = group_all if self.group_all == False: self.npoint = npoint self.radius = radius self.nsample = nsample self.conv_block = nn.ModuleList() self.bn_block = nn.ModuleList() last_channel = in_channel for i in range(len(mlp)): self.conv_block.append(nn.Conv2d(last_channel, mlp[i], 1)) self.bn_block.append(nn.BatchNorm2d(mlp[i])) last_channel = mlp[i] def forward(self, xyz, point): xyz = xyz.transpose(1, 2) point = point.transpose(1 ,2) if self.group_all == False: center = farthest_sample(xyz, self.npoint) dis = square_distance(xyz, center) new_xyz, new_point = sample_and_group(xyz, point, self.radius, self.nsample, dis) new_point = new_point.permute(0, 3, 2, 1) for i in range(len(self.conv_block)): conv = self.conv_block[i] bn = self.bn_block[i] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim = 2)[0] center = center.transpose(1, 2) return center, new_point else: new_point = point.view(point.size(0), 1, point.size(1), point.size(2)).permute(0, 3, 2, 1) for i in range(len(self.conv_block)): conv = self.conv_block[i] bn = self.bn_block[i] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim = 2)[0] center = torch.zeros(xyz.size(0), 1, xyz.size(2)) center = center.transpose(1, 2) return center, new_point class PointNetFeaturePropagation(nn.Module): def __init__(self, inchannel, mlp_list): super().__init__() self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() last_channel = inchannel for i in range(len(mlp_list)): self.conv_blocks.append(nn.Conv1d(last_channel, mlp_list[i], 1)) self.bn_blocks.append(nn.BatchNorm1d(mlp_list[i])) last_channel = mlp_list[i] def forward(self, xyz1, xyz2, point1, point2): #由xyz2的特征point2推出xyz1中每个点的特征,并与point1连接后minipointnet xyz1 = xyz1.transpose(1, 2) xyz2 = xyz2.transpose(1, 2) point2 = point2.transpose(1, 2) B, N, _ = xyz1.size() _, n, _ = xyz2.size() if n == 1: point2 = point2.repeat(1, N, 1) else: dis, idx = square_distance(xyz2, xyz1).sort(dim = -1) dis = dis[:, :, 1 : 4] #B * N * 3,3个最近的点的距离的平方,因为求dis的时候没开方 idx = idx[:, :, 1 : 4] #三个最近的点的原坐标 feature = index_points(point2, idx) # B * N * 3 * D w = 1.0 / (dis + 1e-8) # B * N * 3 ep_w = torch.sum(w, dim = -1, keepdim = True) # B * N * 1 ratio = w / ep_w # B * N * 3 #每个系数都 * 相应的特征,并将每一维的三个特征加和 point2 = torch.sum(feature * ratio.view(B, N , 3, 1), dim = 2) # B * N * D if point1 is not None: point1 = point1.transpose(1, 2) point = torch.cat([point1, point2], dim = -1) else: point = point2 point = point.transpose(1, 2) for i in range(len(self.conv_blocks)): conv = self.conv_blocks[i] bn = self.bn_blocks[i] point = F.relu(bn(conv(point))) return point class PointNet_add_Msg(nn.Module): def __init__(self): super().__init__() self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 64 + 128 + 128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) self.sa3 = PointNetSetAbstractionMsg(in_channel = 128 + 256 + 256 , mlp_list = [[256, 512, 1024]]) self.fp1 = PointNetFeaturePropagation(1664, [256, 256]) self.fp2 = PointNetFeaturePropagation(576, [256, 128]) self.fp3 = PointNetFeaturePropagation(131, [128, 128]) self.conv1 = nn.Conv1d(128, 128, 1) self.conv2 = nn.Conv1d(128, cfg.num_seg, 1) self.bn1 = nn.BatchNorm1d(128) def forward(self, xyz): xyz1, point1 = self.sa1(xyz, xyz) xyz2, point2 = self.sa2(xyz1, point1) xyz3, point3 = self.sa3(xyz2, point2) point2 = self.fp1(xyz2, xyz3, point2, point3) point1 = self.fp2(xyz1, xyz2, point1, point2) point0 = self.fp3(xyz, xyz1, xyz, point1) point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5) point = self.conv2(point) point = point.transpose(1, 2).contiguous() point = point.view(-1, cfg.num_seg) return point
train.py:
import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from DataSet import Dataset from Model import PointNet_add from configuration import config import torch.utils.data.dataloader as DataLoader from tensorboardX import SummaryWriter import os cfg = config() if __name__ == '__main__': model = PointNet_add() model.to(cfg.device) dataset = Dataset(cfg.dataset_root) dataloader = DataLoader.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = True) optimizer = torch.optim.Adam(model.parameters(), lr = 0.001) loss = nn.CrossEntropyLoss() tbwrite = SummaryWriter(logdir = os.path.join(cfg.checkpoint_root, 'log')) model.train() for epoch in range(cfg.num_epochs): total_true = 0 total_loss = 0 cnt = 0 for xyz, label in tqdm(dataloader): optimizer.zero_grad() output = model(xyz) label = label.view(-1, 1)[:, 0] loss_value = loss(output, label) loss_value.backward() optimizer.step() pred = torch.max(output, -1)[1] total_true += torch.sum(pred == label) total_loss += loss_value.item() cnt += 1 mean_loss = total_loss / float(cnt) accuracy = total_true / float(len(dataset)) tbwrite.add_scalar('Loss', mean_loss, epoch) tbwrite.add_scalar('Accuracy', accuracy, epoch) print('mean_loss:{:.4f}, accuracy:{:.4f}'.format(mean_loss, accuracy)) if (epoch + 1) % cfg.num_epochs == 0: state = { 'model': model.state_dict() } torch.save(state, os.path.join(cfg.checkpoint_root, 'checkpoint_{}.pth'.format(epoch)))
DataSet.py:
import numpy as np import torch.utils.data as data import os import random import torch class Dataset(data.Dataset): def __init__(self, root): super().__init__() self.root = root data_list = os.listdir(os.path.join(root, 'points')) label_list = os.listdir(os.path.join(root, 'points_label')) self.data_list = sorted(data_list, key = lambda x : int(x.split('.')[0])) self.label_list = sorted(label_list, key = lambda x : int(x.split('.')[0])) def __getitem__(self, index): self.data = np.loadtxt(os.path.join(self.root, 'points', self.data_list[index])) self.label = np.loadtxt(os.path.join(self.root, 'points_label', self.label_list[index])) #采样2500个点,如果不够,则随机抽样补全 if self.data.shape[0] >= 2500: sample_list = random.sample(range(self.data.shape[0]), 2500) self.data = self.data[sample_list, :] self.label = self.label[sample_list] else: sample_list = random.sample(range(self.data.shape[0]), 2500 - self.data.shape[0]) dup_data = self.data[sample_list, :] dup_label = self.label[sample_list] self.data = np.concatenate([self.data, dup_data], 0) self.label = np.concatenate([self.label, dup_label], 0) self.label = torch.tensor(self.label) self.label = self.label.type(torch.LongTensor) self.data = torch.tensor(self.data.T) #label要是Longtensor,data要是float32 self.data = self.data.to(torch.float32) return self.data, self.label def __len__(self): return len(self.label_list)
以下是分类和分割的结合代码,有普通的和msg的
Dataset.py:
import numpy as np import torch.utils.data as data import os import random import torch import h5py class Dataset(data.Dataset): def __init__(self, root): super().__init__() self.root = root data_list = os.listdir(os.path.join(root, 'points')) label_list = os.listdir(os.path.join(root, 'points_label')) self.data_list = sorted(data_list, key = lambda x : int(x.split('.')[0])) self.label_list = sorted(label_list, key = lambda x : int(x.split('.')[0])) def __getitem__(self, index): self.data = np.loadtxt(os.path.join(self.root, 'points', self.data_list[index])) self.label = np.loadtxt(os.path.join(self.root, 'points_label', self.label_list[index])) #采样2500个点,如果不够,则随机抽样补全 if self.data.shape[0] >= 2500: sample_list = random.sample(range(self.data.shape[0]), 2500) self.data = self.data[sample_list, :] self.label = self.label[sample_list] else: sample_list = random.sample(range(self.data.shape[0]), 2500 - self.data.shape[0]) dup_data = self.data[sample_list, :] dup_label = self.label[sample_list] self.data = np.concatenate([self.data, dup_data], 0) self.label = np.concatenate([self.label, dup_label], 0) self.label = torch.tensor(self.label) self.label = self.label.type(torch.LongTensor) self.data = torch.tensor(self.data.T) #label要是Longtensor,data要是float32 self.data = self.data.to(torch.float32) return self.data, self.label def __len__(self): return len(self.label_list) class claDataset(): def __init__(self, root): super(claDataset, self).__init__() dataset = h5py.File(root, 'r') self.data = dataset['data'][:] self.label = dataset['label'][:][:, 0] def __getitem__(self, index): label = torch.tensor(self.label[index]) label = label.type(torch.LongTensor) return torch.tensor(self.data[index].T), label def __len__(self): return len(self.label)
configuration.py:
import torch.cuda class config(): device = 'cuda' if torch.cuda.is_available() else 'cpu' dataset_root = 'C:/Users/Dell/PycharmProjects/PointNet++/dataset' checkpoint_root = 'C:/Users/Dell/PycharmProjects/PointNet++/checkpoint' cladataset_root = 'H:/DataSet/modelnet40_ply_hdf5_2048/ply_data_train0.h5' num_epochs = 10 batch_size = 4 num_seg = 50 num_classes = 40
Model.py:
import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data.dataloader as Dataloader import torch.utils.data as data from configuration import config cfg = config() def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S],其中S可以是一个多项式 Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_sample(data, sample_n): #data = data.transpose(1, 2) B, N, C = data.size() B_list = [i for i in range(B)] #dis标记每个点到采样点集合的距离 dis = torch.ones((B, N)) * 1e-10 #初始随机选一个点 u = torch.randint(0, N, (B, )) vis = torch.ones((B, N)) #vis标记某点是否被选 vis[B_list, u] = 0 ret = torch.zeros((B, sample_n), dtype = torch.long) for i in range(sample_n): ret[:, i] = u #取出点u cen = data[B_list, u, :].view(B, 1, 3) #求所有点到点u的距离 distance = torch.sum((data - cen) ** 2, -1) #如果存在某个点到u的距离大于该点到采样集合其它点的距离,并且该点不在采样集合中,则更新 idx = torch.logical_and(dis < distance, vis) dis[idx] = distance[idx] u = torch.max(dis, -1)[1] #如果某点被选 dis[B_list, u] = 0 vis[B_list, u] = 0 return index_points(data, ret) def square_distance(data, center): #input: pointcloud_data center #output: dis[bs][sample_n][N], dis[i][j][k] 第i个点云的第j个中心到第k个点的距离 N = data.size()[1] B, sample_n,_ = center.size() bl = torch.arange(B, dtype = torch.long) nl = torch.arange(N, dtype = torch.long) dis = torch.zeros((B, sample_n, N)) for i in range(sample_n): coor = center[:, i, :].view(B, 1, 3) value = torch.sum((data - coor) ** 2, -1) dis[:, i, :] = value return dis def query_ball_point(data, dis, radius, k): #input:点云、dis[B][n][N]是每个点离采样中心点的距离、radius半径、k每组取半径内的k个点 #output:res[B][n][k][3],B个点云,每个点云n个组,每个组k个点的坐标 B, n, N = dis.size() group = torch.zeros([B, n, N], dtype=torch.long) B_list = torch.arange(B, dtype=torch.long) group[:, :, :] = torch.arange(N, dtype=torch.long) # 将不在半径范围内的点赋值为N,并从小到大排序,取前nsample个,group_idx[i, j, k]表示是的是第i个点云,第j个区域的第k个邻近点是谁 # 这里存的是下标 group[dis > radius ** 2] = N group = group.sort(dim=-1)[0][:, :, : k] # 以下三行就是如果半径范围内的点不够需要的数量,则赋值为离center最近的点,就是center自己吧 idx = group == N defa = group[:, :, 0].view(B, n, 1).repeat(1, 1, k) group[idx] = defa[idx] return group # def cal_coor(data, group): # B, n, k = group.size() # B_list = torch.arange(B, dtype = torch.long) # res = torch.zeros([B, n, k, 3]) # for i in range(n): # for j in range(k): # res[:, i, j, :] = data[B_list, group[:, i, j]] # return res def sample_and_group(xyz, point, radius, nsample, dis): group = query_ball_point(xyz, dis, radius, nsample) new_xyz = index_points(xyz, group) # B * n * k * 3 new_point = index_points(point, group) # B * n * k * C return new_xyz, new_point class PointNetSetAbstractionMsg(nn.Module): #npoint采样点个数,radius_list半径列表,nsample_list每个半径所取的邻近点个数,in_channel输入通道数,mlp_list网络列表 def __init__(self, npoint = None, radius_list = None, nsample_list = None, in_channel = None, mlp_list = None): super().__init__() self.npoint = npoint self.radius_list = radius_list self.nsample_list = nsample_list self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() for i in range(len(mlp_list)): convs = nn.ModuleList() bns = nn.ModuleList() last_channel = in_channel for j in range(len(mlp_list[i])): convs.append(nn.Conv2d(last_channel, mlp_list[i][j], 1)) #B * n * k * C所以用conv2d,论文中说是全连接层,这里先用1 * 1 的卷积 bns.append(nn.BatchNorm2d(mlp_list[i][j])) last_channel = mlp_list[i][j] self.conv_blocks.append(convs) self.bn_blocks.append(bns) #xyz是每层的点的坐标,point其实是每层的特征(初始 = xyz);主要就是用xyz求出分组下标,然后给point分组,然后卷积 #二维图像只需要传特征就行,因为像素在tensor中相邻,在实际坐标系中肯定也相邻,而点云不一定 def forward(self, xyz, point): xyz = xyz.transpose(1, 2) if point is not None : point = point.transpose(1, 2) if self.npoint is not None: center = farthest_sample(xyz, self.npoint) dis = square_distance(xyz, center) point_list = [] for i in range(len(self.radius_list)): new_xyz, new_point = sample_and_group(xyz, point, self.radius_list[i], self.nsample_list[i], dis) new_point = new_point.permute(0, 3, 2, 1) for j in range(len(self.conv_blocks[i])): conv = self.conv_blocks[i][j] bn = self.bn_blocks[i][j] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim=2)[0] # B * C‘ * n,相当于maxpool,每组中用最大的特征表示该组 point_list.append(new_point) new_point = torch.cat(point_list, dim = 1) #Msg中将不同半径所生成的向量整合 center = center.transpose(2, 1) return center, new_point else: center = torch.zeros(xyz.size(0), 1, xyz.size(2)) new_point = point.view(point.size(0), 1, point.size(1), point.size(2)) new_point = new_point.permute(0, 3, 2, 1) for j in range(len(self.conv_blocks[0])): conv = self.conv_blocks[0][j] bn = self.bn_blocks[0][j] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim = 2)[0] center = center.transpose(1, 2) return center, new_point class PointNetSetAbstraction(nn.Module): def __init__(self, npoint = None, radius = None, nsample = None, in_channel = None, mlp = None, group_all = False): super(PointNetSetAbstraction, self).__init__() self.group_all = group_all if self.group_all == False: self.npoint = npoint self.radius = radius self.nsample = nsample self.conv_block = nn.ModuleList() self.bn_block = nn.ModuleList() last_channel = in_channel for i in range(len(mlp)): self.conv_block.append(nn.Conv2d(last_channel, mlp[i], 1)) self.bn_block.append(nn.BatchNorm2d(mlp[i])) last_channel = mlp[i] def forward(self, xyz, point): xyz = xyz.transpose(1, 2) point = point.transpose(1 ,2) if self.group_all == False: center = farthest_sample(xyz, self.npoint) dis = square_distance(xyz, center) new_xyz, new_point = sample_and_group(xyz, point, self.radius, self.nsample, dis) new_point = new_point.permute(0, 3, 2, 1) for i in range(len(self.conv_block)): conv = self.conv_block[i] bn = self.bn_block[i] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim = 2)[0] center = center.transpose(1, 2) return center, new_point else: new_point = point.view(point.size(0), 1, point.size(1), point.size(2)).permute(0, 3, 2, 1) for i in range(len(self.conv_block)): conv = self.conv_block[i] bn = self.bn_block[i] new_point = F.relu(bn(conv(new_point))) new_point = torch.max(new_point, dim = 2)[0] center = torch.zeros(xyz.size(0), 1, xyz.size(2)) center = center.transpose(1, 2) return center, new_point class PointNetFeaturePropagation(nn.Module): def __init__(self, inchannel, mlp_list): super().__init__() self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() last_channel = inchannel for i in range(len(mlp_list)): self.conv_blocks.append(nn.Conv1d(last_channel, mlp_list[i], 1)) self.bn_blocks.append(nn.BatchNorm1d(mlp_list[i])) last_channel = mlp_list[i] def forward(self, xyz1, xyz2, point1, point2): #由xyz2的特征point2推出xyz1中每个点的特征,并与point1连接后minipointnet xyz1 = xyz1.transpose(1, 2) xyz2 = xyz2.transpose(1, 2) point2 = point2.transpose(1, 2) B, N, _ = xyz1.size() _, n, _ = xyz2.size() if n == 1: point2 = point2.repeat(1, N, 1) else: dis, idx = square_distance(xyz2, xyz1).sort(dim = -1) dis = dis[:, :, 1 : 4] #B * N * 3,3个最近的点的距离的平方,因为求dis的时候没开方 idx = idx[:, :, 1 : 4] #三个最近的点的原坐标 feature = index_points(point2, idx) # B * N * 3 * D w = 1.0 / (dis + 1e-8) # B * N * 3 ep_w = torch.sum(w, dim = -1, keepdim = True) # B * N * 1 ratio = w / ep_w # B * N * 3 #每个系数都 * 相应的特征,并将每一维的三个特征加和 point2 = torch.sum(feature * ratio.view(B, N , 3, 1), dim = 2) # B * N * D if point1 is not None: point1 = point1.transpose(1, 2) point = torch.cat([point1, point2], dim = -1) else: point = point2 point = point.transpose(1, 2) for i in range(len(self.conv_blocks)): conv = self.conv_blocks[i] bn = self.bn_blocks[i] point = F.relu(bn(conv(point))) return point class PointNet_add_cla(nn.Module): def __init__(self): super(PointNet_add_cla, self).__init__() self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128]) self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128, [128, 128, 256]) self.sa3 = PointNetSetAbstraction(in_channel = 256, mlp = [256, 512, 1024], group_all = True) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, cfg.num_classes) def forward(self, xyz): xyz1, point1 = self.sa1(xyz, xyz) xyz2, point2 = self.sa2(xyz1, point1) xyz3, point3 = self.sa3(xyz2, point2) point = point3.view(point3.size(0), -1) point = F.dropout(F.relu(self.fc1(point)), 0.5) point = F.dropout(F.relu(self.fc2(point)), 0.5) point = self.fc3(point) return point class PointNet_add(nn.Module): def __init__(self): super(PointNet_add, self).__init__() self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128]) self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128, [128, 128, 256]) self.sa3 = PointNetSetAbstraction(in_channel = 256, mlp = [256, 512, 1024], group_all = True) self.fp1 = PointNetFeaturePropagation(1280, [256, 256]) self.fp2 = PointNetFeaturePropagation(384, [256, 128]) self.fp3 = PointNetFeaturePropagation(131, [128, 128]) self.conv1 = nn.Conv1d(128, 128, 1) self.bn1 = nn.BatchNorm1d(128) self.conv2 = nn.Conv1d(128, 128, 1) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128, cfg.num_seg, 1) def forward(self, xyz): xyz1, point1 = self.sa1(xyz, xyz) xyz2, point2 = self.sa2(xyz1, point1) xyz3, point3 = self.sa3(xyz2, point2) point2 = self.fp1(xyz2, xyz3, point2, point3) # print(point2) point1 = self.fp2(xyz1, xyz2, point1, point2) # print(point1) point0 = self.fp3(xyz, xyz1, xyz, point1) point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5) point = F.dropout(F.relu(self.bn2(self.conv2(point))), 0.5) point = self.conv3(point) point = point.transpose(1, 2).contiguous() point = point.view(-1, cfg.num_seg) return point class PointNet_add_Msg(nn.Module): def __init__(self): super().__init__() self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 64 + 128 + 128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) self.sa3 = PointNetSetAbstractionMsg(in_channel = 128 + 256 + 256 , mlp_list = [[256, 512, 1024]]) self.fp1 = PointNetFeaturePropagation(1664, [256, 256]) self.fp2 = PointNetFeaturePropagation(576, [256, 128]) self.fp3 = PointNetFeaturePropagation(131, [128, 128]) self.conv1 = nn.Conv1d(128, 128, 1) self.conv2 = nn.Conv1d(128, cfg.num_seg, 1) self.bn1 = nn.BatchNorm1d(128) def forward(self, xyz): xyz1, point1 = self.sa1(xyz, xyz) xyz2, point2 = self.sa2(xyz1, point1) xyz3, point3 = self.sa3(xyz2, point2) point2 = self.fp1(xyz2, xyz3, point2, point3) point1 = self.fp2(xyz1, xyz2, point1, point2) point0 = self.fp3(xyz, xyz1, xyz, point1) point = F.dropout(F.relu(self.bn1(self.conv1(point0))), 0.5) point = self.conv2(point) point = point.transpose(1, 2).contiguous() point = point.view(-1, cfg.num_seg) return point
train.py:
import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from DataSet import Dataset, claDataset from Model import PointNet_add_Msg, PointNet_add, PointNet_add_cla from configuration import config import torch.utils.data.dataloader as DataLoader from tensorboardX import SummaryWriter import os cfg = config() if __name__ == '__main__': # model = PointNet_add_Msg() # model = PointNet_add() model = PointNet_add_cla() model.to(cfg.device) # dataset = Dataset(cfg.dataset_root) dataset = claDataset(cfg.cladataset_root) dataloader = DataLoader.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = True) optimizer = torch.optim.Adam(model.parameters(), lr = 0.001) loss = nn.CrossEntropyLoss() tbwrite = SummaryWriter(logdir = os.path.join(cfg.checkpoint_root, 'log')) model.train() for epoch in range(cfg.num_epochs): total_true = 0 total_loss = 0 cnt = 0 for xyz, label in tqdm(dataloader): optimizer.zero_grad() output = model(xyz) label = label.view(-1, 1)[:, 0] loss_value = loss(output, label) loss_value.backward() optimizer.step() pred = torch.max(output, -1)[1] total_true += torch.sum(pred == label) total_loss += loss_value cnt += 1 mean_loss = total_loss / float(cnt) # accuracy = total_true / float(len(dataset) * 2500) accuracy = total_true / float(len(dataset)) tbwrite.add_scalar('Loss', mean_loss, epoch) tbwrite.add_scalar('Accuracy', accuracy, epoch) print('mean_loss:{:.4f}, accuracy:{:.4f}'.format(mean_loss, accuracy)) if (epoch + 1) % cfg.num_epochs == 0: state = { 'model': model.state_dict() } torch.save(state, os.path.join(cfg.checkpoint_root, 'checkpoint_{}.pth'.format(epoch + 1)))