zoukankan      html  css  js  c++  java
  • Implement TensorFlow's next_batch for own data

    The version of numpy data

    import numpy as np
    
    class Dataset:
        def __init__(self, data):
            self._index_in_epoch = 0
            self._epochs_completed = 0
            self._data = data
            self._num_examples = data.shape[0]
            pass
    
        @property
        def data(self):
            return self._data
    
        def next_batch(self, batch_size, shuffle=True):
            start = self._index_in_epoch
            if start == 0 and self._epochs_completed == 0:
                idx = np.arange(0, self._num_examples)
                np.random.shuffle(idx)  # shuffle indexe
                self._data = self.data[idx]  # get the shuffled data
    
            # go to the data of next batch
            if start + batch_size > self._num_examples:
                '''
                note: when start  == self._num_examples, data_rest_part = np.array([])
                '''
                self._epochs_completed += 1
                # print(self.data)
                rest_num_examples = self._num_examples - start
                data_rest_part = self.data[start:self._num_examples]
                idx_update = np.arange(0, self._num_examples)
                np.random.shuffle(idx_update)
                self._data = self.data[idx_update]  # get another shuffled data
    
                start = 0
                self._index_in_epoch = batch_size - rest_num_examples
                end = self._index_in_epoch
                data_new_part = self._data[start:end]
                return np.concatenate((data_rest_part, data_new_part), axis=0)
            else:
                self._index_in_epoch += batch_size
                end = self._index_in_epoch
                return self._data[start:end]
    
    dataset = Dataset(np.arange(0, 10))
    for i in range(10):
        print(dataset.next_batch(6))
    print(dataset.data)
    

    The version of pandas data

    import numpy as np
    import pandas as pd
    class Dataset:
        def __init__(self, data):
            self._index_in_epoch = 0
            self._epochs_completed = 0
            self._data = data
            self._num_examples = data.shape[0]
            pass
    
        @property
        def data(self):
            return self._data
    
        def next_batch(self, batch_size, shuffle=True):
            start = self._index_in_epoch
            if start == 0 and self._epochs_completed == 0:
                idx = np.arange(0, self._num_examples)
                np.random.shuffle(idx)  # shuffle index
                self._data = self.data.iloc[idx,:]  # get the shuffled data
    
            # go to the data of next batch
            if start + batch_size > self._num_examples:
                '''
                note: when start  == self._num_examples, data_rest_part = np.array([])
                '''
                self._epochs_completed += 1
                # print(self.data) # this is for debug
                rest_num_examples = self._num_examples - start
                data_rest_part = self.data.iloc[start:self._num_examples,:]
                idx_update = np.arange(0, self._num_examples)
                np.random.shuffle(idx_update)
                self._data = self.data.iloc[idx_update,:]  # get another shuffled data
    
                start = 0
                self._index_in_epoch = batch_size - rest_num_examples
                end = self._index_in_epoch
                data_new_part = self._data.iloc[start:end,:]
                return pd.concat((data_rest_part, data_new_part), axis=0)
            else:
                self._index_in_epoch += batch_size
                end = self._index_in_epoch
                return self._data[start:end]
    
    df = pd.DataFrame()
    df['a']=np.arange(10)
    df['b']=np.arange(10)*10
    dataset = Dataset(df)
    for i in range(10):
        print(dataset.next_batch(5))
    print(dataset.data)
    
  • 相关阅读:
    1.2 JAVA的String类和StringBuffer类
    1.7 JAVA异常总结
    2.1 JQuery框架(封装JavaScript框架)
    1.6 JSON存储数据方式(JavaScript对象表示法)
    1.33 JavaScript之HTML的DOM(三)
    1.32 JavaScript的BOM(二)
    【转】SQL 生成连续字符
    木兰国产编程语言 Mulan--附带下载地址
    【python】两行代码实现近百年的正反日期查询--20200202
    Linux下扫描服务器IP地址是否冲突(arp-scan)
  • 原文地址:https://www.cnblogs.com/ZeroTensor/p/10394989.html
Copyright © 2011-2022 走看看