zoukankan      html  css  js  c++  java
  • Python使用C扩展介绍

    Python作为一种动态语言,使用C扩展的主要目的是加快程序的运行速度,一般有三种方式去实现:swig、Python/C API、ctypes,由于swig会增加额外的复杂性,这里只对后两种方式进行简单的介绍。

    1.Python/C API

    Python/C API由于可以在C代码中操作Python对象,使用范围更广。这里的例子针对python3作了些许修改,函数主要实现了对列表的求和,首先给出C文件如下:

    #include <Python.h>   
    #define PyInt_AsLong(x) (PyLong_AsLong((x))) 
    
    //实现的函数,addList_add: 模块名_函数名(python中)
    static PyObject* addList_add(PyObject* self, PyObject* args) 
    {
        PyObject* listObj;
       
        //解析参数
        if (! PyArg_ParseTuple( args, "O", &listObj ))
            return NULL;
    	/* 注:
    	若传入一个字符串,一个整数和一个Python列表,则这样写:
    	int n;
    	char *s;
    	PyObject* list;
    	PyArg_ParseTuple(args, "siO", &s, &n, &list);
    	*/
    
        long length = PyList_Size(listObj);  //获取长度
        int i, sum =0;
        for (i = 0; i < length; i++)
        {
            PyObject* temp = PyList_GetItem(listObj, i); //获取每个元素
            long elem = PyInt_AsLong(temp);              //将PyInt对象转换为长整型
            sum += elem;
        }
        return Py_BuildValue("i", sum); //长整型sum被转化为Python整形对象并返回给Python代码
    }
    
    //实现的函数的信息表,每行一个函数,以空行作为结束
    static PyMethodDef addList_funcs[] = 
    {
        {"add", (PyCFunction)addList_add, METH_VARARGS, "Add all elements of the list."},
        {NULL, NULL, 0, NULL}
    };
    
    //模块定义结构
    static struct PyModuleDef addList_module = {
        PyModuleDef_HEAD_INIT,
        "addList",   /* 模块名 */
        "",          /* 模块文档 */
        -1,          /* size of per-interpreter state of the module,
                     or -1 if the module keeps state in global variables. */
        addList_funcs
    };
    
    //模块初始化
    PyMODINIT_FUNC PyInit_addList(void)
    {
        return PyModule_Create(&addList_module);
    }
    

    然后编写setup.py如下,并执行命令:sudo python3 setup.py install,完成对python模块的安装。

    from distutils.core import setup, Extension
    setup(name='addList', version='1.0', ext_modules=[Extension('addList', ['addList.c'])])
    

    最后便可以在python中使用该模块:

    import addList
    l = [1,2,3,4,5]
    print(str(addList.add(l)))
    

    2.ctypes

    ctypes使用方法简单,需要进行数据类型转换(除了字符串型和整型),比较适合轻量的快速开发环境。

    2.1基本使用

    //test.c
    #include <stdio.h>
    
    int add_int(int num1, int num2)
    {
        return num1 + num2;
    }
    
    float add_float(float num1, float num2)
    {
        return num1 + num2;
    }
    
    char* str_print(char *str)  
    {  
        puts(str);  
        return str;  
    } 
    
    float add_float_list(float* num, int length)
    {
        float sum=0;
        for(int i=0; i<length; i++)
        {
            sum += num[i];
        }
        return sum;
    }
    
    void str_list_print(char** str_list, int length)
    {
        for(int i=0; i<length; i++)
        {
            puts(str_list[i]);
        }
    }
    

    将test.c编译为动态库:gcc test.c --shared -fPIC -o test.so,然后python可以直接加载动态库并调用其中的函数:

    from ctypes import *
    
    lib = CDLL('./test.so')   # 加载.so动态库
    
    # 传整数
    res_int = lib.add_int(4,5)
    print("Sum of 4 and 5 = " + str(res_int))
    
    # 传浮点数
    a = c_float(5.5)                 # 浮点数需要先进行转换
    b = c_float(4.1)
    lib.add_float.restype = c_float  # 返回值类型转换
    print("Sum of 5.5 and 4.1 = ", str(lib.add_float(a, b)))
    
    # 传字符串
    lib.str_print.restype = c_char_p
    res_str = lib.str_print(b'Hello')
    print(res_str)
    
    # 传浮点数列表
    c = (c_float*5)()
    for i in range(len(c)):
        c[i] = i
    lib.add_float_list.restype = c_float
    print("Sum of list = ", str(lib.add_float_list(c, len(c))))
    
    # 传字符串列表
    d = (c_char_p*3)()
    for i in range(len(d)):
        d[i] = b'HELLO'
    lib.str_list_print(d, len(d))
    

    2.2结合指针使用

    结合指针可以就地改变变量的值,具体使用如下,test.c:

    #include <stdio.h>
    
    void one_ptr_func(int* num, int length)
    {
        for(int i=0; i<length; i++)
        {
            num[i] *= 2;
        }
    }
    
    void two_ptr_func(int** num, int row, int column)
    {
        for(int i=0; i<row; i++)
        {
            for(int j=0; j<column; j++)
            {
                num[i][j] *= 2;
            }    
        }
    }
    
    void three_ptr_func(int*** num, int x, int row, int column)
    {
        for(int a=0; a<x; a++){
            for(int i=0; i<row; i++){
                for(int j=0; j<column; j++){     
                    num[a][i][j] *= 2;
                }    
            }
        }
    }
    
    

    和上述一样,先编译出动态库,在python端使用方法如下:

    from ctypes import *
    
    lib = CDLL('./test.so')   # 加载.so动态库
    
    # 一级指针
    data = [1,2,3,4,5]
    one_arr = (c_int*5)(*data)              #一维数组
    one_ptr = cast(one_arr, POINTER(c_int)) #一维数组转换为一级指针
    
    lib.one_ptr_func(one_ptr, 5)
    for i in range(5):
        print(one_ptr[i], end=' ')
    print("
    ")
    
    # 二级指针
    data = [(1,2,3), (4,5,6)]
    two_arr = (c_int*3*2)(*data)                  #二维数组
    one_ptr_list = [] 
    for i in range(2):
        one_ptr_list.append(cast(two_arr[i], POINTER(c_int))) #一级指针添加进列表
    two_ptr=(POINTER(c_int)*2)(*one_ptr_list)                 #转化为二级指针
    
    lib.two_ptr_func(two_ptr, 2, 3)
    for i in range(2):
        for j in range(3):
            print(two_ptr[i][j], end=' ')
        print("
    ")
    
    # 三级指针
    data =[((1,2,3), (4,5,6)),((7,8,9), (10,11,12))]      #2x2x3
    three_arr =(c_int*3*2*2)(*data)                       #三维数组
    two_ptr=[]
    for i in range(2):
        one_ptr=[]
        for j in range(2):
            one_ptr.append(cast(three_arr[i][j], POINTER(c_int))) #一级指针添加进列表   
        two_ptr.append((POINTER(c_int)*2)(*one_ptr))              #转化为二级指针添加进列表                     
    three_ptr = (POINTER(POINTER(c_int))*2)(*two_ptr)             #转换为三级指针
    
    lib.three_ptr_func(three_ptr, 2, 2, 3)
    for i in range(2):
        for j in range(2):
            for k in range(3):
                print(three_ptr[i][j][k], end=' ')
            print("
    ")
    

    2.3结合numpy使用

    有时候需要对numpy数组进行操作,基本使用方法和上面差不多,仍然使用上面的动态库,python端仅需作如下修改:

    from ctypes import *
    import numpy as np
    
    lib = CDLL('./test.so')   # 加载.so动态库
    
    a = np.asarray(range(16), dtype=np.int32)
    if not a.flags['C_CONTIGUOUS']:
        a = np.ascontiguous(a, dtype=data.dtype)   # 如果不是C连续的内存,必须强制转换
    ptr = cast(a.ctypes.data, POINTER(c_int))      # 转换为一级指针
    for i in range(16):
        print(ptr[i], end=' ')
    print('
    ')
    
    lib.one_ptr_func(ptr, 16)   
    for i in range(16):
        print(a[i], end=' ')  # 注意此时变量a也被改变了
    print('
    ')
    

    另外numpy还提供了numpy.ctypeslib的解决方案:

    from ctypes import *
    import numpy as np
    import numpy.ctypeslib as npct
    
    a = np.ones((3, 3, 3), np.int32)
    lib = npct.load_library("test", ".")                 #引入动态链接库
    lib.np_ptr_func.argtypes = [npct.ndpointer(dtype=np.int32, ndim=3, flags="C_CONTIGUOUS"), c_int, c_int, c_int]          #参数说明
    lib.np_ptr_func(a, c_int(3), c_int(3), c_int(3))     #函数调用
    
    for i in range(3):
        for j in range(3):
            for k in range(3):
                print(a[i][j][k], end=' ')  
            print('
    ')
    

    其中np_ptr_func()函数如下所示:

    void np_ptr_func(int* num, int x, int row, int column)
    {
        for(int a=0; a<x; a++){
            for(int i=0; i<row; i++){
                for(int j=0; j<column; j++){     
                    num[a*row*column+i*column+j] *= 2;
                }    
            }
        }
    } 
    
  • 相关阅读:
    (BFS 二叉树) leetcode 515. Find Largest Value in Each Tree Row
    (二叉树 BFS) leetcode513. Find Bottom Left Tree Value
    (二叉树 BFS DFS) leetcode 104. Maximum Depth of Binary Tree
    (二叉树 BFS DFS) leetcode 111. Minimum Depth of Binary Tree
    (BFS) leetcode 690. Employee Importance
    (BFS/DFS) leetcode 200. Number of Islands
    (最长回文子串 线性DP) 51nod 1088 最长回文子串
    (链表 importance) leetcode 2. Add Two Numbers
    (链表 set) leetcode 817. Linked List Components
    (链表 双指针) leetcode 142. Linked List Cycle II
  • 原文地址:https://www.cnblogs.com/qxcheng/p/12541186.html
Copyright © 2011-2022 走看看