zoukankan      html  css  js  c++  java
  • LMDB数据库加速Pytorch文件读取速度

    问题背景

    训练深度学习模型往往需要大规模的数据集,这些数据集往往无法直接一次性加载到计算机的内存中,通常需要分批加载。数据的I/O很可能成为训练深度网络模型的瓶颈,因此数据的读取速度对于大规模的数据集(几十G甚至上千G)是非常关键的。例如:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

    采用数据库能够大大提升数据的读写速度,例如caffe支持从lmdb、leveldb文件读取训练样本。

    lmdb和leveldb的使用方式差不多,Leveldb lmdb性能对比。但是,数据集转换为LMDB或leveldb之后文件会变大(数据以二进制形式保存),即采用空间换取时间效率。

    caffe先支持leveldb,后支持lmdb。lmdb读取的效率更高,而且支持不同程序同时读取,而leveldb只允许一个程序读取。这一点在使用同样的数据跑不同的配置程序时很重要。

    lmdb

    参考:https://zhuanlan.zhihu.com/p/70359311

    LMDB 全称为 Lightning Memory-Mapped Database,就是非常快的内存映射型数据库,LMDB使用内存映射文件,可以提供更好的输入/输出性能,对于用于神经网络的大型数据集( 比如 ImageNet ),可以将其存储在 LMDB 中。

    因为最开始 Caffe 就是使用的这个数据库,所以网上的大多数关于 LMDB 的教程都通过 Caffe 实现的,对于不了解 Caffe 的同学很不友好,所以本篇文章只讲解 LMDB。

    LMDB属于key-value数据库,而不是关系型数据库( 比如 MySQL ),LMDB提供 key-value 存储,其中每个键值对都是我们数据集中的一个样本。LMDB的主要作用是提供数据管理,可以将各种各样的原始数据转换为统一的key-value存储。

    LMDB效率高的一个关键原因是它是基于内存映射的,这意味着它返回指向键和值的内存地址的指针,而不需要像大多数其他数据库那样复制内存中的任何内容。

    LMDB不仅可以用来存放训练和测试用的数据集,还可以存放神经网络提取出的特征数据。如果数据的结构很简单,就是大量的矩阵和向量,而且数据之间没有什么关联,数据内没有复杂的对象结构,那么就可以选择LMDB这个简单的数据库来存放数据。

    LMDB的文件结构很简单,一个文件夹,里面是一个数据文件和一个锁文件。数据随意复制,随意传输。它的访问简单,不需要单独的数据管理进程。只要在访问代码里引用LMDB库,访问时给文件路径即可。

    用LMDB数据库来存放图像数据,而不是直接读取原始图像数据的原因:

    • 数据类型多种多样,比如:二进制文件、文本文件、编码后的图像文件jpeg、png等,不可能用一套代码实现所有类型的输入数据读取,因此通过LMDB数据库,转换为统一数据格式可以简化数据读取层的实现。
    • lmdb具有极高的存取速度,大大减少了系统访问大量小文件时的磁盘IO的时间开销。LMDB将整个数据集都放在一个文件里,避免了文件系统寻址的开销,你的存储介质有多快,就能访问多快,不会因为文件多而导致时间长。LMDB使用了内存映射的方式访问文件,这使得文件内寻址的开销大幅度降低。

    LMDB 的基本函数

    • env = lmdb.open():创建 lmdb 环境
    • txn = env.begin():建立事务
    • txn.put(key, value):进行插入和修改
    • txn.delete(key):进行删除
    • txn.get(key):进行查询
    • txn.cursor():进行遍历
    • txn.commit():提交更改

    创建一个 lmdb 环境:

    1 import lmdb
    2 
    3 env = lmdb.open('D:/desktop/lmdb', map_size=10*1024**2)

    指定存放生成的lmdb数据库的文件夹路径,如果没有该文件夹则自动创建。

    map_size 指定创建的新数据库所需磁盘空间的最小值,1099511627776B=1T。可以在这里进行 存储单位换算

    会在指定路径下创建 data.mdb 和 lock.mdb 两个文件,一是个数据文件,一个是锁文件。

    修改数据库内容:

     1 # 创建一个事务Transaction对象
     2 txn = env.begin(write=True)
     3 
     4 # insert/modify
     5 # txn.put(key, value)
     6 txn.put(str(1).encode(), "Alice".encode()) # .encode()编码为字节bytes格式
     7 txn.put(str(2).encode(), "Bob".encode())
     8 txn.put(str(3).encode(), "Jack".encode())
     9 
    10 # delete
    11 # txn.delete(key)
    12 txn.delete(str(1).encode())
    13 
    14 # 提交待处理的事务
    15 txn.commit()

    先创建一个事务(transaction) 对象 txn,所有的操作都必须经过这个事务对象。因为我们要对数据库进行写入操作,所以将 write 参数置为 True,默认其为 False

    使用 .put(key, value) 对数据库进行插入和修改操作,传入的参数为键值对。

    值得注意的是,需要在键值字符串后加 .encode() 改变其编码格式,将 str 转换为 bytes 格式,否则会报该错误:TypeError: Won't implicitly convert Unicode to bytes; use .encode()。在后面使用 .decode() 对其进行解码得到原数据。

    使用 .delete(key) 删除指定键值对。

     对LMDB的读写操作在事务中执行,需要使用 commit 方法提交待处理的事务。

     

    查询数据库内容:

    1 # 数据库查询
    2 txn = env.begin() # 每个commit()之后都需要使用begin()方法更新txn得到最新数据库
    3 
    4 print(txn.get(str(2).encode()))
    5 
    6 for key, value in txn.cursor():
    7     print(key, value)
    8 
    9 env.close

    每次 commit() 之后都要用 env.begin() 更新 txn(得到最新的lmdb数据库)。

    使用 .get(key) 查询数据库中的单条记录。

    使用 .cursor() 遍历数据库中的所有记录,其返回一个可迭代对象,相当于关系数据库中的游标,每读取一次,游标下移一位。

    也可以想文件一样使用 with 语法:

    1 # 可以像文件一样使用with语法
    2 with env.begin() as txn:
    3     print(txn.get(str(2).encode()))
    4 
    5     for key, value in txn.cursor():
    6         print(key, value)
    7 env.close

    完整的demo如下:

     1 import lmdb
     2 import os, sys
     3 
     4 def initialize(lmdb_dir, map_size):
     5     # map_size: bytes
     6     env = lmdb.open(lmdb_dir, map_size)
     7     return env
     8 
     9 def insert(env, key, value):
    10     txn = env.begin(write=True)
    11     txn.put(str(key).encode(), value.encode())
    12     txn.commit()
    13 
    14 def delete(env, key):
    15     txn = env.begin(write=True)
    16     txn.delete(str(key).encode())
    17     txn.commit()
    18 
    19 def update(env, key, value):
    20     txn = env.begin(write=True)
    21     txn.put(str(key).encode(), value.encode())
    22     txn.commit()
    23 
    24 def search(env, key):
    25     txn = env.begin()
    26     value = txn.get(str(key).encode())
    27     return value
    28 
    29 def display(env):
    30     txn = env.begin()
    31     cursor = txn.cursor()
    32     for key, value in cursor:
    33         print(key, value)
    34 
    35 
    36 if __name__ == '__main__':
    37     path = 'D:/desktop/lmdb'
    38     env = initialize(path, 10*1024*1024)
    39 
    40     print("Insert 3 records.")
    41     insert(env, 1, "Alice")
    42     insert(env, 2, "Bob")
    43     insert(env, 3, "Peter")
    44     display(env)
    45 
    46     print("Delete the record where key = 1")
    47     delete(env, 1)
    48     display(env)
    49 
    50     print("Update the record where key = 3")
    51     update(env, 3, "Mark")
    52     display(env)
    53 
    54     print("Get the value whose key = 3")
    55     name = search(env, 3)
    56     print(name)
    57 
    58     # 最后需要关闭lmdb数据库
    59     env.close()
    View Code

    图片数据示例

    在图像深度学习训练中我们一般都会把大量原始数据集转化为lmdb格式以方便后续的网络训练。因此我们也需要对该数据集进行lmdb格式转化。

    将图片和对应的文本标签存放到lmdb数据库:

     1 import lmdb
     2 
     3 image_path = './cat.jpg'
     4 label = 'cat'
     5 
     6 env = lmdb.open('lmdb_dir')
     7 cache = {}  # 存储键值对
     8 
     9 with open(image_path, 'rb') as f:
    10     # 读取图像文件的二进制格式数据
    11     image_bin = f.read()
    12 
    13 # 用两个键值对表示一个数据样本
    14 cache['image_000'] = image_bin
    15 cache['label_000'] = label
    16 
    17 with env.begin(write=True) as txn:
    18     for k, v in cache.items():
    19         if isinstance(v, bytes):
    20             # 图片类型为bytes
    21             txn.put(k.encode(), v)
    22         else:
    23             # 标签类型为str, 转为bytes
    24             txn.put(k.encode(), v.encode())  # 编码
    25 
    26 env.close()
    View Code

    这里需要获取图像文件的二进制格式数据,然后用两个键值对保存一个数据样本,即分开保存图片和其标签。

    然后分别将图像和标签写入到lmdb数据库中,和上面例子一样都需要将键值转换为 bytes 格式。因为此处读取的图片格式本身就为 bytes,所以不需要转换,标签格式为 str,写入数据库之前需要先进行编码将其转换为 bytes

    从lmdb数据库中读取图片数据:

     1 import cv2
     2 import lmdb
     3 import numpy as np
     4 
     5 env = lmdb.open('lmdb_dir')
     6 
     7 with env.begin(write=False) as txn:
     8     # 获取图像数据
     9     image_bin = txn.get('image_000'.encode())
    10     label = txn.get('label_000'.encode()).decode()  # 解码
    11 
    12     # 将二进制文件转为十进制文件(一维数组)
    13     image_buf = np.frombuffer(image_bin, dtype=np.uint8)
    14     # 将数据转换(解码)成图像格式
    15     # cv2.IMREAD_GRAYSCALE为灰度图,cv2.IMREAD_COLOR为彩色图
    16     img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
    17     cv2.imshow('image', img)
    18     cv2.waitKey(0)
    View Code

    先通过 lmdb.open() 获取之前创建的lmdb数据库。

    这里通过键得到图片和其标签,因为写入数据库之前进行了编码,所以这里需要先解码。

    • 标签通过 .decode() 进行解码重新得到字符串格式。
    • 读取到的图片数据为二进制格式,所以先使用 np.frombuffer() 将其转换为十进制格式的文件,这是一维数组。然后可以使用 cv2.imdecode() 将其转换为灰度图(二维数组)或者彩色图(三维数组)。

    leveldb

    leveldb的使用与lmdb差不多,然而LevelDB 是单进程的服务。

    https://www.jianshu.com/p/66496c8726a1

    https://github.com/liquidconv/py4db

    https://github.com/google/leveldb

     1 #!/usr/bin/env python
     2 
     3 import leveldb
     4 import os, sys
     5 
     6 def initialize():
     7     db = leveldb.LevelDB("students");
     8     return db;
     9 
    10 def insert(db, sid, name):
    11     db.Put(str(sid), name);
    12 
    13 def delete(db, sid):
    14     db.Delete(str(sid));
    15 
    16 def update(db, sid, name):
    17     db.Put(str(sid), name);
    18 
    19 def search(db, sid):
    20     name = db.Get(str(sid));
    21     return name;
    22 
    23 def display(db):
    24     for key, value in db.RangeIter():
    25         print (key, value);
    26 
    27 db = initialize();
    28 
    29 print "Insert 3 records."
    30 insert(db, 1, "Alice");
    31 insert(db, 2, "Bob");
    32 insert(db, 3, "Peter");
    33 display(db);
    34 
    35 print "Delete the record where sid = 1."
    36 delete(db, 1);
    37 display(db);
    38 
    39 print "Update the record where sid = 3."
    40 update(db, 3, "Mark");
    41 display(db);
    42 
    43 print "Get the name of student whose sid = 3."
    44 name = search(db, 3);
    45 print name;
    View Code

     pytorch从lmdb中加载数据

    这里给出一种pytorch从lmdb中加载数据的参考示例,来自:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

    需要指出的是,pytorch的Dataset并不支持lmdb的迭代器。Dataset通过__getitem__(index)方法通过index获取一个样本,因此无法整合lmdb的cursor进行遍历,只能通过

    with self.data_env.begin() as f: data = f.get(key)的方式,即每次打开一个事务txn,这会降低读取速度。如果设置shuffle=False则可以利用cursor按顺序遍历。

     1 from __future__ import print_function
     2 import torch.utils.data as data
     3 # import h5py
     4 import numpy as np
     5 import lmdb
     6 
     7 
     8 class onlineHCCR(data.Dataset):
     9     def __init__(self, train=True):
    10         # self.root = root
    11         self.train = train
    12 
    13         if self.train:
    14             datalmdb_path = 'traindata_lmdb'
    15             labellmdb_path = 'trainlabel_lmdb'
    16             self.data_env = lmdb.open(datalmdb_path, readonly=True)
    17             self.label_env = lmdb.open(labellmdb_path, readonly=True)
    18 
    19         else:
    20             datalmdb_path = 'testdata_lmdb'
    21             labellmdb_path = 'testlabel_lmdb'
    22             self.data_env = lmdb.open(datalmdb_path, readonly=True)
    23             self.label_env = lmdb.open(labellmdb_path, readonly=True)
    24 
    25 
    26     def __getitem__(self, index):
    27 
    28         Data = []
    29         Target = []
    30 
    31         if self.train:
    32             with self.data_env.begin() as f:
    33                 key = '{:08}'.format(index)
    34                 data = f.get(key)
    35                 flat_data = np.fromstring(data, dtype=float)
    36                 data = flat_data.reshape(150, 6).astype('float32')
    37                 Data = data
    38 
    39             with self.label_env.begin() as f:
    40                 key = '{:08}'.format(index)
    41                 data = f.get(key)
    42                 label = np.fromstring(data, dtype=int)
    43                 Target = label[0]
    44 
    45         else:
    46 
    47             with self.data_env.begin() as f:
    48                 key = '{:08}'.format(index)
    49                 data = f.get(key)
    50                 flat_data = np.fromstring(data, dtype=float)
    51                 data = flat_data.reshape(150, 6).astype('float32')
    52                 Data = data
    53 
    54             with self.label_env.begin() as f:
    55                 key = '{:08}'.format(index)
    56                 data = f.get(key)
    57                 label = np.fromstring(data, dtype=int)
    58                 Target = label[0]
    59 
    60         return Data, Target
    61         
    62 
    63     def __len__(self):
    64         if self.train:
    65             return 2693931
    66         else:
    67             return 224589

     另一个示例:

     1 # https://github.com/pytorch/vision/blob/master/torchvision/datasets/lsun.py#L19-L20
     2 
     3 class LSUNClass(VisionDataset):
     4     def __init__(self, root, transform=None, target_transform=None):
     5         import lmdb
     6         super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform)
     7 
     8         self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
     9         with self.env.begin(write=False) as txn:
    10             self.length = txn.stat()['entries']
    11         cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters)
    12         if os.path.isfile(cache_file):
    13             self.keys = pickle.load(open(cache_file, "rb"))
    14         else:
    15             with self.env.begin(write=False) as txn:
    16                 self.keys = [key for key, _ in txn.cursor()]
    17             pickle.dump(self.keys, open(cache_file, "wb"))
    18 
    19     def __getitem__(self, index):
    20         img, target = None, None
    21         env = self.env
    22         with env.begin(write=False) as txn:
    23             imgbuf = txn.get(self.keys[index])
    24 
    25         buf = io.BytesIO()
    26         buf.write(imgbuf)
    27         buf.seek(0)
    28         img = Image.open(buf).convert('RGB')
    29 
    30         if self.transform is not None:
    31             img = self.transform(img)
    32 
    33         if self.target_transform is not None:
    34             target = self.target_transform(target)
    35 
    36         return img, target
    37 
    38     def __len__(self):
    39         return self.length

    参考:

    lmdb 数据库

    Python操作SQLite/MySQL/LMDB/LevelDB

    https://github.com/liquidconv/py4db

    https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

    https://www.programcreek.com/python/example/106501/lmdb.open

    https://realpython.com/storing-images-in-python/

     https://www.cnblogs.com/skyfsm/p/10345305.html

  • 相关阅读:
    虚函数
    类的继承
    析构
    构造
    枚举类型
    c++中的静态类型 static
    c++中的类
    sizeof和strlen的区别
    剑指36 二叉搜索书与双向链表
    剑指35 复杂链表的复制
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/13192518.html
Copyright © 2011-2022 走看看