zoukankan      html  css  js  c++  java
  • BindsNET学习系列——BasePipeline

    相关源码:bindsnet/bindsnet/pipeline/base_pipeline.py

    class BasePipeline:
        # language=rst
        """
        A generic pipeline that handles high level functionality.
        """
    
        def __init__(self, network: Network, **kwargs) -> None:
            # language=rst
            """
            Initializes the pipeline.
    
            :param network: Arbitrary network object, will be managed by the
                ``BasePipeline`` class.
    
            Keyword arguments:
    
            :param int save_interval: How often to save the network to disk.
            :param str save_dir: Directory to save network object to.
            :param Dict[str, Any] plot_config: Dict containing the plot configuration.
                Includes length, type (``"color"`` or ``"line"``), and interval per plot
                type.
            :param int print_interval: Interval to print text output.
            :param bool allow_gpu: Allows automatic transfer to the GPU.
            """
            self.network = network
    
            # Network saving handles caching of intermediate results.
            self.save_dir = kwargs.get("save_dir", "network.pt")
            self.save_interval = kwargs.get("save_interval", None)
    
            # Handles plotting of all layer spikes and voltages.
            # This constructs monitors at every level.
            self.plot_config = kwargs.get(
                "plot_config", {"data_step": True, "data_length": 100}
            )
    
            if self.plot_config["data_step"] is not None:
                for l in self.network.layers:
                    self.network.add_monitor(
                        Monitor(
                            self.network.layers[l], "s", self.plot_config["data_length"]
                        ),
                        name=f"{l}_spikes",
                    )
                    if hasattr(self.network.layers[l], "v"):
                        self.network.add_monitor(
                            Monitor(
                                self.network.layers[l], "v", self.plot_config["data_length"]
                            ),
                            name=f"{l}_voltages",
                        )
    
            self.print_interval = kwargs.get("print_interval", None)
            self.test_interval = kwargs.get("test_interval", None)
            self.step_count = 0
            self.init_fn()
            self.clock = time.time()
            self.allow_gpu = kwargs.get("allow_gpu", True)
    
            if torch.cuda.is_available() and self.allow_gpu:
                self.device = torch.device("cuda")
            else:
                self.device = torch.device("cpu")
    
            self.network.to(self.device)
    
        def reset_state_variables(self) -> None:
            # language=rst
            """
            Reset the pipeline.
            """
            self.network.reset_state_variables()
            self.step_count = 0
    
        def step(self, batch: Any, **kwargs) -> Any:
            # language=rst
            """
            Single step of any pipeline at a high level.
    
            :param batch: A batch of inputs to be handed to the ``step_()`` function.
                          Standard in subclasses of ``BasePipeline``.
            :return: The output from the subclass's ``step_()`` method, which could be
                anything. Passed to plotting to accommodate this.
            """
            self.step_count += 1
    
            batch = recursive_to(batch, self.device)
            step_out = self.step_(batch, **kwargs)
    
            if (
                self.print_interval is not None
                and self.step_count % self.print_interval == 0
            ):
                print(
                    f"Iteration: {self.step_count} (Time: {time.time() - self.clock:.4f})"
                )
                self.clock = time.time()
    
            self.plots(batch, step_out)
    
            if self.save_interval is not None and self.step_count % self.save_interval == 0:
                self.network.save(self.save_dir)
    
            if self.test_interval is not None and self.step_count % self.test_interval == 0:
                self.test()
    
            return step_out
    
        def get_spike_data(self) -> Dict[str, torch.Tensor]:
            # language=rst
            """
            Get the spike data from all layers in the pipeline's network.
    
            :return: A dictionary containing all spike monitors from the network.
            """
            return {
                l: self.network.monitors[f"{l}_spikes"].get("s")
                for l in self.network.layers
            }
    
        def get_voltage_data(
            self,
        ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
            # language=rst
            """
            Get the voltage data and threshold value from all applicable layers in the
            pipeline's network.
    
            :return: Two dictionaries containing the voltage data and threshold values from
                the network.
            """
            voltage_record = {}
            threshold_value = {}
            for l in self.network.layers:
                if hasattr(self.network.layers[l], "v"):
                    voltage_record[l] = self.network.monitors[f"{l}_voltages"].get("v")
                if hasattr(self.network.layers[l], "thresh"):
                    threshold_value[l] = self.network.layers[l].thresh
    
            return voltage_record, threshold_value
    
        def step_(self, batch: Any, **kwargs) -> Any:
            # language=rst
            """
            Perform a pass of the network given the input batch.
    
            :param batch: The current batch. This could be anything as long as the subclass
                agrees upon the format in some way.
            :return: Any output that is need for recording purposes.
            """
            raise NotImplementedError("You need to provide a step_ method.")
    
        def train(self) -> None:
            # language=rst
            """
            A fully self-contained training loop.
            """
            raise NotImplementedError("You need to provide a train method.")
    
        def test(self) -> None:
            # language=rst
            """
            A fully self contained test function.
            """
            raise NotImplementedError("You need to provide a test method.")
    
        def init_fn(self) -> None:
            # language=rst
            """
            Placeholder function for subclass-specific actions that need to
            happen during the construction of the ``BasePipeline``.
            """
            raise NotImplementedError("You need to provide an init_fn method.")
    
        def plots(self, batch: Any, step_out: Any) -> None:
            # language=rst
            """
            Create any plots and logs for a step given the input batch and step output.
    
            :param batch: The current batch. This could be anything as long as the subclass
                agrees upon the format in some way.
            :param step_out: The output from the ``step_()`` method.
            """
            raise NotImplementedError("You need to provide a plots method.")
  • 相关阅读:
    java Android get date before 7 days (one week) Stack Overflow
    计算机网络与分布式系统实验室 北京大学
    得到IFrame中的Document
    eclipse如何把多个项目放在一个文件夹下
    windows 32位程序编译成64位
    iPhone5和iOS6上HTML5开发的新增功能
    Thinking in Java之接口回调改版
    Java学习笔记35:Java常用字符串操作函数
    进一步优化Bitmap Cache策略
    微软安全新闻聚焦双周刊第三十期
  • 原文地址:https://www.cnblogs.com/lucifer1997/p/14346152.html
Copyright © 2011-2022 走看看