zoukankan      html  css  js  c++  java
  • C++ 模板实现败者树,进行多路归并

    项目需要实现一个败者树,今天研究了一下,附上实现代码。

    几点说明:

    1. 败者树思想及实现参考这里:http://www.cnblogs.com/benjamin-t/p/3325401.html

    2. 多路归并中的“多路”的容器使用的是C语言数组 + 数组长度的实现(即

    const ContainerType* ways, size_t num

    ),而没有用STL中的容器,这是因为项目需要如此,日后再改成STL容器;

    3. _losers 存储下标,用的是 int 类型,还需要修改。程序中其他下标类型都是 size_t,但是这个 _losers 存的下标需要使用 -1 表示无效。

    4. Foo 还不能用在 std::copy上,待修正;

    5. 使用了 FooContainer<Foo> 类型以后,输出的时候不能直接输出,必须定义一个变量再输出,不知道为什么:

                //std::cout << data[i][j] << ", ";  
                Foo foo = data[i][j];
                std::cout << foo << ", ";

    6. 把 const 变量作为 non-type template parameter 时,必须把该 const 变量定义在全局,并且加 extern,原因见这里:http://stackoverflow.com/questions/9183485/const-variable-as-non-type-template-parameter-variable-cannot-appear-in-a-const

    7. 代码放在Github上:https://github.com/qinpeixi/code-pieces/blob/master/loser_tree.cpp

    8. 一个专业级的实现在这里:https://github.com/MITRECND/snugglefish/blob/master/include/loserTree.hpp

    代码如下:

    #include <iostream>
    #include <vector>
    #include <iterator>
    #include <string>
    #include <sstream>
    #include <cstdlib>
    #include <cassert>
    #include <stdexcept>
    
    class Foo {
    public:
        Foo() {}
        explicit Foo(int v): _v(v) {}
        Foo(const Foo& foo) { _v = foo._v; }
        int value() const { return _v; }
        Foo& operator=(const Foo& foo) { _v = foo._v; return *this; }
        bool operator==(const Foo& foo) { return _v == foo._v; }
    
    private:
        int _v;
    };
    std::ostream& operator<<(std::ostream& os, Foo& foo) {
        return os << foo.value();
    }
    
    extern const Foo FOO_MAX(INT_MAX);
    
    namespace  std {
    template<>
    class less<Foo> : std::binary_function<Foo, Foo, bool>
    {
    public:
        bool operator() (const Foo& x, const Foo& y) const {
            return x.value() < y.value();
        }
    };
    } // namespace std
    
    template<class ValueType>
    class FooContainer
    {
    public:
        ValueType operator[](size_t idx) const { return _container[idx]; }
        size_t size() const { return _container.size(); }
        void push_back(const ValueType& value) {
            _container.push_back(value);
        }
    
    private:
        std::vector<ValueType> _container;
    };
    
    template< class ValueType,
              class ContainerType,
              const ValueType& forever_lose_value,
              class Compare = std::less<ValueType> >
    class LoserTree
    {
    public:
        LoserTree(const ContainerType* ways, size_t num) :
            _num(num), _ways(ways), _indexes(new size_t[_num]),
            _data(new ValueType[_num]), _losers(new int[_num])
        {
            if (ways == NULL || num == 0) {
                delete[] _indexes;
                delete[] _data;
                delete[] _losers;
                throw std::invalid_argument("invalid ways or number of ways");
            }
            std::fill(_indexes, _indexes + _num, 0);
            std::fill(_losers, _losers + _num, -1);
            for (int way_idx = _num-1; way_idx >= 0; --way_idx) {
                if (_indexes[way_idx] == _ways[way_idx].size()) {
                    _data[way_idx] = forever_lose_value;
                } else {
                    _data[way_idx] = _ways[way_idx][_indexes[way_idx]];
                }
                adjust(way_idx);
            }
        }
    
        ~LoserTree() {
            delete[] _indexes;
            delete[] _losers;
            delete[] _data;
        }
    
        bool extract_one(ValueType& v) {
            int way_idx = _losers[0];
            if (_data[way_idx] == forever_lose_value)
                return false;
            v = _data[way_idx];
            if (++_indexes[way_idx] == _ways[way_idx].size()) {
                _data[way_idx] = forever_lose_value;
            } else {
                _data[way_idx] = _ways[way_idx][_indexes[way_idx]];
            }
            adjust(way_idx);
            return true;
        }
    
    private:
        size_t _num;
        const ContainerType* _ways;
        size_t* _indexes;
        ValueType* _data;
        int* _losers;
    
        void adjust(int winner_idx) {
            using std::swap;
            // _losers[loser_idx_idx] is the index of the loser in _data
            int loser_idx_idx = (winner_idx + _num) / 2;
            while (loser_idx_idx != 0 && winner_idx != -1) {
                if (_losers[loser_idx_idx] == -1 ||
                        !Compare()(_data[winner_idx],  _data[_losers[loser_idx_idx]]))
                    swap(winner_idx,_losers[loser_idx_idx]);
                loser_idx_idx /= 2;
            }
            _losers[0] = winner_idx;
        }
    };
    
    /*
     * input format:
     * 1 10 100 1000
     * 2 20 200 2000
     * 3 30 300
     * 4 40 400 4000 40000
     */
    std::vector<std::vector<int> > get_input()
    {
        std::vector<std::vector<int> > data;
        std::string line;
        while (std::getline(std::cin, line)) {
            std::vector<int> tmp_data;
            std::istringstream iss(line);
            std::copy(std::istream_iterator<int>(iss), std::istream_iterator<int>(), std::back_inserter(tmp_data));
            data.push_back(tmp_data);
        }
    
        for (size_t i = 0; i < data.size(); ++i) {
            std::copy(data[i].begin(), data[i].end(), std::ostream_iterator<int>(std::cout, ", "));
            std::cout << std::endl;
        }
    
        return data;
    }
    
    template<class ValueType, class ContainerType>
    std::vector<ContainerType> generate_data()
    {
        const int way_num = 20;
        std::vector<ContainerType> data(way_num);
        for (int num = 0; num < 10/*100000*/; ++num) {
            data[rand() % way_num].push_back(ValueType(num));
        }
    
        return data;
    }
    
    void test_foo()
    {
        std::vector<FooContainer<Foo> > data = generate_data<Foo, FooContainer<Foo> >();
        /*
        for (size_t i = 0; i < data.size(); ++i) {
            for (size_t j = 0; j < data[i].size(); ++j) {
                //std::cout << data[i][j] << ", ";
                Foo foo = data[i][j];
                std::cout << foo << ", ";
            }
            std::cout << std::endl;
        }
        */
    
        LoserTree<Foo, FooContainer<Foo>, FOO_MAX> lt(data.data(), data.size());
        Foo v;
        Foo correct_res(0);
        while(lt.extract_one(v)) {
            //assert(v == correct_res);
            //correct_res = Foo(correct_res.value()+1);
            std::cout << v.value() << ", ";
        }
        std::cout << std::endl;
    }
    
    extern const int int_max = INT_MAX;
    void test()
    {
        std::vector<std::vector<int> > data = generate_data<int ,std::vector<int> >();
        LoserTree<int, std::vector<int>, int_max> lt(data.data(), data.size());
        int v;
        int correct_res(0);
        while (lt.extract_one(v)) {
            assert(v == correct_res++);
            std::cout << v << ", ";
        }
        std::cout << std::endl;
    }
    
    int main()
    {
        try {
            //LoserTree<int, std::vector<int>, int_max> lt(NULL, 3);
            test_foo();
            test();
        } catch (const std::exception& exc){
            std::cerr << exc.what() << std::endl;
        }
    
        return 0;
    }
  • 相关阅读:
    PHP删除文件
    PHP定时执行任务
    PHP设置30秒内对页面的访问次数
    PHP抓取网页内容的几种方法
    QQ,新浪,SNS等公众平台的登录及api操作
    php,javascript设置和读取cookie
    php验证邮箱,手机号是否正确
    php自定义加密和解密
    Linux下安装启动多个Mysql
    linux-gcc 编译时头文件和库文件搜索路径
  • 原文地址:https://www.cnblogs.com/yding9/p/4032155.html
Copyright © 2011-2022 走看看