zoukankan      html  css  js  c++  java
  • BindsNET学习系列 —— LearningRule

    相关源码:bindsnet/bindsnet/learning/learning.py

    1、LearningRule

    class LearningRule(ABC):
        # language=rst
        """
        Abstract base class for learning rules.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Abstract constructor for the ``LearningRule`` object.
    
            :param connection: An ``AbstractConnection`` object.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            # Connection parameters.
            self.connection = connection
            self.source = connection.source
            self.target = connection.target
    
            self.wmin = connection.wmin
            self.wmax = connection.wmax
    
            # Learning rate(s).
            if nu is None:
                nu = [0.0, 0.0]
            elif isinstance(nu, float) or isinstance(nu, int):
                nu = [nu, nu]
    
            self.nu = torch.zeros(2, dtype=torch.float)
            self.nu[0] = nu[0]
            self.nu[1] = nu[1]
    
            # Parameter update reduction across minibatch dimension.
            if reduction is None:
                if self.source.batch_size == 1:
                    self.reduction = torch.squeeze
                else:
                    self.reduction = torch.sum
            else:
                self.reduction = reduction
    
            # Weight decay.
            self.weight_decay = 1.0 - weight_decay if weight_decay else 1.0
    
        def update(self) -> None:
            # language=rst
            """
            Abstract method for a learning rule update.
            """
            # Implement weight decay.
            if self.weight_decay:
                self.connection.w *= self.weight_decay
    
            # Bound weights.
            if (
                self.connection.wmin != -np.inf or self.connection.wmax != np.inf
            ) and not isinstance(self, NoOp):
                self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)

    学习规则的抽象基类,可以包括权重衰减和权重截断。

    2、NoOp(没有效果的学习规则)

    class NoOp(LearningRule):
        # language=rst
        """
        Learning rule with no effect.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Abstract constructor for the ``LearningRule`` object.
    
            :param connection: An ``AbstractConnection`` object.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Abstract method for a learning rule update.
            """
            super().update()

    没有效果的学习规则(默认),直接继承自学习规则的抽象基类。

    3、PostPre(STDP的在线实现)

    class PostPre(LearningRule):
        # language=rst
        """
        Simple STDP rule involving both pre- and post-synaptic spiking activity. By default,
        pre-synaptic update is negative and the post-synaptic update is positive.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``PostPre`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``PostPre`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            assert (
                self.source.traces and self.target.traces
            ), "Both pre- and post-synaptic nodes must record spike traces."
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
            """
            batch_size = self.source.batch_size
    
            # Pre-synaptic update. 突触前脉冲发放时基于突触后迹更新突触权重(下降)
            if self.nu[0]:
                source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
                target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
                self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
                del source_s, target_x
    
            # Post-synaptic update. 突触后脉冲发放时基于突触前迹更新突触权重(上升)
            if self.nu[1]:
                target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
                source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
                self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
                del source_x, target_s
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
            """
            # Get convolutional layer parameters.
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
            batch_size = self.source.batch_size
    
            # Reshaping spike traces and spike occurrences.
            source_x = im2col_indices(
                self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
            )
            target_x = self.target.x.view(batch_size, out_channels, -1)
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            # Pre-synaptic update.
            if self.nu[0]:
                pre = self.reduction(
                    torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0
                )
                self.connection.w -= self.nu[0] * pre.view(self.connection.w.size())
    
            # Post-synaptic update.
            if self.nu[1]:
                post = self.reduction(
                    torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0
                )
                self.connection.w += self.nu[1] * post.view(self.connection.w.size())
    
            super().update()

    简单的STDP规则,包括突触前和突触后的脉冲活动。默认情况下,突触前更新为负(-学习率1 × 突触前脉冲 × 突触后发放迹),突触后更新为正(学习率2 × 突触前发放迹 × 突触后脉冲)。

    4、WeightDependentPostPre

    class WeightDependentPostPre(LearningRule):
        # language=rst
        """
        STDP rule involving both pre- and post-synaptic spiking activity. The post-synaptic
        update is positive and the pre- synaptic update is negative, and both are dependent
        on the magnitude of the synaptic weights.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``WeightDependentPostPre`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``WeightDependentPostPre`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            assert self.source.traces, "Pre-synaptic nodes must record spike traces."
            assert (
                connection.wmin != -np.inf and connection.wmax != np.inf
            ), "Connection must define finite wmin and wmax."
    
            self.wmin = connection.wmin
            self.wmax = connection.wmax
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
            """
            batch_size = self.source.batch_size
    
            source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
            source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
            target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float()
            target_x = self.target.x.view(batch_size, -1).unsqueeze(1)
    
            update = 0
    
            # Pre-synaptic update.
            if self.nu[0]:
                outer_product = self.reduction(torch.bmm(source_s, target_x), dim=0)
                update -= self.nu[0] * outer_product * (self.connection.w - self.wmin)
    
            # Post-synaptic update.
            if self.nu[1]:
                outer_product = self.reduction(torch.bmm(source_x, target_s), dim=0)
                update += self.nu[1] * outer_product * (self.wmax - self.connection.w)
    
            self.connection.w += update
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
            """
            # Get convolutional layer parameters.
            (
                out_channels,
                in_channels,
                kernel_height,
                kernel_width,
            ) = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
            batch_size = self.source.batch_size
    
            # Reshaping spike traces and spike occurrences.
            source_x = im2col_indices(
                self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
            )
            target_x = self.target.x.view(batch_size, out_channels, -1)
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            update = 0
    
            # Pre-synaptic update.
            if self.nu[0]:
                pre = self.reduction(
                    torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0
                )
                update -= (
                    self.nu[0]
                    * pre.view(self.connection.w.size())
                    * (self.connection.w - self.wmin)
                )
    
            # Post-synaptic update.
            if self.nu[1]:
                post = self.reduction(
                    torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0
                )
                update += (
                    self.nu[1]
                    * post.view(self.connection.w.size())
                    * (self.wmax - self.connection.wmin)
                )
    
            self.connection.w += update
    
            super().update()

    STDP规则涉及突触前和突触后的脉冲活动。突触前更新为负(-学习率1 × 突触前脉冲 × 突触后发放迹 × (weight - weightmin)),突触后更新为正(学习率2 × 突触前发放迹 × 突触后脉冲 × (weightmax - weight)),两者都依赖于突触权重的大小。

    5、Hebbian

    class Hebbian(LearningRule):
        # language=rst
        """
        Simple Hebbian learning rule. Pre- and post-synaptic updates are both positive.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``Hebbian`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``Hebbian`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            assert (
                self.source.traces and self.target.traces
            ), "Both pre- and post-synaptic nodes must record spike traces."
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Hebbian learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
            """
            batch_size = self.source.batch_size
    
            source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
            source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
            target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float()
            target_x = self.target.x.view(batch_size, -1).unsqueeze(1)
    
            # Pre-synaptic update.
            update = self.reduction(torch.bmm(source_s, target_x), dim=0)
            self.connection.w += self.nu[0] * update
    
            # Post-synaptic update.
            update = self.reduction(torch.bmm(source_x, target_s), dim=0)
            self.connection.w += self.nu[1] * update
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Hebbian learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
            """
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
            batch_size = self.source.batch_size
    
            # Reshaping spike traces and spike occurrences.
            source_x = im2col_indices(
                self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
            )
            target_x = self.target.x.view(batch_size, out_channels, -1)
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            # Pre-synaptic update.
            pre = self.reduction(torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0)
            self.connection.w += self.nu[0] * pre.view(self.connection.w.size())
    
            # Post-synaptic update.
            post = self.reduction(torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0)
            self.connection.w += self.nu[1] * post.view(self.connection.w.size())
    
            super().update()

    简单的Hebbian学习规则。突触前后的更新都为正(分别为学习率1 × 突触前脉冲 × 突触后发放迹,以及学习率2 × 突触前发放迹× 突触后脉冲)

    6、MSTDP

    class MSTDP(LearningRule):
        # language=rst
        """
        Reward-modulated STDP. Adapted from `(Florian 2007)
        <https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>`_.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``MSTDP`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the ``MSTDP``
                learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events,
                respectively.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param tc_plus: Time constant for pre-synaptic firing trace.
            :param tc_minus: Time constant for post-synaptic firing trace.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
            self.tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
            self.tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDP learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            batch_size = self.source.batch_size
    
            # Initialize eligibility, P^+, and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros(
                    batch_size, *self.source.shape, device=self.source.s.device
                )
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros(
                    batch_size, *self.target.shape, device=self.target.s.device
                )
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Reshape pre- and post-synaptic spikes.
            source_s = self.source.s.view(batch_size, -1).float()
            target_s = self.target.s.view(batch_size, -1).float()
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            # Compute weight update based on the eligibility value of the past timestep.
            update = reward * self.eligibility
            self.connection.w += self.nu[0] * self.reduction(update, dim=0)
    
            # Update P^+ and P^- values.
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.bmm(
                self.p_plus.unsqueeze(2), target_s.unsqueeze(1)
            ) + torch.bmm(source_s.unsqueeze(2), self.p_minus.unsqueeze(1))
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDP learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection``
            class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            batch_size = self.source.batch_size
    
            # Initialize eligibility.
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            batch_size = self.source.batch_size
    
            # Compute weight update based on the eligibility value of the past timestep.
            update = reward * self.eligibility
            self.connection.w += self.nu[0] * torch.sum(update, dim=0)
    
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
    
            # Initialize P^+ and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros(
                    batch_size, *self.source.shape, device=self.connection.w.device
                )
                self.p_plus = im2col_indices(
                    self.p_plus, kernel_height, kernel_width, padding=padding, stride=stride
                )
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros(
                    batch_size, *self.target.shape, device=self.connection.w.device
                )
                self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float()
    
            # Reshaping spike occurrences.
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            # Update P^+ and P^- values. 前者跟踪突触前脉冲的影响,后者跟踪突触后脉冲的影响
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.bmm(
                target_s, self.p_plus.permute((0, 2, 1))
            ) + torch.bmm(self.p_minus, source_s.permute((0, 2, 1)))
            self.eligibility = self.eligibility.view(self.connection.w.size())
    
            super().update()

    奖励调节STDP (R-STDP),改编自(Florian 2007)<https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>

    论文参见:Reinforcement Learning Through Modulation of Spike-Timing-Dependent Synaptic Plasticity - 穷酸秀才大艹包 - 博客园 (cnblogs.com)

    7、MSTDPET

    class MSTDPET(LearningRule):
        # language=rst
        """
        Reward-modulated STDP with eligibility trace. Adapted from
        `(Florian 2007) <https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>`_.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``MSTDPET`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``MSTDPET`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events,
                respectively.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param float tc_plus: Time constant for pre-synaptic firing trace.
            :param float tc_minus: Time constant for post-synaptic firing trace.
            :param float tc_e_trace: Time constant for the eligibility trace.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
            self.tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
            self.tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
            self.tc_e_trace = torch.tensor(kwargs.get("tc_e_trace", 25.0))
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDPET learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            # Initialize eligibility, eligibility trace, P^+, and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros((self.source.n), device=self.source.s.device)
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros((self.target.n), device=self.target.s.device)
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    *self.connection.w.shape, device=self.connection.w.device
                )
            if not hasattr(self, "eligibility_trace"):
                self.eligibility_trace = torch.zeros(
                    *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Reshape pre- and post-synaptic spikes.
            source_s = self.source.s.view(-1).float()
            target_s = self.target.s.view(-1).float()
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            # Calculate value of eligibility trace based on the value
            # of the point eligibility value of the past timestep.
            self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
            self.eligibility_trace += self.eligibility / self.tc_e_trace
    
            # Compute weight update.
            self.connection.w += (
                self.nu[0] * self.connection.dt * reward * self.eligibility_trace
            )
    
            # Update P^+ and P^- values.
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.ger(self.p_plus, target_s) + torch.ger(
                source_s, self.p_minus
            )
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDPET learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            batch_size = self.source.batch_size
    
            # Initialize eligibility and eligibility trace.
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
            if not hasattr(self, "eligibility_trace"):
                self.eligibility_trace = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            # Calculate value of eligibility trace based on the value
            # of the point eligibility value of the past timestep.
            self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
    
            # Compute weight update.
            update = reward * self.eligibility_trace
            self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0)
    
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
    
            # Initialize P^+ and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros(
                    batch_size, *self.source.shape, device=self.connection.w.device
                )
                self.p_plus = im2col_indices(
                    self.p_plus, kernel_height, kernel_width, padding=padding, stride=stride
                )
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros(
                    batch_size, *self.target.shape, device=self.connection.w.device
                )
                self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float()
    
            # Reshaping spike occurrences.
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = (
                self.target.s.permute(1, 2, 3, 0).view(batch_size, out_channels, -1).float()
            )
    
            # Update P^+ and P^- values.
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.bmm(
                target_s, self.p_plus.permute((0, 2, 1))
            ) + torch.bmm(self.p_minus, source_s.permute((0, 2, 1)))
            self.eligibility = self.eligibility.view(self.connection.w.size())
    
            super().update()

    带资格迹的奖励调节STDP (R-STDP with eligibility trace),改编自(Florian 2007)<https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>

    论文参见:Reinforcement Learning Through Modulation of Spike-Timing-Dependent Synaptic Plasticity - 穷酸秀才大艹包 - 博客园 (cnblogs.com)

    8、Rmax

    class Rmax(LearningRule):
        # language=rst
        """
        Reward-modulated learning rule derived from reward maximization principles. Adapted
        from `(Vasilaki et al., 2009)
        <https://intranet.physio.unibe.ch/Publikationen/Dokumente/Vasilaki2009PloSComputBio_1.pdf>`_.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``R-max`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the ``R-max``
                learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events,
                respectively.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param float tc_c: Time constant for balancing naive Hebbian and policy gradient
                learning.
            :param float tc_e_trace: Time constant for the eligibility trace.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            # Trace is needed for computing epsilon.
            assert (
                self.source.traces and self.source.traces_additive
            ), "Pre-synaptic nodes must use additive spike traces."
    
            # Derivation of R-max depends on stochastic SRM neurons!
            assert isinstance(
                self.target, SRM0Nodes
            ), "R-max needs stochastically firing neurons, use SRM0Nodes."
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
            self.tc_c = torch.tensor(
                kwargs.get("tc_c", 5.0)
            )  # 0 for pure naive Hebbian, inf for pure policy gradient.
            self.tc_e_trace = torch.tensor(kwargs.get("tc_e_trace", 25.0))
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            R-max learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            """
            # Initialize eligibility trace.
            if not hasattr(self, "eligibility_trace"):
                self.eligibility_trace = torch.zeros(
                    *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Reshape variables.
            target_s = self.target.s.view(-1).float()
            target_s_prob = self.target.s_prob.view(-1)
            source_x = self.source.x.view(-1)
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
    
            # New eligibility trace.
            self.eligibility_trace *= 1 - self.connection.dt / self.tc_e_trace
            self.eligibility_trace += (
                target_s
                - (target_s_prob / (1.0 + self.tc_c / self.connection.dt * target_s_prob))
            ) * source_x[:, None]
    
            # Compute weight update.
            self.connection.w += self.nu[0] * reward * self.eligibility_trace
    
            super().update()

    基于奖励最大化原理的奖励调节学习规则,改编自(Vasilaki et al., 2009)<https://intranet.physio.unibe.ch/Publikationen/Dokumente/Vasilaki2009PloSComputBio_1.pdf>

  • 相关阅读:
    PATA 1071 Speech Patterns.
    PATA 1027 Colors In Mars
    PATB 1038. 统计同成绩学生(20)
    1036. 跟奥巴马一起编程(15)
    PATA 1036. Boys vs Girls (25)
    PATA 1006. Sign In and Sign Out (25)
    读取web工程目录之外的图片并显示
    DOS命令
    java连接oracle集群
    servlet
  • 原文地址:https://www.cnblogs.com/lucifer1997/p/14313716.html
Copyright © 2011-2022 走看看