zoukankan      html  css  js  c++  java
  • BindsNET学习系列 ——Connection

    相关源码:bindsnet/bindsnet/network/topology.py

    1、AbstractConnection

    class AbstractConnection(ABC, Module):
        # language=rst
        """
        Abstract base method for connections between ``Nodes``.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for abstract base class for connection objects.
    
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param nu: Learning rate for both pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param LearningRule update_rule: Modifies connection parameters according to
                some rule.
            :param float wmin: The minimum value on the connection weights.
            :param float wmax: The maximum value on the connection weights.
            :param float norm: Total weight per target neuron normalization.
            """
            super().__init__()
    
            assert isinstance(source, Nodes), "Source is not a Nodes object"
            assert isinstance(target, Nodes), "Target is not a Nodes object"
    
            self.source = source
            self.target = target
    
            # self.nu = nu
            self.weight_decay = weight_decay
            self.reduction = reduction
    
            from ..learning import NoOp
    
            self.update_rule = kwargs.get("update_rule", NoOp)
            self.wmin = kwargs.get("wmin", -np.inf)
            self.wmax = kwargs.get("wmax", np.inf)
            self.norm = kwargs.get("norm", None)
            self.decay = kwargs.get("decay", None)
    
            if self.update_rule is None:
                self.update_rule = NoOp
    
            self.update_rule = self.update_rule(
                connection=self,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
        @abstractmethod
        def compute(self, s: torch.Tensor) -> None:
            # language=rst
            """
            Compute pre-activations of downstream neurons given spikes of upstream neurons.
    
            :param s: Incoming spikes.
            """
            pass
    
        @abstractmethod
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
    
            Keyword arguments:
    
            :param bool learning: Whether to allow connection updates.
            :param ByteTensor mask: Boolean mask determining which weights to clamp to zero.
            """
            learning = kwargs.get("learning", True)
    
            if learning:
                self.update_rule.update(**kwargs)
    
            mask = kwargs.get("mask", None)
            if mask is not None:
                self.w.masked_fill_(mask, 0)
    
        @abstractmethod
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            pass

    2、Connection 

    class Connection(AbstractConnection):
        # language=rst
        """
        Specifies synapses between one or two populations of neurons.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Instantiates a :code:`Connection` object.
    
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param nu: Learning rate for both pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param LearningRule update_rule: Modifies connection parameters according to
                some rule.
            :param torch.Tensor w: Strengths of synapses.
            :param torch.Tensor b: Target population bias.
            :param float wmin: Minimum allowed value on the connection weights.
            :param float wmax: Maximum allowed value on the connection weights.
            :param float norm: Total weight per target neuron normalization constant.
            """
            super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
    
            w = kwargs.get("w", None)
            if w is None:
                if self.wmin == -np.inf or self.wmax == np.inf:
                    w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
                else:
                    w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
            else:
                if self.wmin != -np.inf or self.wmax != np.inf:
                    w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)
    
            self.w = Parameter(w, requires_grad=False)
    
            b = kwargs.get("b", None)
            if b is not None:
                self.b = Parameter(b, requires_grad=False)
            else:
                self.b = None
    
            if isinstance(self.target, CSRMNodes):
                self.s_w = None
    
        def compute(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """
            Compute pre-activations given spikes using connection weights.
    
            :param s: Incoming spikes.
            :return: Incoming spikes multiplied by synaptic weights (with or without
                     decaying spike activation).
            """
            # Compute multiplication of spike activations by weights and add bias.
            if self.b is None:
                post = s.view(s.size(0), -1).float() @ self.w
            else:
                post = s.view(s.size(0), -1).float() @ self.w + self.b
            return post.view(s.size(0), *self.target.shape)
    
        def compute_window(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """"""
    
            if self.s_w == None:
                # Construct a matrix of shape batch size * window size * dimension of layer
                self.s_w = torch.zeros(
                    self.target.batch_size, self.target.res_window_size, *self.source.shape
                )
    
            # Add the spike vector into the first in first out matrix of windowed (res) spike trains
            self.s_w = torch.cat((self.s_w[:, 1:, :], s[:, None, :]), 1)
    
            # Compute multiplication of spike activations by weights and add bias.
            if self.b is None:
                post = (
                    self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w
                )
            else:
                post = (
                    self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w
                    + self.b
                )
    
            return post.view(
                self.s_w.size(0), self.target.res_window_size, *self.target.shape
            )
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
            """
            super().update(**kwargs)
    
        def normalize(self) -> None:
            # language=rst
            """
            Normalize weights so each target neuron has sum of connection weights equal to
            ``self.norm``.
            """
            if self.norm is not None:
                w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
                w_abs_sum[w_abs_sum == 0] = 1.0
                self.w *= self.norm / w_abs_sum
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            super().reset_state_variables()

    3、Conv2dConnection

    class Conv2dConnection(AbstractConnection):
        # language=rst
        """
        Specifies convolutional synapses between one or two populations of neurons.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Union[int, Tuple[int, int]] = 1,
            padding: Union[int, Tuple[int, int]] = 0,
            dilation: Union[int, Tuple[int, int]] = 1,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Instantiates a ``Conv2dConnection`` object.
    
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param kernel_size: Horizontal and vertical size of convolutional kernels.
            :param stride: Horizontal and vertical stride for convolution.
            :param padding: Horizontal and vertical padding for convolution.
            :param dilation: Horizontal and vertical dilation for convolution.
            :param nu: Learning rate for both pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param LearningRule update_rule: Modifies connection parameters according to
                some rule.
            :param torch.Tensor w: Strengths of synapses.
            :param torch.Tensor b: Target population bias.
            :param float wmin: Minimum allowed value on the connection weights.
            :param float wmax: Maximum allowed value on the connection weights.
            :param float norm: Total weight per target neuron normalization constant.
            """
            super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
    
            self.kernel_size = _pair(kernel_size)
            self.stride = _pair(stride)
            self.padding = _pair(padding)
            self.dilation = _pair(dilation)
    
            self.in_channels, input_height, input_width = (
                source.shape[0],
                source.shape[1],
                source.shape[2],
            )
            self.out_channels, output_height, output_width = (
                target.shape[0],
                target.shape[1],
                target.shape[2],
            )
    
            width = (
                input_height - self.kernel_size[0] + 2 * self.padding[0]
            ) / self.stride[0] + 1
            height = (
                input_width - self.kernel_size[1] + 2 * self.padding[1]
            ) / self.stride[1] + 1
            shape = (self.in_channels, self.out_channels, int(width), int(height))
    
            error = (
                "Target dimensionality must be (out_channels, ?,"
                "(input_height - filter_height + 2 * padding_height) / stride_height + 1,"
                "(input_width - filter_width + 2 * padding_width) / stride_width + 1"
            )
    
            assert (
                target.shape[0] == shape[1]
                and target.shape[1] == shape[2]
                and target.shape[2] == shape[3]
            ), error
    
            w = kwargs.get("w", None)
            if w is None:
                if self.wmin == -np.inf or self.wmax == np.inf:
                    w = torch.clamp(
                        torch.rand(self.out_channels, self.in_channels, *self.kernel_size),
                        self.wmin,
                        self.wmax,
                    )
                else:
                    w = (self.wmax - self.wmin) * torch.rand(
                        self.out_channels, self.in_channels, *self.kernel_size
                    )
                    w += self.wmin
            else:
                if self.wmin != -np.inf or self.wmax != np.inf:
                    w = torch.clamp(w, self.wmin, self.wmax)
    
            self.w = Parameter(w, requires_grad=False)
            self.b = Parameter(
                kwargs.get("b", torch.zeros(self.out_channels)),
                requires_grad=False,
            )
    
        def compute(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """
            Compute convolutional pre-activations given spikes using layer weights.
    
            :param s: Incoming spikes.
            :return: Incoming spikes multiplied by synaptic weights (with or without
                decaying spike activation).
            """
            return F.conv2d(
                s.float(),
                self.w,
                self.b,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
            )
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
            """
            super().update(**kwargs)
    
        def normalize(self) -> None:
            # language=rst
            """
            Normalize weights along the first axis according to total weight per target
            neuron.
            """
            if self.norm is not None:
                # get a view and modify in place
                w = self.w.view(
                    self.w.shape[0] * self.w.shape[1], self.w.shape[2] * self.w.shape[3]
                )
    
                for fltr in range(w.shape[0]):
                    w[fltr] *= self.norm / w[fltr].sum(0)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            super().reset_state_variables()

    4、MaxPool2dConnection

    class MaxPool2dConnection(AbstractConnection):
        # language=rst
        """
        Specifies max-pooling synapses between one or two populations of neurons by keeping
        online estimates of maximally firing neurons.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Union[int, Tuple[int, int]] = 1,
            padding: Union[int, Tuple[int, int]] = 0,
            dilation: Union[int, Tuple[int, int]] = 1,
            **kwargs
        ) -> None:
            # language=rst
            """
            Instantiates a ``MaxPool2dConnection`` object.
    
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param kernel_size: Horizontal and vertical size of convolutional kernels.
            :param stride: Horizontal and vertical stride for convolution.
            :param padding: Horizontal and vertical padding for convolution.
            :param dilation: Horizontal and vertical dilation for convolution.
    
            Keyword arguments:
    
            :param decay: Decay rate of online estimates of average firing activity.
            """
            super().__init__(source, target, None, None, 0.0, **kwargs)
    
            self.kernel_size = _pair(kernel_size)
            self.stride = _pair(stride)
            self.padding = _pair(padding)
            self.dilation = _pair(dilation)
    
            self.register_buffer("firing_rates", torch.zeros(source.s.shape))
    
        def compute(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """
            Compute max-pool pre-activations given spikes using online firing rate
            estimates.
    
            :param s: Incoming spikes.
            :return: Incoming spikes multiplied by synaptic weights (with or without
                decaying spike activation).
            """
            self.firing_rates -= self.decay * self.firing_rates
            self.firing_rates += s.float().squeeze()
    
            _, indices = F.max_pool2d(
                self.firing_rates,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                return_indices=True,
            )
    
            return s.flatten(2).gather(2, indices.flatten(2)).view_as(indices).float()
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
            """
            super().update(**kwargs)
    
        def normalize(self) -> None:
            # language=rst
            """
            No weights -> no normalization.
            """
            pass
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            super().reset_state_variables()
    
            self.firing_rates = torch.zeros(self.source.s.shape)

    5、LocalConnection

    class LocalConnection(AbstractConnection):
        # language=rst
        """
        Specifies a locally connected connection between one or two populations of neurons.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Union[int, Tuple[int, int]],
            n_filters: int,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Instantiates a ``LocalConnection`` object. Source population should be
            two-dimensional.
    
            Neurons in the post-synaptic population are ordered by receptive field; that is,
            if there are ``n_conv`` neurons in each post-synaptic patch, then the first
            ``n_conv`` neurons in the post-synaptic population correspond to the first
            receptive field, the second ``n_conv`` to the second receptive field, and so on.
    
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param kernel_size: Horizontal and vertical size of convolutional kernels.
            :param stride: Horizontal and vertical stride for convolution.
            :param n_filters: Number of locally connected filters per pre-synaptic region.
            :param nu: Learning rate for both pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param LearningRule update_rule: Modifies connection parameters according to
                some rule.
            :param torch.Tensor w: Strengths of synapses.
            :param torch.Tensor b: Target population bias.
            :param float wmin: Minimum allowed value on the connection weights.
            :param float wmax: Maximum allowed value on the connection weights.
            :param float norm: Total weight per target neuron normalization constant.
            :param Tuple[int, int] input_shape: Shape of input population if it's not
                ``[sqrt, sqrt]``.
            """
            super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
    
            kernel_size = _pair(kernel_size)
            stride = _pair(stride)
    
            self.kernel_size = kernel_size
            self.stride = stride
            self.n_filters = n_filters
    
            shape = kwargs.get("input_shape", None)
            if shape is None:
                sqrt = int(np.sqrt(source.n))
                shape = _pair(sqrt)
    
            if kernel_size == shape:
                conv_size = [1, 1]
            else:
                conv_size = (
                    int((shape[0] - kernel_size[0]) / stride[0]) + 1,
                    int((shape[1] - kernel_size[1]) / stride[1]) + 1,
                )
    
            self.conv_size = conv_size
    
            conv_prod = int(np.prod(conv_size))
            kernel_prod = int(np.prod(kernel_size))
    
            assert (
                target.n == n_filters * conv_prod
            ), "Target layer size must be n_filters * (kernel_size ** 2)."
    
            locations = torch.zeros(
                kernel_size[0], kernel_size[1], conv_size[0], conv_size[1]
            ).long()
            for c1 in range(conv_size[0]):
                for c2 in range(conv_size[1]):
                    for k1 in range(kernel_size[0]):
                        for k2 in range(kernel_size[1]):
                            location = (
                                c1 * stride[0] * shape[1]
                                + c2 * stride[1]
                                + k1 * shape[0]
                                + k2
                            )
                            locations[k1, k2, c1, c2] = location
    
            self.register_buffer("locations", locations.view(kernel_prod, conv_prod))
            w = kwargs.get("w", None)
    
            if w is None:
                w = torch.zeros(source.n, target.n)
                for f in range(n_filters):
                    for c in range(conv_prod):
                        for k in range(kernel_prod):
                            if self.wmin == -np.inf or self.wmax == np.inf:
                                w[self.locations[k, c], f * conv_prod + c] = np.clip(
                                    np.random.rand(), self.wmin, self.wmax
                                )
                            else:
                                w[
                                    self.locations[k, c], f * conv_prod + c
                                ] = self.wmin + np.random.rand() * (self.wmax - self.wmin)
            else:
                if self.wmin != -np.inf or self.wmax != np.inf:
                    w = torch.clamp(w, self.wmin, self.wmax)
    
            self.w = Parameter(w, requires_grad=False)
    
            self.register_buffer("mask", self.w == 0)
    
            self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False)
    
            if self.norm is not None:
                self.norm *= kernel_prod
    
        def compute(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """
            Compute pre-activations given spikes using layer weights.
    
            :param s: Incoming spikes.
            :return: Incoming spikes multiplied by synaptic weights (with or without
                decaying spike activation).
            """
            # Compute multiplication of pre-activations by connection weights.
            a_post = (
                s.float().view(s.size(0), -1) @ self.w.view(self.source.n, self.target.n)
                + self.b
            )
            return a_post.view(*self.target.shape)
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
    
            Keyword arguments:
    
            :param ByteTensor mask: Boolean mask determining which weights to clamp to zero.
            """
            if kwargs["mask"] is None:
                kwargs["mask"] = self.mask
    
            super().update(**kwargs)
    
        def normalize(self) -> None:
            # language=rst
            """
            Normalize weights so each target neuron has sum of connection weights equal to
            ``self.norm``.
            """
            if self.norm is not None:
                w = self.w.view(self.source.n, self.target.n)
                w *= self.norm / self.w.sum(0).view(1, -1)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            super().reset_state_variables()

    6、MeanFieldConnection

    class MeanFieldConnection(AbstractConnection):
        # language=rst
        """
        A connection between one or two populations of neurons which computes a summary of
        the pre-synaptic population to use as weighted input to the post-synaptic
        population.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            nu: Optional[Union[float, Sequence[float]]] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Instantiates a :code:`MeanFieldConnection` object.
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param nu: Learning rate for both pre- and post-synaptic events.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            Keyword arguments:
            :param LearningRule update_rule: Modifies connection parameters according to
                some rule.
            :param torch.Tensor w: Strengths of synapses.
            :param float wmin: Minimum allowed value on the connection weights.
            :param float wmax: Maximum allowed value on the connection weights.
            :param float norm: Total weight per target neuron normalization constant.
            """
            super().__init__(source, target, nu, weight_decay, **kwargs)
    
            w = kwargs.get("w", None)
            if w is None:
                if self.wmin == -np.inf or self.wmax == np.inf:
                    w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax)
                else:
                    w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin)
            else:
                if self.wmin != -np.inf or self.wmax != np.inf:
                    w = torch.clamp(w, self.wmin, self.wmax)
    
            self.w = Parameter(w, requires_grad=False)
    
        def compute(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """
            Compute pre-activations given spikes using layer weights.
            :param s: Incoming spikes.
            :return: Incoming spikes multiplied by synaptic weights (with or without
                decaying spike activation).
            """
            # Compute multiplication of mean-field pre-activation by connection weights.
            return s.float().mean() * self.w
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
            """
            super().update(**kwargs)
    
        def normalize(self) -> None:
            # language=rst
            """
            Normalize weights so each target neuron has sum of connection weights equal to
            ``self.norm``.
            """
            if self.norm is not None:
                self.w = self.w.view(1, self.target.n)
                self.w *= self.norm / self.w.sum()
                self.w = self.w.view(1, *self.target.shape)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            super().reset_state_variables()

    7、SparseConnection

    class SparseConnection(AbstractConnection):
        # language=rst
        """
        Specifies sparse synapses between one or two populations of neurons.
        """
    
        def __init__(
            self,
            source: Nodes,
            target: Nodes,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = None,
            **kwargs
        ) -> None:
            # language=rst
            """
            Instantiates a :code:`Connection` object with sparse weights.
    
            :param source: A layer of nodes from which the connection originates.
            :param target: A layer of nodes to which the connection connects.
            :param nu: Learning rate for both pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param torch.Tensor w: Strengths of synapses.
            :param float sparsity: Fraction of sparse connections to use.
            :param LearningRule update_rule: Modifies connection parameters according to
                some rule.
            :param float wmin: Minimum allowed value on the connection weights.
            :param float wmax: Maximum allowed value on the connection weights.
            :param float norm: Total weight per target neuron normalization constant.
            """
            super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
    
            w = kwargs.get("w", None)
            self.sparsity = kwargs.get("sparsity", None)
    
            assert (
                w is not None
                and self.sparsity is None
                or w is None
                and self.sparsity is not None
            ), 'Only one of "weights" or "sparsity" must be specified'
    
            if w is None and self.sparsity is not None:
                i = torch.bernoulli(
                    1 - self.sparsity * torch.ones(*source.shape, *target.shape)
                )
                if self.wmin == -np.inf or self.wmax == np.inf:
                    v = torch.clamp(
                        torch.rand(*source.shape, *target.shape)[i.bool()],
                        self.wmin,
                        self.wmax,
                    )
                else:
                    v = self.wmin + torch.rand(*source.shape, *target.shape)[i.bool()] * (
                        self.wmax - self.wmin
                    )
                w = torch.sparse.FloatTensor(i.nonzero().t(), v)
            elif w is not None and self.sparsity is None:
                assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)"
                if self.wmin != -np.inf or self.wmax != np.inf:
                    w = torch.clamp(w, self.wmin, self.wmax)
    
            self.w = Parameter(w, requires_grad=False)
    
        def compute(self, s: torch.Tensor) -> torch.Tensor:
            # language=rst
            """
            Compute convolutional pre-activations given spikes using layer weights.
    
            :param s: Incoming spikes.
            :return: Incoming spikes multiplied by synaptic weights (with or without
                decaying spike activation).
            """
            return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Compute connection's update rule.
            """
            pass
    
        def normalize(self) -> None:
            # language=rst
            """
            Normalize weights along the first axis according to total weight per target
            neuron.
            """
            pass
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Contains resetting logic for the connection.
            """
            super().reset_state_variables()
  • 相关阅读:
    与众不同 LibreOJ
    清点人数 LibreOJ
    数星星 Stars LibreOJ
    树状数组 1 :单点修改,区间查询 Gym
    DDL与DML
    605. 种花问题
    Minimum Spanning Tree Gym
    最小生成树动态演示
    ssm框架添加分页
    ssm框架下的文件自动生成模板(简单方便)
  • 原文地址:https://www.cnblogs.com/lucifer1997/p/14350180.html
Copyright © 2011-2022 走看看