zoukankan      html  css  js  c++  java
  • 矩阵乘向量 基于哈希表的稀疏矩阵优化

    计算N*N 方阵A 乘以 向量B

    记录所需时间

    CPU:I5 750 @3.6G

    填充同样的 A 方阵数据 和 B向量 数据

    N的规模

    N = 1 << 12 = 4096

    稀疏矩阵有效数据填充率:30%

    demo1:

    使用数组计算

    time1 : 19ms

    demo2:

    使用Integer ,Float 包装类,二维稀疏矩阵

    time2 : 83ms

    demo3:

    使用基本类型的int 和 float,并只将一维的数组用散列表替换,一维稀疏矩阵

    行方向上,用一个数组

    time3: 44ms

    ======

    第二次尝试

    提高N = 1 << 13  = 8192

    稀疏矩阵有效数据填充率:30%

    time : 66ms

    time2:java.lang.OutOfMemoryError: GC overhead limit exceeded

    time3: 150ms

    ======

    第三次

    N = 1 << 13  = 8192

    稀疏矩阵有效数据填充率:5%

    time1 : 61ms

    time3: 47ms

    ====

    第四次

    稀疏矩阵有效数据填充率:2%


    time1 : 62ms

    time2 : 87ms

    time3: 22ms

    可以看到time3 得到了极大的提升,如果代码再优化的好一些,hash算法再快速一些,特别是对执行次数非常多语句,比如for中的语句,或循环嵌套中的语句做一些优化,时间成本和收益比会再初期非常大

    google的 PageRank

    使用了类似的方法减少时间复杂度,他的数据量是千亿级,但是矩阵中的稀疏度比较大,为了减少无意义的 0 x 0 和 sum += 0

    但是google用的是一种图(不是本例中一样的数据结构)

    suck code:

    import java.util.Random;
    
    public class Matrix {
        public static int N = 1 << 13;
        public static int matrixRandomSeed = 1234;
        public static float fillFactor = 0.98f;
    
        //一行中的一维数组做成基于 符号表的稀疏向量(set)
        //列: key 行号,val 一行的hashSet
    
        //优化:1.换数据结构,拉链,或者 红黑
        //2.使用基本类型值拷贝
        LinearProbingHashST<Integer, LinearProbingHashST<Integer, Float>> m1;
    
        public Matrix() {
            m1 = new LinearProbingHashST<>();
            randInit();
        }
    
        public void randInit() {
            Random rand = new Random(matrixRandomSeed);
    
            LinearProbingHashST<Integer, Float> row;
            for (int i = 0; i < N; i++) {
                m1.put(i, row = new LinearProbingHashST<Integer, Float>());
                for (int j = 0; j < N; j++) {
                    float v = 0;
                    if (rand.nextFloat() > fillFactor) { //填充%的随机数据
                        v = rand.nextFloat() + 0.1f;
                        row.put(j, v);
                    }
                }
            }
        }
    
    
        public static float[] matrixDemo1() {
            float a[][] = new float[N][N];
            float b[];
            float c[] = new float[N];
            //a*b=c   b postMult a ,b后乘a
    
            Random rand = new Random(matrixRandomSeed);
    
            for (int i = 0; i < N; i++) {
                for (int j = 0; j < N; j++) {
                    if (rand.nextFloat() > fillFactor) { //填充30%的随机数据
                        a[i][j] = rand.nextFloat() + 0.1f;
                    }
                }
            }
            b = genB(N);
    
            long s1 = System.currentTimeMillis();
            for (int i = 0; i < N; i++) {
                c[i] = mult(a[i], b);
            }
            long s2 = System.currentTimeMillis();
            System.out.println("
    time1 : " + (s2 - s1) + "ms
    ");
            return c;
        }
    
        public static float mult(float a[], float b[]) {
            float sum = 0.0f;
            for (int i = 0; i < a.length; i++) {
                sum += a[i] * b[i];
            }
            return sum;
        }
    
        public static float[] genB(int n) {
            Random rand = new Random(matrixRandomSeed);
            float b[] = new float[n];
            for (int i = 0; i < n; i++) {
                b[i] = rand.nextFloat();
            }
            return b;
        }
    
        public static float[] matrixDemo2() {
            Matrix m = new Matrix();
    
            float[] b = genB(N);
            float[] c = m.mult(b);
    
            return c;
        }
    
        public float multRow(LinearProbingHashST<Integer, Float> r, float[] b) {
            float sum = 0;
            Node<Integer, Void> hl = r.headList.head;
    
            for (int index = 0; hl != null; ) {
                index = hl.k;
                float v = r.get(index);
                sum += b[index] * v;
                hl = hl.n;
            }
            return sum;
        }
    
        public float[] mult(float[] b) {
            Node<Integer, Void> hl = m1.headList.head;
            float[] c = new float[N];
            long s1 = System.currentTimeMillis();
            while (hl != null) {
                int rowN = hl.k;
                LinearProbingHashST<Integer, Float> row = m1.get(rowN);
                c[rowN] = multRow(row, b);
                hl = hl.n;
            }
            long s2 = System.currentTimeMillis();
            System.out.println("time2 : " + (s2 - s1) + "ms");
            return c;
        }
    
    
        public static boolean cmp(float[] a, float[] b) {
            if (a.length != b.length)
                return false;
    
            for (int i = 0; i < a.length; i++) {
                if (a[i] != b[i])
                    return false;
            }
    
            return true;
        }
    
    
        public static SparseVector[] initm3() {
            SparseVector[] m3 = new SparseVector[N];
            Random rand = new Random(matrixRandomSeed);
    
            //填充矩阵a
            SparseVector row;
            for (int i = 0; i < N; i++) {
                row = m3[i] = new SparseVector();
    
                for (int j = 0; j < N; j++) {
                    float v = 0;
                    if (rand.nextFloat() > fillFactor) { //填充接近%的随机数据
                        v = rand.nextFloat() + 0.1f;
                        row.put(j, v);
                    }
                }
            }
            return m3;
        }
    
        public static float mult(SparseVector row, float[] b) {
            IntNode intNode = row.headList.head; //keySet ,没用Iterator,直接用链表遍历
            float sum = 0;
            while (intNode != null) {
                int index = intNode.k;
                sum += row.get(index) * b[index];
                intNode = intNode.n;
            }
            return sum;
        }
    
        public static float[] matrixDemo3() {
            //init
            SparseVector[] m3a = Matrix.initm3();
            float[] b = genB(N);
            float[] c = new float[N];
            System.out.println();
            long s31 = System.currentTimeMillis();
            for (int i = 0; i < N; i++) {
                c[i] = mult(m3a[i], b);
            }
            long s32 = System.currentTimeMillis();
            System.out.println("time3: " + (s32 - s31) + "ms");
            return c;
        }
    
        public static void main(String args[]) {
            float[] res1, res2, res3;
            System.out.println("N " + N);
            res1 = matrixDemo1();
    
            System.out.println("
    matrixDemo2
    ");
            res2 = matrixDemo2();
    
            System.out.println("
    matrixDemo3
    ");
            res3 = matrixDemo3();
        }
    }

    其他的 More suck code就不贴了,太low太suck

  • 相关阅读:
    ORACLE日期时间函数大全
    orcal基础
    javaweb学习总结——基于Servlet+JSP+JavaBean开发模式的用户登录注册
    一个DataTable赋值给另一个DataTable的常用方法
    ios开发 解释器和编译器
    ios面试题(五)-多线程
    ios面试题(四)-block
    ios面试题(三)
    ios开发面试题(二)
    ios开发面试题(一)
  • 原文地址:https://www.cnblogs.com/cyy12/p/11581894.html
Copyright © 2011-2022 走看看