zoukankan      html  css  js  c++  java
  • BindsNET学习系列 —— Nodes

    相关源码:bindsnet/bindsnet/network/nodes.py

    1、Nodes

    class Nodes(torch.nn.Module):
        # language=rst
        """
        Abstract base class for groups of neurons.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            learning: bool = True,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Abstract base class constructor.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record decaying spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param learning: Whether to be in learning or testing.
            """
            super().__init__()
    
            assert (
                n is not None or shape is not None
            ), "Must provide either no. of neurons or shape of layer"
    
            if n is None:
                self.n = reduce(mul, shape)  # No. of neurons product of shape.
            else:
                self.n = n  # No. of neurons provided.
    
            if shape is None:
                self.shape = [self.n]  # Shape is equal to the size of the layer.
            else:
                self.shape = shape  # Shape is passed in as an argument.
    
            assert self.n == reduce(
                mul, self.shape
            ), "No. of neurons and shape do not match"
    
            self.traces = traces  # Whether to record synaptic traces.
            self.traces_additive = (
                traces_additive  # Whether to record spike traces additively.
            )
            self.register_buffer("s", torch.ByteTensor())  # Spike occurrences.
    
            self.sum_input = sum_input  # Whether to sum all inputs.
    
            if self.traces:
                self.register_buffer("x", torch.Tensor())  # Firing traces.
                self.register_buffer(
                    "tc_trace", torch.tensor(tc_trace)
                )  # Time constant of spike trace decay.
                if self.traces_additive:
                    self.register_buffer(
                        "trace_scale", torch.tensor(trace_scale)
                    )  # Scaling factor for spike trace.
                self.register_buffer(
                    "trace_decay", torch.empty_like(self.tc_trace)
                )  # Set in compute_decays, e.g. math.exp(-1/20)=0.9512
    
            if self.sum_input:
                self.register_buffer("summed", torch.FloatTensor())  # Summed inputs.
    
            self.dt = None
            self.batch_size = None
            self.trace_decay = None
            self.learning = learning
    
        @abstractmethod
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Abstract base class method for a single simulation step.
    
            :param x: Inputs to the layer.
            """
    
            if self.traces:
                # Decay and set spike traces.
                self.x *= self.trace_decay
    
                if self.traces_additive:
                    self.x += self.trace_scale * self.s.float()
                else:
                    self.x.masked_fill_(self.s, 1)
    
            if self.sum_input:
                # Add current input to running sum.
                self.summed += x.float()
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Abstract base class method for resetting state variables.
            """
            self.s.zero_()
    
            if self.traces:
                self.x.zero_()  # Spike traces.
    
            if self.sum_input:
                self.summed.zero_()  # Summed inputs.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Abstract base class method for setting decays.
            """
            self.dt = torch.tensor(dt)
            if self.traces:
                self.trace_decay = torch.exp(
                    -self.dt / self.tc_trace
                )  # Spike trace decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            self.batch_size = batch_size
            self.s = torch.zeros(
                batch_size, *self.shape, device=self.s.device, dtype=torch.bool
            )
    
            if self.traces:
                self.x = torch.zeros(batch_size, *self.shape, device=self.x.device)
    
            if self.sum_input:
                self.summed = torch.zeros(
                    batch_size, *self.shape, device=self.summed.device
                )
    
        def train(self, mode: bool = True) -> "Nodes":
            # language=rst
            """
            Sets the layer in training mode.
    
            :param bool mode: Turn training on or off
            :return: self as specified in `torch.nn.Module`
            """
            self.learning = mode
            return super().train(mode)
    View Code

    神经元群的抽象基类。当self.traces = True时,先对发放迹进行衰减,然后基于脉冲信号对原本发放迹进行1填充(即发放脉冲的神经元,其发放率为1)。

    2、Input

    class AbstractInput(ABC):
        # language=rst
        """
        Abstract base class for groups of input neurons.
        """
    
    
    class Input(Nodes, AbstractInput):
        # language=rst
        """
        Layer of nodes with user-specified spiking behavior.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of input neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record decaying spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            On each simulation step, set the spikes of the population equal to the inputs.
    
            :param x: Inputs to the layer.
            """
            # Set spike occurrences to input values.
            self.s = x
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
    View Code

    具有用户指定的脉冲行为的节点层。先将前向输入赋值给脉冲信号,然后执行基类函数。

    3、McCullochPitts

    class McCullochPitts(Nodes):
        # language=rst
        """
        Layer of `McCulloch-Pitts neurons
        <http://wwwold.ece.utep.edu/research/webfuzzy/docs/kk-thesis/kk-thesis-html/node12.html>`_.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = 1.0,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a McCulloch-Pitts layer of neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer(
                "thresh", torch.tensor(thresh, dtype=torch.float)
            )  # Spike threshold voltage.
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            self.v = x  # Voltages are equal to the inputs.
            self.s = self.v >= self.thresh  # Check for spiking neurons.
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = torch.zeros(batch_size, *self.shape, device=self.v.device)
    View Code

    McCulloch-Pitts神经元层。先将前向输入赋值给电压信号;然后将电压与阈值作比较,得到脉冲信号;最后执行基类函数。

    4、IFNodes

    class IFNodes(Nodes):
        # language=rst
        """
        Layer of `integrate-and-fire (IF) neurons <http://neuronaldynamics.epfl.ch/online/Ch1.S3.html>`_.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = -52.0,
            reset: Union[float, torch.Tensor] = -65.0,
            refrac: Union[int, torch.Tensor] = 5,
            lbound: float = None,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of IF neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            :param reset: Post-spike reset voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param lbound: Lower bound of the voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer(
                "reset", torch.tensor(reset, dtype=torch.float)
            )  # Post-spike reset voltage.
            self.register_buffer(
                "thresh", torch.tensor(thresh, dtype=torch.float)
            )  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer(
                "refrac_count", torch.FloatTensor()
            )  # Refractory period counters.
    
            self.lbound = lbound  # Lower bound of voltage.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Integrate input voltages.
            self.v += (self.refrac_count <= 0).float() * x
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh
    
            # Refractoriness and voltage reset.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, self.reset)
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.reset)  # Neuron voltages.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.reset * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    IF神经元层。先整合输入电压(同时考虑不应期计数);然后降低不应期计数;接着将电压与阈值作比较,得到脉冲信号;之后基于脉冲信号对不应期计数器和电压进行修改;电压截断至下限;最后执行基类函数。

    5、LIFNodes

    class LIFNodes(Nodes):
        # language=rst
        """
        Layer of `leaky integrate-and-fire (LIF) neurons
        <http://web.archive.org/web/20190318204706/http://icwww.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000>`_.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = -52.0,
            rest: Union[float, torch.Tensor] = -65.0,
            reset: Union[float, torch.Tensor] = -65.0,
            refrac: Union[int, torch.Tensor] = 5,
            tc_decay: Union[float, torch.Tensor] = 100.0,
            lbound: float = None,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of LIF neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            :param rest: Resting membrane voltage.
            :param reset: Post-spike reset voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            :param lbound: Lower bound of the voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer(
                "rest", torch.tensor(rest, dtype=torch.float)
            )  # Rest voltage.
            self.register_buffer(
                "reset", torch.tensor(reset, dtype=torch.float)
            )  # Post-spike reset voltage.
            self.register_buffer(
                "thresh", torch.tensor(thresh, dtype=torch.float)
            )  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay, dtype=torch.float)
            )  # Time constant of neuron voltage decay.
            self.register_buffer(
                "decay", torch.zeros(*self.shape)
            )  # Set in compute_decays.
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer(
                "refrac_count", torch.FloatTensor()
            )  # Refractory period counters.
    
            if lbound is None:
                self.lbound = None  # Lower bound of voltage.
            else:
                self.lbound = torch.tensor(
                    lbound, dtype=torch.float
                )  # Lower bound of voltage.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages.
            self.v = self.decay * (self.v - self.rest) + self.rest
    
            # Integrate inputs.
            x.masked_fill_(self.refrac_count > 0, 0.0)
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            self.v += x  # interlaced
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh
    
            # Refractoriness and voltage reset.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, self.reset)
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    LIF神经元层。先整合输入电压(同时考虑不应期计数和电压衰减);然后降低不应期计数器;接着将电压与阈值作比较,得到脉冲信号;之后基于脉冲信号对不应期计数器以及电压进行修改;电压截断至下限;最后执行基类函数。

    6、BoostedLIFNodes

    class BoostedLIFNodes(Nodes):
        # Same as LIFNodes, faster: no rest, no reset, no lbound
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = 13.0,
            refrac: Union[int, torch.Tensor] = 5,
            tc_decay: Union[float, torch.Tensor] = 100.0,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of LIF neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            :param reset: Post-spike reset voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer(
                "thresh", torch.tensor(thresh, dtype=torch.float)
            )  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay, dtype=torch.float)
            )  # Time constant of neuron voltage decay.
            self.register_buffer(
                "decay", torch.zeros(*self.shape)
            )  # Set in compute_decays.
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer(
                "refrac_count", torch.tensor(0)
            )  # Refractory period counters.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages.
            self.v *= self.decay
    
            # Integrate inputs.
            if x is not None:
                x.masked_fill_(self.refrac_count > 0, 0.0)
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            if x is not None:
                self.v += x
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh
    
            # Refractoriness and voltage reset.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, 0)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(0)  # Neuron voltages.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = torch.zeros(batch_size, *self.shape, device=self.v.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    与LIFNodes相同,但更快——无需不应期,没有重置(直接归零),并且没有下界。 

    7、CurrentLIFNodes

    class CurrentLIFNodes(Nodes):
        # language=rst
        """
        Layer of `current-based leaky integrate-and-fire (LIF) neurons
        <http://web.archive.org/web/20190318204706/http://icwww.epfl.ch/~gerstner/SPNM/node26.html#SECTION02313000000000000000>`_.
        Total synaptic input current is modeled as a decaying memory of input spikes multiplied by synaptic strengths.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = -52.0,
            rest: Union[float, torch.Tensor] = -65.0,
            reset: Union[float, torch.Tensor] = -65.0,
            refrac: Union[int, torch.Tensor] = 5,
            tc_decay: Union[float, torch.Tensor] = 100.0,
            tc_i_decay: Union[float, torch.Tensor] = 2.0,
            lbound: float = None,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of synaptic input current-based LIF neurons.
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            :param rest: Resting membrane voltage.
            :param reset: Post-spike reset voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            :param tc_i_decay: Time constant of synaptic input current decay.
            :param lbound: Lower bound of the voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
            self.register_buffer("reset", torch.tensor(reset))  # Post-spike reset voltage.
            self.register_buffer("thresh", torch.tensor(thresh))  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay)
            )  # Time constant of neuron voltage decay.
            self.register_buffer(
                "decay", torch.empty_like(self.tc_decay)
            )  # Set in compute_decays.
            self.register_buffer(
                "tc_i_decay", torch.tensor(tc_i_decay)
            )  # Time constant of synaptic input current decay.
            self.register_buffer(
                "i_decay", torch.empty_like(self.tc_i_decay)
            )  # Set in compute_decays.
    
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer("i", torch.FloatTensor())  # Synaptic input currents.
            self.register_buffer(
                "refrac_count", torch.FloatTensor()
            )  # Refractory period counters.
    
            self.lbound = lbound  # Lower bound of voltage.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages and current.
            self.v = self.decay * (self.v - self.rest) + self.rest
            self.i *= self.i_decay
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            # Integrate inputs.
            self.i += x
            self.v += (self.refrac_count <= 0).float() * self.i
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh
    
            # Refractoriness and voltage reset.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, self.reset)
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
            self.i.zero_()  # Synaptic input currents.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
            self.i_decay = torch.exp(
                -self.dt / self.tc_i_decay
            )  # Synaptic input current decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.i = torch.zeros_like(self.v, device=self.i.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    基于电流的LIF神经元层具体神经元过程如下:

    1. 电压/电流衰减;
    2. 不应期计数器自减;

    3. 整合输入(电流自增输入项,电压自增不应期计数器小于等于零的电流项);
    4. 检查脉冲神经元(当电压>阈值时,发放脉冲);
    5. 不应性/电压复位;
    6. 电压截断至下限;
    7. Nodes基类前向传播(迹/累积和)。

    8、AdaptiveLIFNodes

    class AdaptiveLIFNodes(Nodes):
        # language=rst
        """
        Layer of leaky integrate-and-fire (LIF) neurons with adaptive thresholds. A neuron's voltage threshold is increased
        by some constant each time it spikes; otherwise, it is decaying back to its default value.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            rest: Union[float, torch.Tensor] = -65.0,
            reset: Union[float, torch.Tensor] = -65.0,
            thresh: Union[float, torch.Tensor] = -52.0,
            refrac: Union[int, torch.Tensor] = 5,
            tc_decay: Union[float, torch.Tensor] = 100.0,
            theta_plus: Union[float, torch.Tensor] = 0.05,
            tc_theta_decay: Union[float, torch.Tensor] = 1e7,
            lbound: float = None,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of LIF neurons with adaptive firing thresholds.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param rest: Resting membrane voltage.
            :param reset: Post-spike reset voltage.
            :param thresh: Spike threshold voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            :param theta_plus: Voltage increase of threshold after spiking.
            :param tc_theta_decay: Time constant of adaptive threshold decay.
            :param lbound: Lower bound of the voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
            self.register_buffer("reset", torch.tensor(reset))  # Post-spike reset voltage.
            self.register_buffer("thresh", torch.tensor(thresh))  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay)
            )  # Time constant of neuron voltage decay.
            self.register_buffer(
                "decay", torch.empty_like(self.tc_decay, dtype=torch.float32)
            )  # Set in compute_decays.
            self.register_buffer(
                "theta_plus", torch.tensor(theta_plus)
            )  # Constant threshold increase on spike.
            self.register_buffer(
                "tc_theta_decay", torch.tensor(tc_theta_decay)
            )  # Time constant of adaptive threshold decay.
            self.register_buffer(
                "theta_decay", torch.empty_like(self.tc_theta_decay)
            )  # Set in compute_decays.
    
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer("theta", torch.zeros(*self.shape))  # Adaptive thresholds.
            self.register_buffer(
                "refrac_count", torch.FloatTensor()
            )  # Refractory period counters.
            self.lbound = lbound  # Lower bound of voltage.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages and adaptive thresholds.
            self.v = self.decay * (self.v - self.rest) + self.rest
            if self.learning:
                self.theta *= self.theta_decay
    
            # Integrate inputs.
            self.v += (self.refrac_count <= 0).float() * x
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh + self.theta
    
            # Refractoriness, voltage reset, and adaptive thresholds.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, self.reset)
            if self.learning:
                self.theta += self.theta_plus * self.s.float().sum(0)
    
            # voltage clipping to lowerbound
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
            self.theta_decay = torch.exp(
                -self.dt / self.tc_theta_decay
            )  # Adaptive threshold decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    具有自适应阈值的LIF神经元层。神经元的电压阈值在每次脉冲时都会增加一个常数;否则,它就会衰减回默认值。

    9、DiehlAndCookNodes

    class DiehlAndCookNodes(Nodes):
        # language=rst
        """
        Layer of leaky integrate-and-fire (LIF) neurons with adaptive thresholds (modified for Diehl & Cook 2015
        replication).
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = -52.0,
            rest: Union[float, torch.Tensor] = -65.0,
            reset: Union[float, torch.Tensor] = -65.0,
            refrac: Union[int, torch.Tensor] = 5,
            tc_decay: Union[float, torch.Tensor] = 100.0,
            theta_plus: Union[float, torch.Tensor] = 0.05,
            tc_theta_decay: Union[float, torch.Tensor] = 1e7,
            lbound: float = None,
            one_spike: bool = True,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of Diehl & Cook 2015 neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            :param rest: Resting membrane voltage.
            :param reset: Post-spike reset voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            :param theta_plus: Voltage increase of threshold after spiking.
            :param tc_theta_decay: Time constant of adaptive threshold decay.
            :param lbound: Lower bound of the voltage.
            :param one_spike: Whether to allow only one spike per timestep.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
            self.register_buffer("reset", torch.tensor(reset))  # Post-spike reset voltage.
            self.register_buffer("thresh", torch.tensor(thresh))  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay)
            )  # Time constant of neuron voltage decay.
            self.register_buffer(
                "decay", torch.empty_like(self.tc_decay)
            )  # Set in compute_decays.
            self.register_buffer(
                "theta_plus", torch.tensor(theta_plus)
            )  # Constant threshold increase on spike.
            self.register_buffer(
                "tc_theta_decay", torch.tensor(tc_theta_decay)
            )  # Time constant of adaptive threshold decay.
            self.register_buffer(
                "theta_decay", torch.empty_like(self.tc_theta_decay)
            )  # Set in compute_decays.
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer("theta", torch.zeros(*self.shape))  # Adaptive thresholds.
            self.register_buffer(
                "refrac_count", torch.FloatTensor()
            )  # Refractory period counters.
    
            self.lbound = lbound  # Lower bound of voltage.
            self.one_spike = one_spike  # One spike per timestep.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages and adaptive thresholds.
            self.v = self.decay * (self.v - self.rest) + self.rest
            if self.learning:
                self.theta *= self.theta_decay
    
            # Integrate inputs.
            self.v += (self.refrac_count <= 0).float() * x
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh + self.theta
    
            # Refractoriness, voltage reset, and adaptive thresholds.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, self.reset)
            if self.learning:
                self.theta += self.theta_plus * self.s.float().sum(0)
    
            # Choose only a single neuron to spike.
            if self.one_spike:
                if self.s.any():
                    _any = self.s.view(self.batch_size, -1).any(1)
                    ind = torch.multinomial(
                        self.s.float().view(self.batch_size, -1)[_any], 1
                    )
                    _any = _any.nonzero()
                    self.s.zero_()
                    self.s.view(self.batch_size, -1)[_any, ind] = 1
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
            self.theta_decay = torch.exp(
                -self.dt / self.tc_theta_decay
            )  # Adaptive threshold decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    具有自适应阈值的LIF神经元层(针对Diehl & Cook 2015复现进行了修改),每个时间步骤只发放一个脉冲。 

    10、IzhikevichNodes

    class IzhikevichNodes(Nodes):
        # language=rst
        """
        Layer of `Izhikevich neurons<https://www.izhikevich.org/publications/spikes.htm>`_.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            excitatory: float = 1,
            thresh: Union[float, torch.Tensor] = 45.0,
            rest: Union[float, torch.Tensor] = -65.0,
            lbound: float = None,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of Izhikevich neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param excitatory: Percent of excitatory (vs. inhibitory) neurons in the layer; in range ``[0, 1]``.
            :param thresh: Spike threshold voltage.
            :param rest: Resting membrane voltage.
            :param lbound: Lower bound of the voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
            self.register_buffer("thresh", torch.tensor(thresh))  # Spike threshold voltage.
            self.lbound = lbound
    
            self.register_buffer("r", None)
            self.register_buffer("a", None)
            self.register_buffer("b", None)
            self.register_buffer("c", None)
            self.register_buffer("d", None)
            self.register_buffer("S", None)
            self.register_buffer("excitatory", None)
    
            if excitatory > 1:
                excitatory = 1
            elif excitatory < 0:
                excitatory = 0
    
            if excitatory == 1:
                self.r = torch.rand(n)
                self.a = 0.02 * torch.ones(n)
                self.b = 0.2 * torch.ones(n)
                self.c = -65.0 + 15 * (self.r ** 2)
                self.d = 8 - 6 * (self.r ** 2)
                self.S = 0.5 * torch.rand(n, n)
                self.excitatory = torch.ones(n).byte()
    
            elif excitatory == 0:
                self.r = torch.rand(n)
                self.a = 0.02 + 0.08 * self.r
                self.b = 0.25 - 0.05 * self.r
                self.c = -65.0 * torch.ones(n)
                self.d = 2 * torch.ones(n)
                self.S = -torch.rand(n, n)
    
                self.excitatory = torch.zeros(n).byte()
    
            else:
                self.excitatory = torch.zeros(n).byte()
    
                ex = int(n * excitatory)
                inh = n - ex
    
                # init
                self.r = torch.zeros(n)
                self.a = torch.zeros(n)
                self.b = torch.zeros(n)
                self.c = torch.zeros(n)
                self.d = torch.zeros(n)
                self.S = torch.zeros(n, n)
    
                # excitatory
                self.r[:ex] = torch.rand(ex)
                self.a[:ex] = 0.02 * torch.ones(ex)
                self.b[:ex] = 0.2 * torch.ones(ex)
                self.c[:ex] = -65.0 + 15 * self.r[:ex] ** 2
                self.d[:ex] = 8 - 6 * self.r[:ex] ** 2
                self.S[:, :ex] = 0.5 * torch.rand(n, ex)
                self.excitatory[:ex] = 1
    
                # inhibitory
                self.r[ex:] = torch.rand(inh)
                self.a[ex:] = 0.02 + 0.08 * self.r[ex:]
                self.b[ex:] = 0.25 - 0.05 * self.r[ex:]
                self.c[ex:] = -65.0 * torch.ones(inh)
                self.d[ex:] = 2 * torch.ones(inh)
                self.S[:, ex:] = -torch.rand(n, inh)
                self.excitatory[ex:] = 0
    
            self.register_buffer("v", self.rest * torch.ones(n))  # Neuron voltages.
            self.register_buffer("u", self.b * self.v)  # Neuron recovery.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Check for spiking neurons.
            self.s = self.v >= self.thresh
    
            # Voltage and recovery reset.
            self.v = torch.where(self.s, self.c, self.v)
            self.u = torch.where(self.s, self.u + self.d, self.u)
    
            # Add inter-columnar input.
            if self.s.any():
                x += torch.cat(
                    [self.S[:, self.s[i]].sum(dim=1)[None] for i in range(self.s.shape[0])],
                    dim=0,
                )
    
            # Apply v and u updates.
            self.v += self.dt * 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x)
            self.v += self.dt * 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x)
            self.u += self.dt * self.a * (self.b * self.v - self.u)
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
            self.u = self.b * self.v  # Neuron recovery.
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.u = self.b * self.v
    View Code

    Izhikevich神经元层r:神经元异质性变量,原则上,可以使用RS细胞来模拟所有兴奋性神经元,而使用FS细胞来模拟所有抑制性神经元,因此,ri = 0对应于常规脉冲(RS)细胞,ri = 1对应于抖振(CH)细胞;a:恢复变量u的时间常数,a越小,恢复的越慢;b:恢复变量u依赖膜电位v的阈值下随机波动的敏感程度;c:发放脉冲后,v的复位值;d:发放脉冲后,u的复位值;S:神经元之间的突触连接权重;excitatory:该层中兴奋性神经元的比例(范围为0-1);v:膜电位;u:恢复变量,用来代替生理模型中激活的K离子电流和失活的Na离子电流,实现对膜电位V的负反馈;

    • 理论公式:

    • 实际实现:

    11、CSRMNodes

    class CSRMNodes(Nodes):
        """
        A layer of Cumulative Spike Response Model (Gerstner and van Hemmen 1992, Gerstner et al. 1996) nodes.
        It accounts for a model where refractoriness and adaptation were modeled by the combined effects
        of the spike after potentials of several previous spikes, rather than only the most recent spike.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            rest: Union[float, torch.Tensor] = -65.0,
            thresh: Union[float, torch.Tensor] = -52.0,
            responseKernel: str = "ExponentialKernel",
            refractoryKernel: str = "EtaKernel",
            tau: Union[float, torch.Tensor] = 1,
            res_window_size: Union[float, torch.Tensor] = 20,
            ref_window_size: Union[float, torch.Tensor] = 10,
            reset_const: Union[float, torch.Tensor] = 50,
            tc_decay: Union[float, torch.Tensor] = 100.0,
            theta_plus: Union[float, torch.Tensor] = 0.05,
            tc_theta_decay: Union[float, torch.Tensor] = 1e7,
            lbound: float = None,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of Cumulative Spike Response Model nodes.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param rest: Resting membrane voltage.
            :param thresh: Spike threshold voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            :param theta_plus: Voltage increase of threshold after spiking.
            :param tc_theta_decay: Time constant of adaptive threshold decay.
            :param lbound: Lower bound of the voltage.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
            self.register_buffer("thresh", torch.tensor(thresh))  # Spike threshold voltage.
            self.register_buffer(
                "tau", torch.tensor(tau)
            )  # Time constant of Spike Response Kernel
            self.register_buffer(
                "reset_const", torch.tensor(reset_const)
            )  # Reset constant of refractory kernel
            self.register_buffer(
                "res_window_size", torch.tensor(res_window_size)
            )  # Window size for sampling incoming input current
            self.register_buffer(
                "ref_window_size", torch.tensor(ref_window_size)
            )  # Window size for sampling previous spikes
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay)
            )  # Time constant of neuron voltage decay.
            self.register_buffer(
                "decay", torch.empty_like(self.tc_decay, dtype=torch.float32)
            )  # Set in compute_decays.
            self.register_buffer(
                "theta_plus", torch.tensor(theta_plus)
            )  # Constant threshold increase on spike.
            self.register_buffer(
                "tc_theta_decay", torch.tensor(tc_theta_decay)
            )  # Time constant of adaptive threshold decay.
            self.register_buffer(
                "theta_decay", torch.empty_like(self.tc_theta_decay)
            )  # Set in compute_decays.
    
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer(
                "last_spikes", torch.ByteTensor()
            )  # Previous spikes occurrences in time window
            self.register_buffer("theta", torch.zeros(*self.shape))  # Adaptive thresholds.
            self.lbound = lbound  # Lower bound of voltage.
    
            self.responseKernel = responseKernel  # Type of spike response kernel used
    
            self.refractoryKernel = refractoryKernel  # Type of refractory kernel used
    
            self.register_buffer(
                "resKernel", torch.FloatTensor()
            )  # Vector of synaptic response kernel values over a window of time
    
            self.register_buffer(
                "refKernel", torch.FloatTensor()
            )  # Vector of refractory kernel values over a window of time
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages.
            self.v *= self.decay
    
            if self.learning:
                self.theta *= self.theta_decay
    
            # Integrate inputs.
            v = torch.einsum(
                "i,kij->kj", self.resKernel, x
            )  # Response due to incoming current
            v += torch.einsum(
                "i,kij->kj", self.refKernel, self.last_spikes
            )  # Refractoriness due to previous spikes
            self.v += v.view(x.size(0), *self.shape)
    
            # Check for spiking neurons.
            self.s = self.v >= self.thresh + self.theta
    
            if self.learning:
                self.theta += self.theta_plus * self.s.float().sum(0)
    
            # Add the spike vector into the first in first out matrix of windowed (ref) spike trains
            self.last_spikes = torch.cat(
                (self.last_spikes[:, 1:, :], self.s[:, None, :]), 1
            )
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
            self.theta_decay = torch.exp(
                -self.dt / self.tc_theta_decay
            )  # Adaptive threshold decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.last_spikes = torch.zeros(batch_size, self.ref_window_size, *self.shape)
    
            resKernels = {
                "AlphaKernel": self.AlphaKernel,
                "AlphaKernelSLAYER": self.AlphaKernelSLAYER,
                "LaplacianKernel": self.LaplacianKernel,
                "ExponentialKernel": self.ExponentialKernel,
                "RectangularKernel": self.RectangularKernel,
                "TriangularKernel": self.TriangularKernel,
            }
    
            if self.responseKernel not in resKernels.keys():
                raise Exception(" The given response Kernel is not implemented")
    
            self.resKernel = resKernels[self.responseKernel](self.dt)
    
            refKernels = {"EtaKernel": self.EtaKernel}
    
            if self.refractoryKernel not in refKernels.keys():
                raise Exception(" The given refractory Kernel is not implemented")
    
            self.refKernel = refKernels[self.refractoryKernel](self.dt)
    
        def AlphaKernel(self, dt):
            t = torch.arange(0, self.res_window_size, dt)
            kernelVec = (1 / (self.tau ** 2)) * t * torch.exp(-t / self.tau)
            return torch.flip(kernelVec, [0])
    
        def AlphaKernelSLAYER(self, dt):
            t = torch.arange(0, self.res_window_size, dt)
            kernelVec = (1 / self.tau) * t * torch.exp(1 - t / self.tau)
            return torch.flip(kernelVec, [0])
    
        def LaplacianKernel(self, dt):
            t = torch.arange(0, self.res_window_size, dt)
            kernelVec = (1 / (self.tau * 2)) * torch.exp(-1 * torch.abs(t / self.tau))
            return torch.flip(kernelVec, [0])
    
        def ExponentialKernel(self, dt):
            t = torch.arange(0, self.res_window_size, dt)
            kernelVec = (1 / self.tau) * torch.exp(-t / self.tau)
            return torch.flip(kernelVec, [0])
    
        def RectangularKernel(self, dt):
            t = torch.arange(0, self.res_window_size, dt)
            kernelVec = 1 / (selftau * 2)
            return torch.flip(kernelVec, [0])
    
        def TriangularKernel(self, dt):
            t = torch.arange(0, self.res_window_size, dt)
            kernelVec = (1 / self.tau) * (1 - (t / self.tau))
            return torch.flip(kernelVec, [0])
    
        def EtaKernel(self, dt):
            t = torch.arange(0, self.ref_window_size, dt)
            kernelVec = -self.reset_const * torch.exp(-t / self.tau)
            return torch.flip(kernelVec, [0])
    View Code

    累积脉冲响应模型(Gerstner and van Hemmen 1992, Gerstner et al. 1996)节点层。它解释了一个模型,在这个模型中,不应期和适应性是由前几次脉冲的电压之后的脉冲(而不仅仅是最近的脉冲)的综合效应来建模的。 

    12、SRM0Nodes

    class SRM0Nodes(Nodes):
        # language=rst
        """
        Layer of simplified spike response model (SRM0) neurons with stochastic threshold (escape noise). Adapted from
        `(Vasilaki et al., 2009) <https://intranet.physio.unibe.ch/Publikationen/Dokumente/Vasilaki2009PloSComputBio_1.pdf>`_.
        """
    
        def __init__(
            self,
            n: Optional[int] = None,
            shape: Optional[Iterable[int]] = None,
            traces: bool = False,
            traces_additive: bool = False,
            tc_trace: Union[float, torch.Tensor] = 20.0,
            trace_scale: Union[float, torch.Tensor] = 1.0,
            sum_input: bool = False,
            thresh: Union[float, torch.Tensor] = -50.0,
            rest: Union[float, torch.Tensor] = -70.0,
            reset: Union[float, torch.Tensor] = -70.0,
            refrac: Union[int, torch.Tensor] = 5,
            tc_decay: Union[float, torch.Tensor] = 10.0,
            lbound: float = None,
            eps_0: Union[float, torch.Tensor] = 1.0,
            rho_0: Union[float, torch.Tensor] = 1.0,
            d_thresh: Union[float, torch.Tensor] = 5.0,
            **kwargs,
        ) -> None:
            # language=rst
            """
            Instantiates a layer of SRM0 neurons.
    
            :param n: The number of neurons in the layer.
            :param shape: The dimensionality of the layer.
            :param traces: Whether to record spike traces.
            :param traces_additive: Whether to record spike traces additively.
            :param tc_trace: Time constant of spike trace decay.
            :param trace_scale: Scaling factor for spike trace.
            :param sum_input: Whether to sum all inputs.
            :param thresh: Spike threshold voltage.
            :param rest: Resting membrane voltage.
            :param reset: Post-spike reset voltage.
            :param refrac: Refractory (non-firing) period of the neuron.
            :param tc_decay: Time constant of neuron voltage decay.
            :param lbound: Lower bound of the voltage.
            :param eps_0: Scaling factor for pre-synaptic spike contributions.
            :param rho_0: Stochastic intensity at threshold.
            :param d_thresh: Width of the threshold region.
            """
            super().__init__(
                n=n,
                shape=shape,
                traces=traces,
                traces_additive=traces_additive,
                tc_trace=tc_trace,
                trace_scale=trace_scale,
                sum_input=sum_input,
            )
    
            self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
            self.register_buffer("reset", torch.tensor(reset))  # Post-spike reset voltage.
            self.register_buffer("thresh", torch.tensor(thresh))  # Spike threshold voltage.
            self.register_buffer(
                "refrac", torch.tensor(refrac)
            )  # Post-spike refractory period.
            self.register_buffer(
                "tc_decay", torch.tensor(tc_decay)
            )  # Time constant of neuron voltage decay.
            self.register_buffer("decay", torch.tensor(tc_decay))  # Set in compute_decays.
            self.register_buffer(
                "eps_0", torch.tensor(eps_0)
            )  # Scaling factor for pre-synaptic spike contributions.
            self.register_buffer(
                "rho_0", torch.tensor(rho_0)
            )  # Stochastic intensity at threshold.
            self.register_buffer(
                "d_thresh", torch.tensor(d_thresh)
            )  # Width of the threshold region.
            self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
            self.register_buffer(
                "refrac_count", torch.FloatTensor()
            )  # Refractory period counters.
    
            self.lbound = lbound  # Lower bound of voltage.
    
        def forward(self, x: torch.Tensor) -> None:
            # language=rst
            """
            Runs a single simulation step.
    
            :param x: Inputs to the layer.
            """
            # Decay voltages.
            self.v = self.decay * (self.v - self.rest) + self.rest
    
            # Integrate inputs.
            self.v += (self.refrac_count <= 0).float() * self.eps_0 * x
    
            # Compute (instantaneous) probabilities of spiking, clamp between 0 and 1 using exponentials.
            # Also known as 'escape noise', this simulates nearby neurons.
            self.rho = self.rho_0 * torch.exp((self.v - self.thresh) / self.d_thresh)
            self.s_prob = 1.0 - torch.exp(-self.rho * self.dt)
    
            # Decrement refractory counters.
            self.refrac_count -= self.dt
    
            # Check for spiking neurons (spike when probability > some random number).
            self.s = torch.rand_like(self.s_prob) < self.s_prob
    
            # Refractoriness and voltage reset.
            self.refrac_count.masked_fill_(self.s, self.refrac)
            self.v.masked_fill_(self.s, self.reset)
    
            # Voltage clipping to lower bound.
            if self.lbound is not None:
                self.v.masked_fill_(self.v < self.lbound, self.lbound)
    
            super().forward(x)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Resets relevant state variables.
            """
            super().reset_state_variables()
            self.v.fill_(self.rest)  # Neuron voltages.
            self.refrac_count.zero_()  # Refractory period counters.
    
        def compute_decays(self, dt) -> None:
            # language=rst
            """
            Sets the relevant decays.
            """
            super().compute_decays(dt=dt)
            self.decay = torch.exp(
                -self.dt / self.tc_decay
            )  # Neuron voltage decay (per timestep).
    
        def set_batch_size(self, batch_size) -> None:
            # language=rst
            """
            Sets mini-batch size. Called when layer is added to a network.
    
            :param batch_size: Mini-batch size.
            """
            super().set_batch_size(batch_size=batch_size)
            self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
            self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)
    View Code

    具有随机阈值(逃逸噪声)的简化脉冲响应模型(SRM0)神经元层。改编自(Vasilaki et al., 2009)。关于论文与具体实现的关系,开发者回答如下:

      “首先,论文有点不清楚。在第13页,等式(23)规定了SRM模型的使用。然而,第15页的图9突然将这些神经元称为"带有逃逸噪声的LIF"。那么,是哪一个?因为文章对等式(23)中的epsilon(t)和eta(t)使用了单指数核,所以SRM模型实际上只是一个具有随机阈值的LIF,我们可以用迭代的方式来计算,而不必在所有的时间步骤上进行卷积。现在让我们继续使用图9的定义。图9中的神经元和BindsNET中的SRM0节点之间的区别纯粹是架构上的:如果我没记错的话,文章中的神经元有前馈连接(一层到下一层)和横向连接(一层神经元之间)。这将解释图9中的亚阈动态方程中的两个输入项:xi(或其对应的希腊字母)的前馈权重,W的横向权重。其他一切都应该是一样的:随机发放与逃逸噪声已经实现,核衰减是相同的。我和其中一位作者谈过一次,她告诉我他们仅用了一个计数器来计算不应期,就像SRM0Nodes一样。”

      该神经元具体过程如下:

    1. 电压衰减;
    2. 整合输入(不应期计数器小于等于零的输入项);
    3. 计算(瞬时)脉冲概率,使用指数在0和1之间截断(也被称为“逃逸噪音”,它模拟附近的神经元);
    4. 不应期计数器自减;
    5. 检查脉冲神经元(当概率>某个随机数时,发放脉冲);
    6. 不应性/电压复位;
    7. 电压截断至下限;
    8. Nodes基类前向传播(迹/累积和)。
  • 相关阅读:
    1008 Elevator (20)(20 分)
    1007 Maximum Subsequence Sum (25)(25 分)
    1006 Sign In and Sign Out (25)(25 分)
    酒店前台管理系统主要代码截图
    酒店前台管理系统
    酒店前台管理系统需求分析
    session与cookie的区别和联系
    ldap的实现单点登录
    git如何使用
    javaweb项目中的web.xml
  • 原文地址:https://www.cnblogs.com/lucifer1997/p/14314148.html
Copyright © 2011-2022 走看看