zoukankan      html  css  js  c++  java
  • c++runtime

    #include <iostream>
    #include <string>
    #include <vector>
    #include <unordered_map>
    
    namespace dl {
    
    struct Params {
        std::unordered_map<std::string, std::string> s;
        std::unordered_map<std::string, int> i;
        std::unordered_map<std::string, float> f;
        std::unordered_map<std::string, std::vector<int>> vi;
        std::unordered_map<std::string, std::vector<float>> vf;
    };
    
    struct Tensor {
        std::string dtype;
        std::vector<int> shape;
        void *data;
        int numel;
        int device;
    };
    
    std::unordered_map<std::string, void *> _func_d;
    std::unordered_map<std::string, Tensor> _tensor_d;
    
    #define REGISTER(FUNC, NAME) _func_d[NAME] = (void *)FUNC
    #define EXEC(NAME, PARAMS) 
        (*((void (*)(Params &))_func_d[NAME]))(PARAMS)
    #define TYPE(DATATYPE) typeid(DATATYPE).name()
    #define GLOBAL_TENSOR _tensor_d
    #define HANDLE_DTYPE(T, TYPENAME, FUNC, ...) 
        if (TYPE(T) == TYPENAME) {               
            using scalar_t = T;                  
            FUNC<scalar_t>(__VA_ARGS__);         
        }
    #define HANDLE_DTYPE2(T1, T2, ...) 
        HANDLE_DTYPE(T1, __VA_ARGS__)  
        HANDLE_DTYPE(T2, __VA_ARGS__)
    #define HANDLE_DTYPE3(T1, T2, T3, ...) 
        HANDLE_DTYPE(T1, __VA_ARGS__)      
        HANDLE_DTYPE(T2, __VA_ARGS__)      
        HANDLE_DTYPE(T3, __VA_ARGS__)
    
    template <typename scalar_t>
    void empty_cpu(std::string name, std::vector<int> shape) {
        Tensor t;
        t.dtype = TYPE(scalar_t);
        t.shape = shape;
        t.numel = 1;
        for (auto d : shape) t.numel *= d;
        t.data = (void *)(new char[t.numel * sizeof(scalar_t)]);
        GLOBAL_TENSOR[name] = t;
    }
    
    void empty(Params &p) {
        HANDLE_DTYPE2(int, float, p.s["dtype"], empty_cpu, p.s["name"], p.vi["shape"])
    }
    
    template <typename scalar_t>
    void linspace_cpu_(scalar_t *data, int numel) {
        for (int i = 0; i < numel; i++) data[i] = (scalar_t)i;
    }
    
    void linspace_(Params &p) {
        auto self = GLOBAL_TENSOR[p.s["name"]];
        auto numel = self.numel;
        HANDLE_DTYPE2(int, float, self.dtype, linspace_cpu_, (scalar_t *)self.data, numel)
    }
    
    } // namespace dl
    
    int main() {
        using namespace std;
        using namespace dl;
    
        REGISTER(empty, "empty");
        REGISTER(linspace_, "linspace_");
    
        Params p;
        p.s["name"] = "Tensor0";
        p.s["dtype"] = TYPE(float);
        p.vi["shape"] = std::vector<int>{2, 2};
        EXEC("empty", p);
    
        Params p2;
        p2.s["name"] = "Tensor0";
        EXEC("linspace_", p2);
    
        Tensor &t = GLOBAL_TENSOR["Tensor0"];
        float *ptr = (float *)t.data;
        for (int i = 0; i < t.numel; i++) {
            cout << ptr[i] << endl;
        }
    }
    
  • 相关阅读:
    Android 常见adb命令
    下载安装JDK,并且配置java环境变量
    安装黑苹果教程
    创建不死目录、不死文件
    win10下安装centos7双系统
    Hadoop 3.0完全分布式集群搭建方法(CentOS 7+Hadoop 3.2.0)
    Hadoop 2.0完全分布式集群搭建方法(CentOS7+Hadoop 2.7.7)
    启动HBase脚本start-hbase.sh时报Class path contains multiple SLF4J bindings.解决方法
    HQuorumPeer和QuorumPeerMain进程的区别
    Zookeeper集群安装与配置
  • 原文地址:https://www.cnblogs.com/xytpai/p/15511436.html
Copyright © 2011-2022 走看看