zoukankan      html  css  js  c++  java
  • Implement TensorFlow's next_batch for own data

    The version of numpy data

    import numpy as np
    
    class Dataset:
        def __init__(self, data):
            self._index_in_epoch = 0
            self._epochs_completed = 0
            self._data = data
            self._num_examples = data.shape[0]
            pass
    
        @property
        def data(self):
            return self._data
    
        def next_batch(self, batch_size, shuffle=True):
            start = self._index_in_epoch
            if start == 0 and self._epochs_completed == 0:
                idx = np.arange(0, self._num_examples)
                np.random.shuffle(idx)  # shuffle indexe
                self._data = self.data[idx]  # get the shuffled data
    
            # go to the data of next batch
            if start + batch_size > self._num_examples:
                '''
                note: when start  == self._num_examples, data_rest_part = np.array([])
                '''
                self._epochs_completed += 1
                # print(self.data)
                rest_num_examples = self._num_examples - start
                data_rest_part = self.data[start:self._num_examples]
                idx_update = np.arange(0, self._num_examples)
                np.random.shuffle(idx_update)
                self._data = self.data[idx_update]  # get another shuffled data
    
                start = 0
                self._index_in_epoch = batch_size - rest_num_examples
                end = self._index_in_epoch
                data_new_part = self._data[start:end]
                return np.concatenate((data_rest_part, data_new_part), axis=0)
            else:
                self._index_in_epoch += batch_size
                end = self._index_in_epoch
                return self._data[start:end]
    
    dataset = Dataset(np.arange(0, 10))
    for i in range(10):
        print(dataset.next_batch(6))
    print(dataset.data)
    

    The version of pandas data

    import numpy as np
    import pandas as pd
    class Dataset:
        def __init__(self, data):
            self._index_in_epoch = 0
            self._epochs_completed = 0
            self._data = data
            self._num_examples = data.shape[0]
            pass
    
        @property
        def data(self):
            return self._data
    
        def next_batch(self, batch_size, shuffle=True):
            start = self._index_in_epoch
            if start == 0 and self._epochs_completed == 0:
                idx = np.arange(0, self._num_examples)
                np.random.shuffle(idx)  # shuffle index
                self._data = self.data.iloc[idx,:]  # get the shuffled data
    
            # go to the data of next batch
            if start + batch_size > self._num_examples:
                '''
                note: when start  == self._num_examples, data_rest_part = np.array([])
                '''
                self._epochs_completed += 1
                # print(self.data) # this is for debug
                rest_num_examples = self._num_examples - start
                data_rest_part = self.data.iloc[start:self._num_examples,:]
                idx_update = np.arange(0, self._num_examples)
                np.random.shuffle(idx_update)
                self._data = self.data.iloc[idx_update,:]  # get another shuffled data
    
                start = 0
                self._index_in_epoch = batch_size - rest_num_examples
                end = self._index_in_epoch
                data_new_part = self._data.iloc[start:end,:]
                return pd.concat((data_rest_part, data_new_part), axis=0)
            else:
                self._index_in_epoch += batch_size
                end = self._index_in_epoch
                return self._data[start:end]
    
    df = pd.DataFrame()
    df['a']=np.arange(10)
    df['b']=np.arange(10)*10
    dataset = Dataset(df)
    for i in range(10):
        print(dataset.next_batch(5))
    print(dataset.data)
    
  • 相关阅读:
    PHP之简单实现MVC框架
    socket泄露的问题
    gdb 调试多线程
    MMAP和DIRECT IO区别
    三年回首:C基础
    定时器管理:nginx的红黑树和libevent的堆
    strsep和strtok_r替代strtok
    缓存穿透和缓存失效
    mmap为什么比read/write快(兼论buffercache和pagecache)
    B+Tree和MySQL索引分析
  • 原文地址:https://www.cnblogs.com/ZeroTensor/p/10394989.html
Copyright © 2011-2022 走看看