1. PointNet论文理解:
https://blog.csdn.net/phosphenesvision/article/details/106724377?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param
https://zhuanlan.zhihu.com/p/86331508
PointNet 源码理解:
https://www.jianshu.com/p/4646016620db
https://blog.csdn.net/cg129054036/article/details/105456002
https://blog.csdn.net/jcsm__/article/details/109244748
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 import torch 2 import numpy as np 3 4 class STN3d(nn.Module): 5 ''' 6 3x3 transform 7 ''' 8 def __init__(self): 9 super(STN3d, self).__init__() 10 self.conv1 = torch.nn.Conv1d(3, 64, 1) 11 self.conv2 = torch.nn.Conv1d(64, 128, 1) 12 self.conv3 = torch.nn.Conv1d(128, 1024, 1) 13 self.fc1 = nn.Linear(1024, 512) 14 self.fc2 = nn.Linear(512, 256) 15 self.fc3 = nn.Linear(256, 9) 16 self.relu = nn.ReLU() 17 18 self.bn1 = nn.BatchNorm1d(64) 19 self.bn2 = nn.BatchNorm1d(128) 20 self.bn3 = nn.BatchNorm1d(1024) 21 self.bn4 = nn.BatchNorm1d(512) 22 self.bn5 = nn.BatchNorm1d(256) 23 24 25 def forward(self, x): 26 batchsize = x.size()[0] #32 27 x = F.relu(self.bn1(self.conv1(x))) #32 *64 *2500 28 x = F.relu(self.bn2(self.conv2(x))) #32*128*2500 29 x = F.relu(self.bn3(self.conv3(x))) #32*1024*2500 30 x = torch.max(x, 2, keepdim=True)[0] 31 x = x.view(-1, 1024) #32*1024 32 33 x = F.relu(self.bn4(self.fc1(x))) #32*512 34 x = F.relu(self.bn5(self.fc2(x))) #32*256 35 x = self.fc3(x) #32*9 36 37 iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) 38 if x.is_cuda: 39 iden = iden.cuda() 40 x = x + iden 41 x = x.view(-1, 3, 3) #32*3*3 42 return x 43 44 45 class STNkd(nn.Module): 46 ''' 47 64x64 transform 48 ''' 49 def __init__(self, k=64): 50 super(STNkd, self).__init__() 51 self.conv1 = torch.nn.Conv1d(k, 64, 1) 52 self.conv2 = torch.nn.Conv1d(64, 128, 1) 53 self.conv3 = torch.nn.Conv1d(128, 1024, 1) 54 self.fc1 = nn.Linear(1024, 512) 55 self.fc2 = nn.Linear(512, 256) 56 self.fc3 = nn.Linear(256, k*k) 57 self.relu = nn.ReLU() 58 59 self.bn1 = nn.BatchNorm1d(64) 60 self.bn2 = nn.BatchNorm1d(128) 61 self.bn3 = nn.BatchNorm1d(1024) 62 self.bn4 = nn.BatchNorm1d(512) 63 self.bn5 = nn.BatchNorm1d(256) 64 65 self.k = k 66 67 def forward(self, x): 68 batchsize = x.size()[0] 69 x = F.relu(self.bn1(self.conv1(x))) 70 x = F.relu(self.bn2(self.conv2(x))) 71 x = F.relu(self.bn3(self.conv3(x))) 72 x = torch.max(x, 2, keepdim=True)[0] 73 x = x.view(-1, 1024) 74 75 x = F.relu(self.bn4(self.fc1(x))) 76 x = F.relu(self.bn5(self.fc2(x))) 77 x = self.fc3(x) 78 79 iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) 80 if x.is_cuda: 81 iden = iden.cuda() 82 x = x + iden 83 x = x.view(-1, self.k, self.k) 84 return x 85 86 class PointNetfeat(nn.Module): 87 ''' 88 Output: global feature / local+global feature 89 ''' 90 def __init__(self, global_feat = True, feature_transform = False): 91 super(PointNetfeat, self).__init__() 92 self.stn = STN3d() 93 self.conv1 = torch.nn.Conv1d(3, 64, 1) 94 self.conv2 = torch.nn.Conv1d(64, 128, 1) 95 self.conv3 = torch.nn.Conv1d(128, 1024, 1) 96 self.bn1 = nn.BatchNorm1d(64) 97 self.bn2 = nn.BatchNorm1d(128) 98 self.bn3 = nn.BatchNorm1d(1024) 99 self.global_feat = global_feat 100 self.feature_transform = feature_transform 101 if self.feature_transform: 102 self.fstn = STNkd(k=64) 103 104 def forward(self, x): #x : 32*3*2500 105 n_pts = x.size()[2] # 2500 106 trans = self.stn(x) # 32*3*3 107 x = x.transpose(2, 1) # 32*2500*3 108 109 x = torch.bmm(x, trans) # X:32*3*2500 trans:32*2500*3 输出:32*3*3 110 x = x.transpose(2, 1) # 32 *3*2500 111 x = F.relu(self.bn1(self.conv1(x))) #32*64*2500 112 113 if self.feature_transform: 114 trans_feat = self.fstn(x) #输入x: 32*64*2500, 输出trans_feat:32*64*64 115 x = x.transpose(2,1) #32*2500*64 116 x = torch.bmm(x, trans_feat) #输入:32*2500*64 32*64*64 输出:32*2500*64 117 x = x.transpose(2,1) #32*64*2500 118 else: 119 trans_feat = None 120 121 pointfeat = x 122 x = F.relu(self.bn2(self.conv2(x))) #32*128*2500 123 x = self.bn3(self.conv3(x)) #32*1024*2500 124 x = torch.max(x, 2, keepdim=True)[0] #32*1024的乘积 125 x = x.view(-1, 1024) #32*1024 126 127 if self.global_feat: 128 return x, trans, trans_feat # (B, 1024) (B, 3, 3) (B, 64, 64) 129 else: 130 x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 131 return torch.cat([x, pointfeat], 1), trans, trans_feat # (B, 1088, 2500) (B,3, 3) (B, 64, 64) 132 133 class PointNetCls(nn.Module): 134 # 分类网络 135 def __init__(self, k=2, feature_transform=False): 136 super(PointNetCls, self).__init__() 137 self.feature_transform = feature_transform 138 self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) 139 self.fc1 = nn.Linear(1024, 512) 140 self.fc2 = nn.Linear(512, 256) 141 self.fc3 = nn.Linear(256, k) 142 self.dropout = nn.Dropout(p=0.3) 143 self.bn1 = nn.BatchNorm1d(512) 144 self.bn2 = nn.BatchNorm1d(256) 145 self.relu = nn.ReLU() 146 147 def forward(self, x): 148 x, trans, trans_feat = self.feat(x) #32*1024 32*3*3 32*64*64 149 x = F.relu(self.bn1(self.fc1(x))) #32*512 150 x = F.relu(self.bn2(self.dropout(self.fc2(x)))) #32*256 151 x = self.fc3(x) #32*K 152 return F.log_softmax(x, dim=1), trans, trans_feat 153 154 155 class PointNetDenseCls(nn.Module): 156 # 分割网络 157 def __init__(self, k = 2, feature_transform=False): 158 super(PointNetDenseCls, self).__init__() 159 self.k = k 160 self.feature_transform=feature_transform 161 self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) 162 self.conv1 = torch.nn.Conv1d(1088, 512, 1) 163 self.conv2 = torch.nn.Conv1d(512, 256, 1) 164 self.conv3 = torch.nn.Conv1d(256, 128, 1) 165 self.conv4 = torch.nn.Conv1d(128, self.k, 1) 166 self.bn1 = nn.BatchNorm1d(512) 167 self.bn2 = nn.BatchNorm1d(256) 168 self.bn3 = nn.BatchNorm1d(128) 169 170 def forward(self, x): #32*3*2500 171 batchsize = x.size()[0] #32 172 n_pts = x.size()[2] #2500 173 x, trans, trans_feat = self.feat(x) #[32, 1088, 2500] 174 175 x = F.relu(self.bn1(self.conv1(x))) #[32, 512, 2500] 176 x = F.relu(self.bn2(self.conv2(x))) #[32, 256, 2500] 177 x = F.relu(self.bn3(self.conv3(x))) #[32, 128, 2500] 178 x = self.conv4(x)#[32, K, 2500] 179 x = x.transpose(2,1).contiguous()#[32*2500*k] 180 x = F.log_softmax(x.view(-1,self.k), dim=-1) 181 x = x.view(batchsize, n_pts, self.k)#[32, 2500, k] 182 return x, trans, trans_feat 183 184 if __name__ == '__main__': 185 sim_data = Variable(torch.rand(32,3,2500)) 186 trans = STN3d() 187 out = trans(sim_data) 188 print('stn', out.size()) 189 print('loss', feature_transform_regularizer(out)) 190 191 sim_data_64d = Variable(torch.rand(32, 64, 2500)) 192 trans = STNkd(k=64) 193 out = trans(sim_data_64d) 194 print('stn64d', out.size()) 195 print('loss', feature_transform_regularizer(out)) 196 197 pointfeat = PointNetfeat(global_feat=True) 198 out, _, _ = pointfeat(sim_data) 199 print('global feat', out.size()) 200 201 pointfeat = PointNetfeat(global_feat=False) 202 out, _, _ = pointfeat(sim_data) 203 print('point feat', out.size()) 204 205 cls = PointNetCls(k = 5) 206 out, _, _ = cls(sim_data) 207 print('class', out.size()) 208 209 seg = PointNetDenseCls(k = 3) 210 out, _, _ = seg(sim_data) 211 print('seg', out.size())
2. PointNet++论文理解:
https://blog.csdn.net/yong_qi2015/article/details/108957905
https://zhuanlan.zhihu.com/p/88238420
源码理解:
https://blog.csdn.net/cg129054036/article/details/105545895?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param