zoukankan      html  css  js  c++  java
  • BindsNET学习系列——Encoding

    相关源码:bindsnet/bindsnet/encoding/encodings.py;

    1、bernoulli

    def bernoulli(
        datum: torch.Tensor,
        time: Optional[int] = None,
        dt: float = 1.0,
        device="cpu",
        **kwargs
    ) -> torch.Tensor:
        # language=rst
        """
        Generates Bernoulli-distributed spike trains based on input intensity. Inputs must
        be non-negative. Spikes correspond to successful Bernoulli trials, with success
        probability equal to (normalized in [0, 1]) input value.
    
        :param datum: Tensor of shape ``[n_1, ..., n_k]``.
        :param time: Length of Bernoulli spike train per input variable.
        :param dt: Simulation time step.
        :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Bernoulli-distributed spikes.
    
        Keyword arguments:
    
        :param float max_prob: Maximum probability of spike per Bernoulli trial.
        """
        # Setting kwargs.
        max_prob = kwargs.get("max_prob", 1.0)
    
        assert 0 <= max_prob <= 1, "Maximum firing probability must be in range [0, 1]"
        assert (datum >= 0).all(), "Inputs must be non-negative"
    
        shape, size = datum.shape, datum.numel()
        datum = datum.flatten()
    
        if time is not None: 
            time = int(time / dt) # 100
    
        # Normalize inputs and rescale (spike probability proportional to input intensity).
        if datum.max() > 1.0:
            datum /= datum.max()
    
        # Make spike data from Bernoulli sampling.
        if time is None: # pass
            spikes = torch.bernoulli(max_prob * datum).to(device)
            spikes = spikes.view(*shape)
        else: 
            spikes = torch.bernoulli(max_prob * datum.repeat([time, 1]))
            spikes = spikes.view(time, *shape)
    
        return spikes.byte()

    Bernoulli编码:基于输入强度生成Bernoulli分布脉冲序列。输入必须为非负。脉冲对应于成功的Bernoulli试验,成功概率等于(标准化为[0, 1])输入值。

    2、poisson

    def poisson(
        datum: torch.Tensor,
        time: int,
        dt: float = 1.0,
        device="cpu",
        approx=False,
        **kwargs
    ) -> torch.Tensor:
        # language=rst
        """
        Generates Poisson-distributed spike trains based on input intensity. Inputs must be
        non-negative, and give the firing rate in Hz. Inter-spike intervals (ISIs) for
        non-negative data incremented by one to avoid zero intervals while maintaining ISI
        distributions.
    
        :param datum: Tensor of shape ``[n_1, ..., n_k]``.
        :param time: Length of Poisson spike train per input variable.
        :param dt: Simulation time step.
        :param device: target destination of poisson spikes.
        :param approx: Bool: use alternate faster, less accurate computation.
        :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes.
        """
        assert (datum >= 0).all(), "Inputs must be non-negative"
    
        # Get shape and size of data.
        shape, size = datum.shape, datum.numel()
        datum = datum.flatten()
        time = int(time / dt)
    
        if approx:
            # random normal power awful approximation
            x = torch.randn((time, size), device=device).abs()
            x = torch.pow(x, (datum * 0.11 + 5) / 50)
            y = torch.tensor(x < 0.6, dtype=torch.bool, device=device)
    
            return y.view(time, *shape).byte()
        else:
            # Compute firing rates in seconds as function of data intensity,
            # accounting for simulation time step.
            rate = torch.zeros(size, device=device)
            rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt)
    
            # Create Poisson distribution and sample inter-spike intervals
            # (incrementing by 1 to avoid zero intervals).
            dist = torch.distributions.Poisson(rate=rate, validate_args=False)
            intervals = dist.sample(sample_shape=torch.Size([time + 1]))
            intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float()
    
            # Calculate spike times by cumulatively summing over time dimension.
            times = torch.cumsum(intervals, dim=0).long()
            times[times >= time + 1] = 0
    
            # Create tensor of spikes.
            spikes = torch.zeros(time + 1, size, device=device).byte()
            spikes[times, torch.arange(size)] = 1
            spikes = spikes[1:]
    
            return spikes.view(time, *shape)

    Poisson编码:根据输入强度生成泊松分布的脉冲序列。输入必须是非负的,并给出以赫兹为单位的发放率。非负数据的脉冲间隔(ISI)通过增加1以避免间隔为零,同时保持ISI分布。通过在时间维度上对ISI序列累计求和来计算脉冲时间,并基于此创建脉冲张量。

    3、rank_order

    def rank_order(
        datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs
    ) -> torch.Tensor:
        # language=rst
        """
        Encodes data via a rank order coding-like representation. One spike per neuron,
        temporally ordered by decreasing intensity. Inputs must be non-negative.
    
        :param datum: Tensor of shape ``[n_samples, n_1, ..., n_k]``.
        :param time: Length of rank order-encoded spike train per input variable.
        :param dt: Simulation time step.
        :return: Tensor of shape ``[time, n_1, ..., n_k]`` of rank order-encoded spikes.
        """
        assert (datum >= 0).all(), "Inputs must be non-negative"
    
        shape, size = datum.shape, datum.numel()
        datum = datum.flatten()
        time = int(time / dt)
    
        # Create spike times in order of decreasing intensity.
        datum /= datum.max()
        times = torch.zeros(size)
        times[datum != 0] = 1 / datum[datum != 0]
        times *= time / times.max()  # Extended through simulation time.
        times = torch.ceil(times).long()
    
        # Create spike times tensor.
        spikes = torch.zeros(time, size).byte()
        for i in range(size):
            if 0 < times[i] < time:
                spikes[times[i] - 1, i] = 1
    
        return spikes.reshape(time, *shape)

    通过类似于秩的编码表征对数据进行编码。每个神经元一个脉冲,按强度递减的时序排列。输入必须为非负。

  • 相关阅读:
    mysql创建用户,并赋予权限:只能查某个数据库中的某张表(只读)
    Fastjson toJSONString用单引号进行转换
    MyBatis传入参数为list、数组、map写法
    进制GB和GiB的区别
    leaflet 根据一个经纬度及距离角度,算出另外一个经纬度
    ubuntu下安装YApi
    Oracle 存储过程测试
    Oracle两种临时表的创建与使用详解
    一月到十二月的英文
    spring framework各个版本下载网址
  • 原文地址:https://www.cnblogs.com/lucifer1997/p/14348350.html
Copyright © 2011-2022 走看看