关于整图分类,有篇知乎写的很好:【图分类】10分钟就学会的图分类教程,基于pytorch和dgl。下面的代码也是来者这篇知乎。
import dgl
import torch
from torch._C import device
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import MiniGCDataset
from dgl.nn.pytorch import GraphConv
from sklearn.metrics import accuracy_score
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim) # 定义第一层图卷积
self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定义第二层图卷积
self.classify = nn.Linear(hidden_dim, n_classes) # 定义分类器
def forward(self, g):
"""g表示批处理后的大图,N表示大图的所有节点数量,n表示图的数量
"""
# 为方便,我们用节点的度作为初始节点特征。对于无向图,入度 = 出度
h = g.in_degrees().view(-1, 1).float() # [N, 1]
# 执行图卷积和激活函数
h = F.relu(self.conv1(g, h)) # [N, hidden_dim]
h = F.relu(self.conv2(g, h)) # [N, hidden_dim]
g.ndata['h'] = h # 将特征赋予到图的节点
# 通过平均池化每个节点的表示得到图表示
hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim]
return self.classify(hg) # [n, n_classes]
def collate(samples):
# 输入参数samples是一个列表
# 列表里的每个元素是图和标签对,如[(graph1, label1), (graph2, label2), ...]
# zip(*samples)是解压操作,解压为[(graph1, graph2, ...), (label1, label2, ...)]
graphs, labels = map(list, zip(*samples))
# dgl.batch 将一批图看作是具有许多互不连接的组件构成的大型图
return dgl.batch(graphs), torch.tensor(labels, dtype=torch.long)
# 创建训练集和测试集
trainset = MiniGCDataset(2000, 10, 20) # 生成2000个图,每个图的最小节点数>=10, 最大节点数<=20
testset = MiniGCDataset(1000, 10, 20)
# 用pytorch的DataLoader和之前定义的collect函数
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
collate_fn=collate)
DEVICE = torch.device("cuda:2")
# 构造模型
model = Classifier(1, 256, trainset.num_classes)
model.to(DEVICE)
# 定义分类交叉熵损失
loss_func = nn.CrossEntropyLoss()
# 定义Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 模型训练
model.train()
epoch_losses = []
for epoch in range(100):
epoch_loss = 0
for iter, (batchg, label) in enumerate(data_loader):
batchg, label = batchg.to(DEVICE), label.to(DEVICE)
prediction = model(batchg)
loss = loss_func(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
# 测试
test_loader = DataLoader(testset, batch_size=64, shuffle=False,
collate_fn=collate)
model.eval()
test_pred, test_label = [], []
with torch.no_grad():
for it, (batchg, label) in enumerate(test_loader):
batchg, label = batchg.to(DEVICE), label.to(DEVICE)
pred = torch.softmax(model(batchg), 1)
pred = torch.max(pred, 1)[1].view(-1)
test_pred += pred.detach().cpu().numpy().tolist()
test_label += label.cpu().numpy().tolist()
print("Test accuracy: ", accuracy_score(test_label, test_pred))
运行结果:
