zoukankan      html  css  js  c++  java
  • 大数乘法的C代码实现

    在C语言中,宽度最大的无符号整数类型是unsigned long long, 占8个字节。那么,如果整数超过8个字节,如何进行大数乘法呢? 例如:

    $ python
    Python 2.7.6 (default, Oct 26 2016, 20:32:47) 
    ...<snip>....
    >>> a = 0x123456781234567812345678
    >>> b = 0x876543211234567887654321
    >>> print "a * b = 0x%x" % (a * b)
    a * b = 0x9a0cd057ba4c159a33a669f0a522711984e32bd70b88d78

    用C语言实现大数乘法,跟十进制的多位数乘法类似,基本思路是采用分而治之的策略难点就是进位处理相对比较复杂。本文尝试给出C代码实现(基于小端),并使用Python脚本验证计算结果。

    1. foo.c

      1 #include <stdio.h>
      2 #include <stdlib.h>
      3 #include <string.h>
      4 
      5 typedef unsigned char           byte;   /* 1 byte */
      6 typedef unsigned short          word;   /* 2 bytes */
      7 typedef unsigned int            dword;  /* 4 bytes */
      8 typedef unsigned long long      qword;  /* 8 bytes */
      9 
     10 typedef struct big_number_s {
     11         dword   *data;
     12         dword   size;
     13 } big_number_t;
     14 
     15 static void
     16 dump(char *tag, big_number_t *p)
     17 {
     18         if (p == NULL)
     19                 return;
     20 
     21         printf("%s : data=%p : size=%d:	", tag, p, p->size);
     22         for (dword i = 0; i < p->size; i++)
     23                 printf("0x%08x ", (p->data)[i]);
     24         printf("
    ");
     25 }
     26 
     27 /*
     28  * Add 64-bit number (8 bytes) to a[] whose element is 32-bit int (4 bytes)
     29  *
     30  * e.g.
     31  *      a[] = {0x12345678,0x87654321,0x0}; n = 3;
     32  *      n64 =  0xffffffff12345678
     33  *
     34  *      The whole process of add64() looks like:
     35  *
     36  *             0x12345678 0x87654321 0x00000000
     37  *          +  0x12345678 0xffffffff
     38  *          -----------------------------------
     39  *          =  0x2468acf0 0x87654321 0x00000000
     40  *          +             0xffffffff
     41  *          -----------------------------------
     42  *          =  0x2468acf0 0x87654320 0x00000001
     43  *
     44  *      Finally,
     45  *      a[] = {0x2468acf0,0x87654320,0x00000001}
     46  */
     47 static void
     48 add64(dword a[], dword n, qword n64)
     49 {
     50         dword carry = 0;
     51 
     52         carry = n64 & 0xFFFFFFFF; /* low 32 bits of n64 */
     53         for (dword i = 0; i < n; i++) {
     54                 if (carry == 0x0)
     55                         break;
     56 
     57                 qword t = (qword)a[i] + (qword)carry;
     58                 a[i] = t & 0xFFFFFFFF;
     59                 carry = (dword)(t >> 32); /* next carry */
     60         }
     61 
     62         carry = (dword)(n64 >> 32); /* high 32 bits of n64 */
     63         for (dword i = 1; i < n; i++) {
     64                 if (carry == 0x0)
     65                         break;
     66 
     67                 qword t = (qword)a[i] + (qword)carry;
     68                 a[i] = t & 0xFFFFFFFF;
     69                 carry = (dword)(t >> 32); /* next carry */
     70         }
     71 }
     72 
     73 static big_number_t *
     74 big_number_mul(big_number_t *a, big_number_t *b)
     75 {
     76         big_number_t *c = (big_number_t *)malloc(sizeof(big_number_t));
     77         if (c == NULL) /* malloc error */
     78                 return NULL;
     79 
     80         c->size = a->size + b->size;
     81         c->data = (dword *)malloc(sizeof(dword) * c->size);
     82         if (c->data == NULL) /* malloc error */
     83                 return NULL;
     84 
     85         memset(c->data, 0, sizeof(dword) * c->size);
     86 
     87         dword *adp = a->data;
     88         dword *bdp = b->data;
     89         dword *cdp = c->data;
     90         for (dword i = 0; i < a->size; i++) {
     91                 if (adp[i] == 0x0)
     92                         continue;
     93 
     94                 for (dword j = 0; j < b->size; j++) {
     95                         if (bdp[j] == 0x0)
     96                                 continue;
     97 
     98                         qword n64 = (qword)adp[i] * (qword)bdp[j];
     99                         dword *dst = cdp + i + j;
    100                         add64(dst, c->size - (i + j), n64);
    101                 }
    102         }
    103 
    104         return c;
    105 }
    106 
    107 static void
    108 free_big_number(big_number_t *p)
    109 {
    110         if (p == NULL)
    111                 return;
    112 
    113         if (p->data != NULL)
    114                 free(p->data);
    115 
    116         free(p);
    117 }
    118 
    119 int
    120 main(int argc, char *argv[])
    121 {
    122         dword a_data[] = {0x12345678, 0x9abcdef0, 0xffffffff, 0x9abcdefa, 0x0};
    123         dword b_data[] = {0xfedcba98, 0x76543210, 0x76543210, 0xfedcba98, 0x0};
    124 
    125         big_number_t a;
    126         a.data = (dword *)a_data;
    127         a.size = sizeof(a_data) / sizeof(dword);
    128 
    129         big_number_t b;
    130         b.data = (dword *)b_data;
    131         b.size = sizeof(b_data) / sizeof(dword);
    132 
    133         dump("BigNumber A", &a);
    134         dump("BigNumber B", &b);
    135         big_number_t *c = big_number_mul(&a, &b);
    136         dump("  C = A * B", c);
    137         free_big_number(c);
    138 
    139         return 0;
    140 }

    2. bar.py

     1 #!/usr/bin/python
     2 
     3 import sys
     4 
     5 def str2hex(s):
     6     l = s.split(' ')
     7 
     8     i = len(l)
     9     out = ""
    10     while i > 0:
    11         i -= 1
    12         e = l[i]
    13         if e.startswith("0x"):
    14             e = e[2:]
    15         out += e
    16 
    17     out = "0x%s" % out
    18     n = eval("%s * %d" % (out, 0x1))
    19     return n
    20 
    21 def hex2str(n):
    22     s_hex = "%x" % n
    23     if s_hex.startswith("0x"):
    24         s_hex = s_hex[2:]
    25 
    26     n = len(s_hex)
    27     m = n % 8
    28     if m != 0:
    29         s_hex = '0' * (8 - m) + s_hex
    30         n += (8 - m)
    31     i = n
    32     l = []
    33     while i >= 8:
    34         l.append('0x' + s_hex[i-8:i])
    35         i -= 8
    36     return "%s" % ' '.join(l)
    37 
    38 def main(argc, argv):
    39     if argc != 4:
    40         sys.stderr.write("Usage: %s <a> <b> <c>
    " % argv[0])
    41         return 1
    42 
    43     a = argv[1]
    44     b = argv[2]
    45     c = argv[3]
    46     ax = str2hex(a)
    47     bx = str2hex(b)
    48     cx = str2hex(c)
    49 
    50     axbx = ax * bx
    51     if axbx != cx:
    52         print "0x%x * 0x%x = " % (ax, bx)
    53         print "got: 0x%x" % axbx
    54         print "exp: 0x%x" % cx
    55         print "res: FAIL"
    56         return 1
    57 
    58     print "got: %s" % hex2str(axbx)
    59     print "exp: %s" % c
    60     print "res: PASS"
    61     return 0
    62 
    63 if __name__ == '__main__':
    64     argv = sys.argv
    65     argc = len(argv)
    66     sys.exit(main(argc, argv))

    3. Makefile

    CC        = gcc
    CFLAGS        = -g -Wall -m32 -std=c99
    
    TARGETS        = foo bar
    
    all: $(TARGETS)
    
    foo: foo.c
        $(CC) $(CFLAGS) -o $@ $<
    
    bar: bar.py
        cp $< $@ && chmod +x $@
    
    clean:
        rm -f *.o
    clobber: clean
        rm -f $(TARGETS)
    cl: clobber

    4. 编译并测试

    $ make
    gcc -g -Wall -m32 -std=c99 -o foo foo.c
    cp bar.py bar && chmod +x bar
    $ ./foo
    BigNumber A : data=0xbfc2a7c8 : size=5: 0x12345678 0x9abcdef0 0xffffffff 0x9abcdefa 0x00000000
    BigNumber B : data=0xbfc2a7d0 : size=5: 0xfedcba98 0x76543210 0x76543210 0xfedcba98 0x00000000
      C = A * B : data=0x8967008 : size=10: 0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056 0x00000000 0x00000000
    $ A="0x12345678 0x9abcdef0 0xffffffff 0x9abcdefa 0x00000000"
    $ B="0xfedcba98 0x76543210 0x76543210 0xfedcba98 0x00000000"
    $ C="0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056 0x00000000 0x00000000"
    $
    $ ./bar "$A" "$B" "$C"
    got: 0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056
    exp: 0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056 0x00000000 0x00000000
    res: PASS
    $

    结束语:

    本文给出的是串行化的大数乘法实现方法。 A * B = C设定如下:

    • 大数A对应的数组长度为M, A = {a0, a1, ..., aM};
    • 大数B对应的数组长度为N, B = {b0, b1, ..., bN};
    • A*B的结果C对应的数组长度为(M+N)。
    A = { a0, a1, ..., aM };
    B = { b0, b1, ..., bN };
    
    C =  A * B
      =  a0 * b0 + a0 * b1 + ... + a0 * bN
       + a1 * b0 + a1 * b1 + ... + a1 * bN
       + ...
       + aM * b0 + aM * b1 + ... + aM * bN
    
       a[i] * b[j] will be save to memory @ c[i+j]
         i = 0, 1, ..., M;     
    j = 0, 1, ..., N a[i] is unsigned int (4 bytes) b[j] is unsigned int (4 bytes)

    算法的时间复杂度为O(M*N), 空间复杂度为O(1)。 为了缩短运行时间,我们也可以采用并行化的实现方法。

    • 启动M个线程同时计算, T0 = a0 * B, T1 = a1 * B, ..., TM = aM * B;
    • 接下来,只要M个线程都把活干完了,主线程就可以对T0, T1, ..., TM进行合并。

    不过,在并行化的实现方法中,对每一个线程来说,时间复杂度为O(N), 空间复杂度为O(N) (至少N+3个辅助存储空间)。因为有M个线程并行计算,于是总的空间复杂度为O(M*N)。

  • 相关阅读:
    初识react hooks
    react初识生命周期
    在调用setState之后发生了什么
    课后作业四
    课后作业2
    课后作业1
    自我介绍
    电脑软件推荐
    数据结构
    数组(一维数组)
  • 原文地址:https://www.cnblogs.com/idorax/p/7119252.html
Copyright © 2011-2022 走看看