zoukankan      html  css  js  c++  java
  • Python使用ctypes模块调用C/C++

    最近在做图卷积相关的实验,里面涉及到图采样,该过程可以抽象为:从一个包含n个节点,m条边的图中根据一定规则采样一个连通图。由于实验使用的是FB15k-237数据集,共包含14541个节点,272115条边,每次采样30000条边,采样一次需要8s,这对于深度学习实验来说是难以接受的,会导致GPU长时间空闲。因此我开始尝试使用C/C++优化代码,虽然最后优化效果不行,但是也是对python调用C代码的一次学习,因此在此纪录一下。

    Python原代码

     def get_adj_and_degrees(num_nodes, triplets):
        """ Get adjacency list and degrees of the graph"""
        adj_list = [[] for _ in range(num_nodes)]
        for i, triplet in enumerate(triplets):
            adj_list[triplet[0]].append([i, triplet[2]])
            adj_list[triplet[2]].append([i, triplet[0]])
    
        degrees = np.array([len(a) for a in adj_list])
        adj_list = [np.array(a) for a in adj_list]
        return adj_list, degrees
    

    这里以get_adj_and_degrees函数为例,我们使用C/C++优化该函数。该函数只是起演示作用,具体细节不重要。

    C/C++实现代码

    我们在sampler.hpp中对该函数进行优化,该文件的定义如下:

    #ifndef SAMPLER_H
    #define SAMPLER_H
    
    #include <vector>
    #include "utility.hpp"
    
    using namespace std;
    
    // global graph data
    int num_node = 0;
    int num_edge = 0;
    vector<int> degrees; // shape=[N]
    vector<vector<vector<int>>> adj_list; // shape=[N, variable_size, 2]
    
    
    void build_graph(int* src, int* rel, int* dst, int num_node_m, int num_edge_m) {
        num_node = num_node_m;
        num_edge = num_edge_m;
    
        // resize the vectors
        degrees.resize(num_node);
        adj_list.resize(num_node);
    
        for (int i = 0; i < num_edge; i++) {
            int s = src[i];
            int r = rel[i];
            int d = dst[i];
    
            vector<int> p = {i, d};
            vector<int> q = {i, s};
            adj_list[s].push_back(p);
            adj_list[d].push_back(q);
        }
    
        for (int i = 0; i < num_node; i++) {
            degrees[i] = adj_list[i].size();
        }
    }
    
    #endif
    

    这里C/C++函数把结果作为全局变量进行存储,是为了后一步使用。具体的函数细节也不在讲述,因为我们的重点是如何用python调用。

    生成so库

    ctypes只能调用C函数,因此我们需要把上述C++函数导出为C函数。因此我们在lib.cpp中做出如下定义:

    #ifndef LIB_H
    #define LIB_H
    
    #include "sampler.hpp"
    
    extern "C" {
        void build_graph_c(int* src, int* rel, int* dst, int num_node, int num_edge) {
            build_graph(src, rel, dst, num_node, num_edge);
        }
    }
    
    #endif
    

    然后使用如下命令进行编译,为了优化代码,加上了O3march=native选项:

    g++ lib.cpp -fPIC -shared -o libsampler.so -O3 -march=native
    

    Python调用C/C++函数

    编译完之后,在当前目录下生成了libsampler.so库,我们就可以编写python代码调用C/C++函数了,Python代码如下:

    import numpy as np
    import time
    from ctypes import cdll, POINTER, Array, cast
    from ctypes import c_int
    
    
    class CPPLib:
        """Class for operating CPP library
    
        Attributes:
            lib_path: (str) the path of a library, e.g. 'lib.so.6'
        """
        def __init__(self, lib_path):
            self.lib = cdll.LoadLibrary(lib_path)
    
            IntArray = IntArrayType()
            self.lib.build_graph_c.argtypes = (IntArray, IntArray, IntArray, c_int, c_int)
            self.lib.build_graph_c.restype = None
    
        def build_graph(self, src, rel, dst, num_node, num_edge):
            self.lib.build_graph_c(src, rel, dst, num_node, num_edge)
    
    class IntArrayType:
        # Define a special type for the 'int *' argument
        def from_param(self, param):
            typename = type(param).__name__
            if hasattr(self, 'from_' + typename):
                return getattr(self, 'from_' + typename)(param)
            elif isinstance(param, Array):
                return param
            else:
                raise TypeError("Can't convert %s" % typename)
    
        # Cast from array.array objects
        def from_array(self, param):
            if param.typecode != 'i':
                raise TypeError('must be an array of doubles')
            ptr, _ = param.buffer_info()
            return cast(ptr, POINTER(c_int))
    
        # Cast from lists/tuples
        def from_list(self, param):
            val = ((c_int) * len(param))(*param)
            return val
    
        from_tuple = from_list
    
        # Cast from a numpy array
        def from_ndarray(self, param):
            return param.ctypes.data_as(POINTER(c_int))
    

    总结

    python使用ctypes库调用C/C++函数本身不难,但是优化代码确是一个深坑,尤其是优化Numpy等科学计算库时。因为这些库本身已经进行了大量优化,自己使用C++实现的话,很有可能就比优化前还更差。

  • 相关阅读:
    ajax学习笔记
    CSS3伪类
    《HTML5与CSS3基础教程》学习笔记 ——Four Day
    《HTML5与CSS3基础教程》学习笔记 ——Three Day
    《HTML5与CSS3基础教程》学习笔记 ——Two Day
    《HTML5与CSS3基础教程》学习笔记 ——One Day
    js面向对象笔记
    《锋利的jQuery》心得笔记--Four Sections
    《锋利的jQuery》心得笔记--Three Sections
    《锋利的jQuery》心得笔记--Two Sections
  • 原文地址:https://www.cnblogs.com/weilonghu/p/12122063.html
Copyright © 2011-2022 走看看