zoukankan      html  css  js  c++  java
  • dgl sageconv 源码

    dgl version:0.5.x

    说明: 在SAGEConv中,如果想要再消息传递过程中,使用边上的信息,可以将fn.copy_src('h', 'm') 替换为 fn.copy_e('h', 'm')。

    [egin{align}egin{aligned}h_{mathcal{N}(i)}^{(l+1)} &= mathrm{aggregate} left({h_{j}^{l}, forall j in mathcal{N}(i) } ight)\h_{i}^{(l+1)} &= sigma left(W cdot mathrm{concat} (h_{i}^{l}, h_{mathcal{N}(i)}^{l+1}) ight)\h_{i}^{(l+1)} &= mathrm{norm}(h_{i}^{l})end{aligned}end{align} ]

    dgl SAGEConv 过程:
    feat(ndim)为节点的特征,其中n为节点个数,dim为特征的维度
    如果是同构图:
    conv = SAGEConv(dim,dim_out,'pool') 返回一个conv layer实例,
    res = conv(g, feat) 在图g上对feat进行SAGEConv操作,输入的维度是dim,输出的维度是dim_out,同构图输出和输入的节点数相同,即res的维度是n
    dim_out
    如果是二部图:
    conv = SAGEConv((dim_v,dim_u),dim_out,'pool')
    res = conv(g, (feat_v,feat_u)) 输入的维度是dim_v和dim_u,其中v当作源节点,U当作目标节点,输出的维度是dim_out,二部图输出节点数与u的节点数相同,即res的维度是n_u*dim_out
    消息传递阶段,同构图使用graph.number_of_dst_nodes()取出feat_src对应feat_dst
    h_self = feat_dst # 将feat_dst 作为目的节点自身特征
    graph.srcdata['h'] = feat_src
    对不同的聚合方法使用不同的消息传递过程,
    graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) # aggre_type == 'mean' , fn.copy_src('h', 'm')是message_func将源节点的'h'传递到目的节点的'm',fn.mean('m', 'neigh')是reduce_func对目的节点的消息'm'聚合到'neigh'
    h_neigh = graph.dstdata['neigh']
    这个时候h_self的维度和目的节点的特征维度相同,h_neigh的维度和目的节点的特征维度相同,需要进行线性变换将维度转成dim_out
    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

    """Torch Module for GraphSAGE layer"""
    # pylint: disable= no-member, arguments-differ, invalid-name
    import torch
    from torch import nn
    from torch.nn import functional as F
    
    from .... import function as fn
    from ....utils import expand_as_pair, check_eq_shape
    
    
    class SAGEConv(nn.Module):
        r"""
        Description
        -----------
        GraphSAGE layer from paper `Inductive Representation Learning on
        Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
        .. math::
            h_{mathcal{N}(i)}^{(l+1)} &= mathrm{aggregate}
            left({h_{j}^{l}, forall j in mathcal{N}(i) }
    ight)
            h_{i}^{(l+1)} &= sigma left(W cdot mathrm{concat}
            (h_{i}^{l}, h_{mathcal{N}(i)}^{l+1}) 
    ight)
            h_{i}^{(l+1)} &= mathrm{norm}(h_{i}^{l})
        Parameters
        ----------
        in_feats : int, or pair of ints
            Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
            GATConv can be applied on homogeneous graph and unidirectional
            `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
            If the layer applies on a unidirectional bipartite graph, ``in_feats``
            specifies the input feature size on both the source and destination nodes.  If
            a scalar is given, the source and destination node feature size would take the
            same value.
            If aggregator type is ``gcn``, the feature size of source and destination nodes
            are required to be the same.
        out_feats : int
            Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
        feat_drop : float
            Dropout rate on features, default: ``0``.
        aggregator_type : str
            Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
        bias : bool
            If True, adds a learnable bias to the output. Default: ``True``.
        norm : callable activation function/layer or None, optional
            If not None, applies normalization to the updated node features.
        activation : callable activation function/layer or None, optional
            If not None, applies an activation function to the updated node features.
            Default: ``None``.
        Examples
        --------
        >>> import dgl
        >>> import numpy as np
        >>> import torch as th
        >>> from dgl.nn import SAGEConv
        >>> # Case 1: Homogeneous graph
        >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
        >>> g = dgl.add_self_loop(g)
        >>> feat = th.ones(6, 10)
        >>> conv = SAGEConv(10, 2, 'pool')
        >>> res = conv(g, feat)
        >>> res
        tensor([[-1.0888, -2.1099],
                [-1.0888, -2.1099],
                [-1.0888, -2.1099],
                [-1.0888, -2.1099],
                [-1.0888, -2.1099],
                [-1.0888, -2.1099]], grad_fn=<AddBackward0>)
        >>> # Case 2: Unidirectional bipartite graph
        >>> u = [0, 1, 0, 0, 1]
        >>> v = [0, 1, 2, 3, 2]
        >>> g = dgl.bipartite((u, v))
        >>> u_fea = th.rand(2, 5)
        >>> v_fea = th.rand(4, 10)
        >>> conv = SAGEConv((5, 10), 2, 'mean')
        >>> res = conv(g, (u_fea, v_fea))
        >>> res
        tensor([[ 0.3163,  3.1166],
                [ 0.3866,  2.5398],
                [ 0.5873,  1.6597],
                [-0.2502,  2.8068]], grad_fn=<AddBackward0>)
        """
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     feat_drop=0.,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()
    
            # 将in_feats展成 in_src 和 in_dst 两部分
            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.feat_drop = nn.Dropout(feat_drop)
            self.activation = activation
            # aggregator type: mean/pool/lstm/gcn
                
            if aggregator_type == 'pool':
                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
            if aggregator_type == 'lstm':
                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
            if aggregator_type != 'gcn':
                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
            # 线性变换,维度变为out_feats
            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
            self.reset_parameters()
    
        def reset_parameters(self):
            r"""
            Description
            -----------
            Reinitialize learnable parameters.
            Note
            ----
            The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
            The LSTM module is using xavier initialization method for its weights.
            """
            gain = nn.init.calculate_gain('relu')
            if self._aggre_type == 'pool':
                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
            if self._aggre_type == 'lstm':
                self.lstm.reset_parameters()
            if self._aggre_type != 'gcn':
                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
    
        def _lstm_reducer(self, nodes):
            """LSTM reducer
            NOTE(zihao): lstm reducer with default schedule (degree bucketing)
            is slow, we could accelerate this with degree padding in the future.
            """
            m = nodes.mailbox['m'] # (B, L, D)
            batch_size = m.shape[0]
            h = (m.new_zeros((1, batch_size, self._in_src_feats)),
                 m.new_zeros((1, batch_size, self._in_src_feats)))
            _, (rst, _) = self.lstm(m, h)
            return {'neigh': rst.squeeze(0)}
    
        def forward(self, graph, feat):
            r"""
            Description
            -----------
            Compute GraphSAGE layer.
            Parameters
            ----------
            graph : DGLGraph
                The graph.
            feat : torch.Tensor or pair of torch.Tensor
                If a torch.Tensor is given, it represents the input feature of shape
                :math:`(N, D_{in})`
                where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
                If a pair of torch.Tensor is given, the pair must contain two tensors of shape
                :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
            Returns
            -------
            torch.Tensor
                The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
                is size of output feature.
            """
            # 如果是同构图,使用graph.number_of_dst_nodes()取出feat_src对应feat_dst
            # 如果是二部图, 直接获取feat_src 和 feat_dst
            with graph.local_scope():
                if isinstance(feat, tuple):
                    feat_src = self.feat_drop(feat[0])
                    feat_dst = self.feat_drop(feat[1])
                else:
                    feat_src = feat_dst = self.feat_drop(feat)
                    if graph.is_block:
                        feat_dst = feat_src[:graph.number_of_dst_nodes()]
                # feat_dst 传给h_self
                h_self = feat_dst
    
                # Handle the case of graphs without edges
                if graph.number_of_edges() == 0:
                    graph.dstdata['neigh'] = torch.zeros(
                        feat_dst.shape[0], self._in_src_feats).to(feat_dst)
    
                if self._aggre_type == 'mean':
                    graph.srcdata['h'] = feat_src
                    graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
                    h_neigh = graph.dstdata['neigh']
                elif self._aggre_type == 'gcn':
                    check_eq_shape(feat)
                    graph.srcdata['h'] = feat_src
                    graph.dstdata['h'] = feat_dst     # same as above if homogeneous
                    graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
                    # divide in_degrees
                    degs = graph.in_degrees().to(feat_dst)
                    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
                elif self._aggre_type == 'pool':
                    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                    graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                    h_neigh = graph.dstdata['neigh']
                elif self._aggre_type == 'lstm':
                    graph.srcdata['h'] = feat_src
                    graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
                    h_neigh = graph.dstdata['neigh']
                else:
                    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
    
                # GraphSAGE GCN does not require fc_self.
                if self._aggre_type == 'gcn':
                    rst = self.fc_neigh(h_neigh)
                else:
                    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
                # activation
                if self.activation is not None:
                    rst = self.activation(rst)
                # normalization
                if self.norm is not None:
                    rst = self.norm(rst)
                return rst
    
  • 相关阅读:
    封了1000多个IP地址段,服务器现在坚如磐石,对付几个小毛贼还是很轻松的
    这两周服务器被攻击,封锁了600多个IP地址段后今天服务器安静多了
    centos clamav杀毒软件安装配置及查杀,没想到linux下病毒比windows还多!
    JS 在页面上直接将json数据导出到excel,支持chrome,edge,IE10+,IE9,IE8,Safari,Firefox
    一个实战系统的权限架构思维推导过程
    股灾情形下搞了个满堂红,我也是醉了
    VBC#代码互转工具
    DSAPI多功能.NET函数库组件
    DS标签控件文本解析格式
    DSAPI官方QQ群
  • 原文地址:https://www.cnblogs.com/sandy-t/p/13573305.html
Copyright © 2011-2022 走看看