zoukankan      html  css  js  c++  java
  • 数据挖掘 之 关联规则求解算法Apriori的实现

    关联规则求解算法Apriori的实现

    code + 报告 见:https://github.com/JianmingS/Apriori

      1 // by Shi Jianming
      2 /*
      3 数据挖掘:关联规则求解算法Apriori的实现
      4 */
      5 
      6 #define _CRT_SECURE_NO_WARNINGS
      7 #define HOME
      8 
      9 #include <iostream>
     10 #include <cstdio>
     11 #include <vector>
     12 #include <string>
     13 #include <cmath>
     14 #include <map>
     15 #include <locale>
     16 using namespace std;
     17 const double eps = 1e-8;
     18 const int MaxColNum = 100;
     19 
     20 int rowNum, columnNum; // 行数,列数
     21 double supportMin, confidenceMin; // 最小支持度, 最小置信度
     22 int supporNum; // 最小支持数
     23 int total;
     24 int Case;
     25 
     26 vector<vector<int> > dataBase; // 保存原始数据集
     27 vector<string> columnName; // 保存每一列的栏目名
     28 
     29 // 数据集
     30 struct itemset
     31 {
     32     vector<int> item; // 事务(包含0个或多个项)
     33     int cnt; // 事务出现次数
     34     int id; // 事务唯一标识
     35     itemset()
     36     {
     37         cnt = 0;
     38         id = -1;
     39     }
     40 };
     41 
     42 vector<itemset> preL; // 频繁(k-1)-项集
     43 vector<itemset> C; // 候选(k)-项集
     44 vector<itemset> L; // 频繁(k)-项集
     45 
     46 map<int, itemset> forL; // 为构造频繁(k)-项集
     47 
     48 int C1[MaxColNum]; // 记录C1
     49 
     50 
     51 /****************************************************/
     52 /*
     53 Hash树:
     54 Hash函数: h(p) = p mod k
     55 时间复杂度:O(k)
     56 */
     57 
     58 struct hashTrie
     59 {
     60     hashTrie *next[MaxColNum]; // Hash树后继节点
     61     vector<itemset> C; // 候选(k)-项集
     62     hashTrie()
     63     {
     64         fill(next, next + MaxColNum, nullptr);
     65     }
     66 };
     67 // 创建Hash树
     68 void CrehashTrie(hashTrie *root, vector<int> branch)
     69 {
     70     hashTrie *p = root;
     71     for (auto i = 0; i < branch.size(); ++i)
     72     {
     73         int pos = branch[i] % branch.size();
     74         if (nullptr == p->next[pos])
     75         {
     76             p->next[pos] = new hashTrie;
     77         }
     78         p = p->next[pos];
     79     }
     80     itemset itsetTmp;
     81     itsetTmp.item = branch;
     82     itsetTmp.id = (total++);
     83     p->C.push_back(itsetTmp);
     84 }
     85 // 查找branch的值,判断是否可以在Hash树中匹配成功,并记录在Hash树中匹配成功的次数,保存频繁集
     86 bool FindhashTrie(hashTrie *root, vector<int> branch)
     87 {
     88     hashTrie *p = root;
     89     for (auto i = 0; i < branch.size(); ++i)
     90     {
     91         int pos = branch[i] % branch.size();
     92         if (nullptr == p->next[pos])
     93         {
     94             return false;
     95         }
     96         p = p->next[pos];
     97     }
     98     for (auto &tmp : p->C)
     99     {
    100         auto i = 0;
    101         for (; i != tmp.item.size(); ++i)
    102         {
    103             if (tmp.item[i] != branch[i])
    104             {
    105                 break;
    106             }
    107         }
    108         if (i == tmp.item.size())
    109         {
    110 
    111             ++(tmp.cnt);
    112             if (tmp.cnt >= (supporNum))
    113             {
    114                 if (forL.find(tmp.id) != forL.end())
    115                 {
    116                     ++(forL[tmp.id].cnt);
    117                 }else
    118                 {
    119                     forL.insert({tmp.id, tmp});
    120                 }
    121             }
    122             return true;
    123         }
    124     }
    125     return false;
    126 }
    127 // 销毁Hash树
    128 void DelhashTrie(hashTrie *T, int len)
    129 {
    130     for (int i = 0; i < len; ++i)
    131     {
    132         if (T->next[i] != nullptr)
    133         {
    134             DelhashTrie(T->next[i], len);
    135         }
    136     }
    137     if (!T->C.empty())
    138     {
    139         T->C.clear();
    140     }
    141     delete[] T->next;
    142     total = 0;
    143 }
    144 
    145 /****************************************************/
    146 
    147 
    148 
    149 /****************************************************/
    150 /*
    151 从集合{0,1,2,3..,(N-1)} 中找出所有大小为k的子集, 并按照字典序排序
    152 */
    153 vector<vector<int>> combine;
    154 int arr[MaxColNum];
    155 int visit[MaxColNum];
    156 int combineN, combineK;
    157 // 起始:dfs(0, 0)
    158 void dfs(int d, int pos)
    159 {
    160     if (d == combineK)
    161     {
    162         vector<int> tmp;
    163         for (int i = 0; i < combineK; ++i)
    164         {
    165             tmp.push_back(arr[i]);
    166         }
    167         combine.push_back(tmp);
    168         return;
    169     }
    170     for (int i = pos; i < combineN; ++i)
    171     {
    172         if (!visit[i])
    173         {
    174             visit[i] = true;
    175             arr[d] = i;
    176             dfs(d + 1, i + 1);
    177             visit[i] = false;
    178         }
    179     }
    180 }
    181 /****************************************************/
    182 
    183 // 读取原始数据集
    184 void Input()
    185 {
    186     cin >> rowNum >> columnNum;
    187     supporNum = ceil(supportMin*(rowNum - 1));
    188     string rowFirst;
    189     for (auto i = 0; i < rowNum; ++i)
    190     {
    191         cin >> rowFirst;
    192         vector<int> dataRow;
    193         int valueTmp;
    194         // 去掉输入数据的第一列
    195         for (auto j = 0; j < (columnNum - 1); ++j)
    196         {
    197             if (i != 0)
    198             {
    199                 cin >> valueTmp;
    200                 if (valueTmp) {
    201                     ++C1[j];
    202                     dataRow.push_back(j);
    203                 }
    204             }
    205             else
    206             {
    207                 string colNameTmp;
    208                 cin >> colNameTmp; 
    209                 columnName.push_back(colNameTmp);
    210             }
    211         }
    212         if (i != 0) dataBase.push_back(dataRow);
    213     }
    214 }
    215 
    216 // 获取频繁1-项集
    217 void Ini()
    218 {
    219     for (auto i = 0; i < (columnNum - 1); ++i)
    220     {
    221         if (C1[i] >= supporNum)
    222         {
    223             itemset itemsetTmp;
    224             itemsetTmp.item.push_back(i);
    225             itemsetTmp.cnt = C1[i];
    226             preL.push_back(itemsetTmp);
    227         }
    228     }
    229 }
    230 
    231 
    232 // 枚举所有事务(即原始数据)包含的k-项集,计算支持度
    233 void getItemsK(hashTrie *root, int k)
    234 {
    235     vector<int> branch;
    236 //    int bbb = 0;
    237     for (auto tmp : dataBase)
    238     {
    239 //        cout << bbb++ << " : " << endl;
    240         if (tmp.size() < k) continue;
    241 
    242         combineN = tmp.size();
    243         combineK = k;
    244         dfs(0, 0);
    245 
    246         for (int i = 0; i < combine.size(); ++i)
    247         {
    248             for (int j = 0; j < combine[i].size(); ++j)
    249             {
    250                 branch.push_back(tmp[combine[i][j]]);
    251             }
    252             /***********************/
    253             /*
    254             匹配候选k-项集,计算支持数
    255             */
    256             FindhashTrie(root, branch);
    257 //            if (FindhashTrie(root, branch))
    258 //            {
    259 //                for (auto aaa = 0; aaa < branch.size(); ++aaa)
    260 //                {
    261 //                    cout << branch[aaa] << " ";
    262 //                }
    263 //                cout << endl;
    264 //            }
    265 //            /***********************/
    266             branch.clear();
    267         }
    268         combine.clear();
    269 //        cout << endl;
    270     }
    271     
    272 }
    273 
    274 // 判断生成的候选(k)-项集的某个(k-1)-项子集是否为频繁项集
    275 bool isInfrequentSubset(itemset c)
    276 {
    277     hashTrie *root = new hashTrie;
    278     int k = c.item.size() - 1;
    279     for (auto tmp : preL)
    280     {
    281         CrehashTrie(root, tmp.item);
    282     }
    283     vector<int> branch;
    284 
    285     combineN = c.item.size();
    286     combineK = k;
    287     dfs(0, 0);
    288 
    289     for (int i = 0; i < combine.size(); ++i)
    290     {
    291         for (int j = 0; j < combine[i].size(); ++j)
    292         {
    293             branch.push_back(c.item[combine[i][j]]);
    294         }
    295 
    296         /***********************/
    297         /*
    298         判断生成的((k-1)-项子集是否为频繁的。
    299         */
    300         if (!FindhashTrie(root, branch))
    301         {
    302             combine.clear();
    303             DelhashTrie(root, k);
    304             return true;
    305         }
    306         /***********************/
    307         branch.clear();
    308     }
    309     combine.clear();
    310     DelhashTrie(root, k);
    311     return false;
    312 }
    313 
    314 // 产生候选(k)-项集
    315 void apriori_gen(int k)
    316 {
    317     for (auto L1 = 0; L1 < preL.size(); ++L1)
    318     {
    319         for (auto L2 = L1 + 1; L2 < preL.size(); ++L2)
    320         {
    321             auto judge = true;
    322             for (auto i = 0; i < (k - 1); ++i)
    323             {
    324                 if (preL[L1].item[i] != preL[L2].item[i])
    325                 {
    326                     judge = false;
    327                 }
    328             }
    329             if (!judge) continue;
    330             itemset itemsetTmp;
    331             for (auto i = 0; i < (k - 1); ++i)
    332             {
    333                 itemsetTmp.item.push_back(preL[L1].item[i]);
    334             }
    335             itemsetTmp.item.push_back(preL[L1].item[k - 1]);
    336             itemsetTmp.item.push_back(preL[L2].item[k - 1]);
    337             if (isInfrequentSubset(itemsetTmp)) {
    338                 continue;
    339             }
    340             C.push_back(itemsetTmp);
    341         }
    342     }
    343 }
    344 // Apriori算法实现,并输出关联规则集
    345 void Apriori()
    346 {
    347     for (auto k = 2; !preL.empty(); ++k)
    348     {
    349         hashTrie *root = new hashTrie;
    350         apriori_gen(k - 1); // 求出候选(k)-项集;
    351         for (auto i = 0; i < C.size(); ++i)
    352         {
    353             CrehashTrie(root, C[i].item);
    354         }
    355         C.clear();
    356         getItemsK(root, k);
    357         DelhashTrie(root, k);
    358         for (auto tmp : forL)
    359         {
    360             L.push_back(tmp.second);
    361         }
    362         forL.clear();
    363         if (L.empty())
    364         {
    365             break;
    366         }
    367         for (auto fromTmp : L)
    368         {
    369             for (auto toTmp : preL)
    370             {
    371                 auto i = 0;
    372                 for (; i < toTmp.item.size(); ++i)
    373                 {
    374                     if (toTmp.item[i] != fromTmp.item[i])
    375                     {
    376                         break;
    377                     }
    378                 }
    379                 if (i == toTmp.item.size())
    380                 {
    381 //                    double aaa = (1.0*fromTmp.cnt) / (1.0*toTmp.cnt);
    382 //                    double bbb = (1.0*fromTmp.cnt) / (1.0*toTmp.cnt) - confidenceMin;
    383                     if ((1.0*fromTmp.cnt)/(1.0*toTmp.cnt) - confidenceMin >= 0.0)
    384                     {
    385                         cout << "Case " << Case++ << " : " << endl;
    386                         for (auto j = 0; j < toTmp.item.size(); ++j)
    387                         {
    388                             cout << columnName[toTmp.item[j]];
    389                             if (j != toTmp.item.size() - 1)
    390                             {
    391                                 cout << ",";
    392                             }
    393                         }
    394                         cout << " => " << columnName[fromTmp.item[toTmp.item.size()]] << endl;
    395                     }
    396                 }
    397             }
    398         }
    399         preL.clear();
    400         preL = L;
    401         L.clear();
    402     }
    403 }
    404 
    405 int main()
    406 {
    407 #ifdef HOME
    408     freopen("in", "r", stdin);
    409     freopen("out", "w", stdout);
    410 #endif
    411     cin >> supportMin >> confidenceMin;
    412     Case = 0;
    413     total = 0;
    414     Input();
    415     Ini();
    416     Apriori();
    417 
    418 #ifdef HOME
    419     cerr << "Time elapsed: " << clock() / CLOCKS_PER_SEC << " ms" << endl;
    420 #endif
    421     return 0;
    422 }
  • 相关阅读:
    C#对象深度克隆(转)
    .Net Core 图片文件上传下载(转)
    事件总线(Event Bus)知多少(转)
    深入理解C#:编程技巧总结(一)(转)
    asp.net core源码飘香:Configuration组件(转)
    asp.net core源码飘香:Logging组件(转)
    基于C#.NET的高端智能化网络爬虫(下)(转)
    基于C#.NET的高端智能化网络爬虫(转)
    30分钟掌握 C#7(转)
    30分钟掌握 C#6(转)
  • 原文地址:https://www.cnblogs.com/shijianming/p/4992610.html
Copyright © 2011-2022 走看看