zoukankan      html  css  js  c++  java
  • Apriori算法--关联规则挖掘

    我的数据挖掘算法代码:https://github.com/linyiqun/DataMiningAlgorithm

    介绍

    Apriori算法是一个经典的数据挖掘算法,Apriori的单词的意思是"先验的",说明这个算法是具有先验性质的,就是说要通过上一次的结果推导出下一次的结果,这个如何体现将会在下面的分析中会慢慢的体现出来。Apriori算法的用处是挖掘频繁项集的,频繁项集粗俗的理解就是找出经常出现的组合,然后根据这些组合最终推出我们的关联规则。

    Apriori算法原理

    Apriori算法是一种逐层搜索的迭代式算法,其中k项集用于挖掘(k+1)项集,这是依靠他的先验性质的:

    频繁项集的所有非空子集一定是也是频繁的。

    通过这个性质可以对候选集进行剪枝。用k项集如何生成(k+1)项集呢,这个是算法里面最难也是最核心的部分。

    通过2个步骤

    1、连接步,将频繁项自己与自己进行连接运算。

    2、剪枝步,去除候选集项中的不符合要求的候选项,不符合要求指的是这个候选项的子集并非都是频繁项,要遵守上文提到的先验性质。

    3、通过1,2步骤还不够,在后面还要根据支持度计数筛选掉不满足最小支持度数的候选集。

    算法实例

    首先是测试数据:

    交易ID

    商品ID列表

    T100

    I1I2I5

    T200

    I2I4

    T300

    I2I3

    T400

    I1I2I4

    T500

    I1I3

    T600

    I2I3

    T700

    I1I3

    T800

    I1I2I3I5

    T900

    I1I2I3

    算法的步骤图:


    最后我们可以看到频繁3项集的结果为{1, 2, 3}和{1, 2, 5},然后我们去后者{1, 2, 5}作为频繁项集来生产他的关联规则,但是在这之前得先知道一些概念,怎么样才能够成为一条关联规则,关有频繁项集还是不够的。

    关联规则

    confidence(置信度)

    confidence的中文意思为自信的,在这里其实表示的是一种条件概率,当在A条件下,B发生的概率就可以表示为confidence(A->B)=p(B|A),意为在A的情况下,推出B的概率。那么关联规则与有什么关系呢,请继续往下看。

    最小置信度阈值

    按照字面上的意思就是限制置信度值的一个限制条件嘛,这个很好理解。

    强规则

    强规则就是指的是置信度满足最小置信度(就是>=最小置信度)的推断就是一个强规则,也就是文中所说的关联规则了。这个在下面的程序中会有所体现。

    算法的代码实现

    我自己写的算法实现可能会让你有点晦涩难懂,不过重在理解算法的整个思路即可,尤其是连接步和剪枝步是最难点所在,可能还存在bug。

    输入数据:

    T1 1 2 5
    T2 2 4
    T3 2 3
    T4 1 2 4
    T5 1 3
    T6 2 3
    T7 1 3
    T8 1 2 3 5
    T9 1 2 3
    频繁项类:

    /**
     * 频繁项集
     * 
     * @author lyq
     * 
     */
    public class FrequentItem implements Comparable<FrequentItem>{
    	// 频繁项集的集合ID
    	private String[] idArray;
    	// 频繁项集的支持度计数
    	private int count;
    	//频繁项集的长度,1项集或是2项集,亦或是3项集
    	private int length;
    
    	public FrequentItem(String[] idArray, int count){
    		this.idArray = idArray;
    		this.count = count;
    		length = idArray.length;
    	}
    
    	public String[] getIdArray() {
    		return idArray;
    	}
    
    	public void setIdArray(String[] idArray) {
    		this.idArray = idArray;
    	}
    
    	public int getCount() {
    		return count;
    	}
    
    	public void setCount(int count) {
    		this.count = count;
    	}
    
    	public int getLength() {
    		return length;
    	}
    
    	public void setLength(int length) {
    		this.length = length;
    	}
    
    	@Override
    	public int compareTo(FrequentItem o) {
    		// TODO Auto-generated method stub
    		return this.getIdArray()[0].compareTo(o.getIdArray()[0]);
    	}
    	
    }
    
    主程序类:

    package DataMining_Apriori;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.text.MessageFormat;
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.HashMap;
    import java.util.Map;
    
    /**
     * apriori算法工具类
     * 
     * @author lyq
     * 
     */
    public class AprioriTool {
    	// 最小支持度计数
    	private int minSupportCount;
    	// 测试数据文件地址
    	private String filePath;
    	// 每个事务中的商品ID
    	private ArrayList<String[]> totalGoodsIDs;
    	// 过程中计算出来的所有频繁项集列表
    	private ArrayList<FrequentItem> resultItem;
    	// 过程中计算出来频繁项集的ID集合
    	private ArrayList<String[]> resultItemID;
    
    	public AprioriTool(String filePath, int minSupportCount) {
    		this.filePath = filePath;
    		this.minSupportCount = minSupportCount;
    		readDataFile();
    	}
    
    	/**
    	 * 从文件中读取数据
    	 */
    	private void readDataFile() {
    		File file = new File(filePath);
    		ArrayList<String[]> dataArray = new ArrayList<String[]>();
    
    		try {
    			BufferedReader in = new BufferedReader(new FileReader(file));
    			String str;
    			String[] tempArray;
    			while ((str = in.readLine()) != null) {
    				tempArray = str.split(" ");
    				dataArray.add(tempArray);
    			}
    			in.close();
    		} catch (IOException e) {
    			e.getStackTrace();
    		}
    
    		String[] temp = null;
    		totalGoodsIDs = new ArrayList<>();
    		for (String[] array : dataArray) {
    			temp = new String[array.length - 1];
    			System.arraycopy(array, 1, temp, 0, array.length - 1);
    
    			// 将事务ID加入列表吧中
    			totalGoodsIDs.add(temp);
    		}
    	}
    
    	/**
    	 * 判读字符数组array2是否包含于数组array1中
    	 * 
    	 * @param array1
    	 * @param array2
    	 * @return
    	 */
    	public boolean iSStrContain(String[] array1, String[] array2) {
    		if (array1 == null || array2 == null) {
    			return false;
    		}
    
    		boolean iSContain = false;
    		for (String s : array2) {
    			// 新的字母比较时,重新初始化变量
    			iSContain = false;
    			// 判读array2中每个字符,只要包括在array1中 ,就算包含
    			for (String s2 : array1) {
    				if (s.equals(s2)) {
    					iSContain = true;
    					break;
    				}
    			}
    
    			// 如果已经判断出不包含了,则直接中断循环
    			if (!iSContain) {
    				break;
    			}
    		}
    
    		return iSContain;
    	}
    
    	/**
    	 * 项集进行连接运算
    	 */
    	private void computeLink() {
    		// 连接计算的终止数,k项集必须算到k-1子项集为止
    		int endNum = 0;
    		// 当前已经进行连接运算到几项集,开始时就是1项集
    		int currentNum = 1;
    		// 商品,1频繁项集映射图
    		HashMap<String, FrequentItem> itemMap = new HashMap<>();
    		FrequentItem tempItem;
    		// 初始列表
    		ArrayList<FrequentItem> list = new ArrayList<>();
    		// 经过连接运算后产生的结果项集
    		resultItem = new ArrayList<>();
    		resultItemID = new ArrayList<>();
    		// 商品ID的种类
    		ArrayList<String> idType = new ArrayList<>();
    		for (String[] a : totalGoodsIDs) {
    			for (String s : a) {
    				if (!idType.contains(s)) {
    					tempItem = new FrequentItem(new String[] { s }, 1);
    					idType.add(s);
    					resultItemID.add(new String[] { s });
    				} else {
    					// 支持度计数加1
    					tempItem = itemMap.get(s);
    					tempItem.setCount(tempItem.getCount() + 1);
    				}
    				itemMap.put(s, tempItem);
    			}
    		}
    		// 将初始频繁项集转入到列表中,以便继续做连接运算
    		for (Map.Entry entry : itemMap.entrySet()) {
    			list.add((FrequentItem) entry.getValue());
    		}
    		// 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
    		Collections.sort(list);
    		resultItem.addAll(list);
    
    		String[] array1;
    		String[] array2;
    		String[] resultArray;
    		ArrayList<String> tempIds;
    		ArrayList<String[]> resultContainer;
    		// 总共要算到endNum项集
    		endNum = list.size() - 1;
    
    		while (currentNum < endNum) {
    			resultContainer = new ArrayList<>();
    			for (int i = 0; i < list.size() - 1; i++) {
    				tempItem = list.get(i);
    				array1 = tempItem.getIdArray();
    				for (int j = i + 1; j < list.size(); j++) {
    					tempIds = new ArrayList<>();
    					array2 = list.get(j).getIdArray();
    					for (int k = 0; k < array1.length; k++) {
    						// 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
    						if (array1[k].equals(array2[k])) {
    							tempIds.add(array1[k]);
    						} else {
    							tempIds.add(array1[k]);
    							tempIds.add(array2[k]);
    						}
    					}
    					resultArray = new String[tempIds.size()];
    					tempIds.toArray(resultArray);
    
    					boolean isContain = false;
    					// 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
    					if (resultArray.length == (array1.length + 1)) {
    						isContain = isIDArrayContains(resultContainer,
    								resultArray);
    						if (!isContain) {
    							resultContainer.add(resultArray);
    						}
    					}
    				}
    			}
    
    			// 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
    			list = cutItem(resultContainer);
    			currentNum++;
    		}
    
    		// 输出频繁项集
    		for (int k = 1; k <= currentNum; k++) {
    			System.out.println("频繁" + k + "项集:");
    			for (FrequentItem i : resultItem) {
    				if (i.getLength() == k) {
    					System.out.print("{");
    					for (String t : i.getIdArray()) {
    						System.out.print(t + ",");
    					}
    					System.out.print("},");
    				}
    			}
    			System.out.println();
    		}
    	}
    
    	/**
    	 * 判断列表结果中是否已经包含此数组
    	 * 
    	 * @param container
    	 *            ID数组容器
    	 * @param array
    	 *            待比较数组
    	 * @return
    	 */
    	private boolean isIDArrayContains(ArrayList<String[]> container,
    			String[] array) {
    		boolean isContain = true;
    		if (container.size() == 0) {
    			isContain = false;
    			return isContain;
    		}
    
    		for (String[] s : container) {
    			// 比较的视乎必须保证长度一样
    			if (s.length != array.length) {
    				continue;
    			}
    
    			isContain = true;
    			for (int i = 0; i < s.length; i++) {
    				// 只要有一个id不等,就算不相等
    				if (s[i] != array[i]) {
    					isContain = false;
    					break;
    				}
    			}
    
    			// 如果已经判断是包含在容器中时,直接退出
    			if (isContain) {
    				break;
    			}
    		}
    
    		return isContain;
    	}
    
    	/**
    	 * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
    	 */
    	private ArrayList<FrequentItem> cutItem(ArrayList<String[]> resultIds) {
    		String[] temp;
    		// 忽略的索引位置,以此构建子集
    		int igNoreIndex = 0;
    		FrequentItem tempItem;
    		// 剪枝生成新的频繁项集
    		ArrayList<FrequentItem> newItem = new ArrayList<>();
    		// 不符合要求的id
    		ArrayList<String[]> deleteIdArray = new ArrayList<>();
    		// 子项集是否也为频繁子项集
    		boolean isContain = true;
    
    		for (String[] array : resultIds) {
    			// 列举出其中的一个个的子项集,判断存在于频繁项集列表中
    			temp = new String[array.length - 1];
    			for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) {
    				isContain = true;
    				for (int j = 0, k = 0; j < array.length; j++) {
    					if (j != igNoreIndex) {
    						temp[k] = array[j];
    						k++;
    					}
    				}
    
    				if (!isIDArrayContains(resultItemID, temp)) {
    					isContain = false;
    					break;
    				}
    			}
    
    			if (!isContain) {
    				deleteIdArray.add(array);
    			}
    		}
    
    		// 移除不符合条件的ID组合
    		resultIds.removeAll(deleteIdArray);
    
    		// 移除支持度计数不够的id集合
    		int tempCount = 0;
    		for (String[] array : resultIds) {
    			tempCount = 0;
    			for (String[] array2 : totalGoodsIDs) {
    				if (isStrArrayContain(array2, array)) {
    					tempCount++;
    				}
    			}
    
    			// 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中
    			if (tempCount >= minSupportCount) {
    				tempItem = new FrequentItem(array, tempCount);
    				newItem.add(tempItem);
    				resultItemID.add(array);
    				resultItem.add(tempItem);
    			}
    		}
    
    		return newItem;
    	}
    
    	/**
    	 * 数组array2是否包含于array1中,不需要完全一样
    	 * 
    	 * @param array1
    	 * @param array2
    	 * @return
    	 */
    	private boolean isStrArrayContain(String[] array1, String[] array2) {
    		boolean isContain = true;
    		for (String s2 : array2) {
    			isContain = false;
    			for (String s1 : array1) {
    				// 只要s2字符存在于array1中,这个字符就算包含在array1中
    				if (s2.equals(s1)) {
    					isContain = true;
    					break;
    				}
    			}
    
    			// 一旦发现不包含的字符,则array2数组不包含于array1中
    			if (!isContain) {
    				break;
    			}
    		}
    
    		return isContain;
    	}
    
    	/**
    	 * 根据产生的频繁项集输出关联规则
    	 * 
    	 * @param minConf
    	 *            最小置信度阈值
    	 */
    	public void printAttachRule(double minConf) {
    		// 进行连接和剪枝操作
    		computeLink();
    
    		int count1 = 0;
    		int count2 = 0;
    		ArrayList<String> childGroup1;
    		ArrayList<String> childGroup2;
    		String[] group1;
    		String[] group2;
    		// 以最后一个频繁项集做关联规则的输出
    		String[] array = resultItem.get(resultItem.size() - 1).getIdArray();
    		// 子集总数,计算的时候除去自身和空集
    		int totalNum = (int) Math.pow(2, array.length);
    		String[] temp;
    		// 二进制数组,用来代表各个子集
    		int[] binaryArray;
    		// 除去头和尾部
    		for (int i = 1; i < totalNum - 1; i++) {
    			binaryArray = new int[array.length];
    			numToBinaryArray(binaryArray, i);
    
    			childGroup1 = new ArrayList<>();
    			childGroup2 = new ArrayList<>();
    			count1 = 0;
    			count2 = 0;
    			// 按照二进制位关系取出子集
    			for (int j = 0; j < binaryArray.length; j++) {
    				if (binaryArray[j] == 1) {
    					childGroup1.add(array[j]);
    				} else {
    					childGroup2.add(array[j]);
    				}
    			}
    
    			group1 = new String[childGroup1.size()];
    			group2 = new String[childGroup2.size()];
    
    			childGroup1.toArray(group1);
    			childGroup2.toArray(group2);
    
    			for (String[] a : totalGoodsIDs) {
    				if (isStrArrayContain(a, group1)) {
    					count1++;
    
    					// 在group1的条件下,统计group2的事件发生次数
    					if (isStrArrayContain(a, group2)) {
    						count2++;
    					}
    				}
    			}
    
    			// {A}-->{B}的意思为在A的情况下发生B的概率
    			System.out.print("{");
    			for (String s : group1) {
    				System.out.print(s + ", ");
    			}
    			System.out.print("}-->");
    			System.out.print("{");
    			for (String s : group2) {
    				System.out.print(s + ", ");
    			}
    			System.out.print(MessageFormat.format(
    					"},confidence(置信度):{0}/{1}={2}", count2, count1, count2
    							* 1.0 / count1));
    			if (count2 * 1.0 / count1 < minConf) {
    				// 不符合要求,不是强规则
    				System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则");
    			} else {
    				System.out.println("为强规则");
    			}
    		}
    
    	}
    
    	/**
    	 * 数字转为二进制形式
    	 * 
    	 * @param binaryArray
    	 *            转化后的二进制数组形式
    	 * @param num
    	 *            待转化数字
    	 */
    	private void numToBinaryArray(int[] binaryArray, int num) {
    		int index = 0;
    		while (num != 0) {
    			binaryArray[index] = num % 2;
    			index++;
    			num /= 2;
    		}
    	}
    
    }
    
    调用类:

    /**
     * apriori关联规则挖掘算法调用类
     * @author lyq
     *
     */
    public class Client {
    	public static void main(String[] args){
    		String filePath = "C:\Users\lyq\Desktop\icon\testInput.txt";
    		
    		AprioriTool tool = new AprioriTool(filePath, 2);
    		tool.printAttachRule(0.7);
    	}
    }
    输出的结果:

    频繁1项集:
    {1,},{2,},{3,},{4,},{5,},
    频繁2项集:
    {1,2,},{1,3,},{1,5,},{2,3,},{2,4,},{2,5,},
    频繁3项集:
    {1,2,3,},{1,2,5,},
    频繁4项集:
    
    {1, }-->{2, 5, },confidence(置信度):2/6=0.333由于此规则置信度未达到最小置信度的要求,不是强规则
    {2, }-->{1, 5, },confidence(置信度):2/7=0.286由于此规则置信度未达到最小置信度的要求,不是强规则
    {1, 2, }-->{5, },confidence(置信度):2/4=0.5由于此规则置信度未达到最小置信度的要求,不是强规则
    {5, }-->{1, 2, },confidence(置信度):2/2=1为强规则
    {1, 5, }-->{2, },confidence(置信度):2/2=1为强规则
    {2, 5, }-->{1, },confidence(置信度):2/2=1为强规则

    程序算法的问题和技巧

    在实现Apiori算法的时候,碰到的一些问题和待优化的点特别要提一下:

    1、首先程序的运行效率不高,里面有大量的for嵌套循环叠加上循环,当然这有本身算法的原因(连接运算所致)还有我的各个的方法选择,很多一部分用来比较字符串数组。

    2、这个是我觉得会是程序的一个漏洞,当生成的候选项集加入resultItemId时,会出现{1, 2, 3}和{3, 2, 1}会被当成不同的侯选集,未做顺序的判断。

    3、程序的调试过程中由于未按照从小到大的排序,导致,生成的候选集与真实值不一致的情况,所以这里必须在频繁1项集的时候就应该是有序的。

    4、在输出关联规则的时候,用到了数字转二进制数组的形式,输出他的各个非空子集,然后最出关联规则的判断。

    Apriori算法的缺点

    此算法的的应用非常广泛,但是他在运算的过程中会产生大量的侯选集,而且在匹配的时候要进行整个数据库的扫描,因为要做支持度计数的统计操作,在小规模的数据上操作还不会有大问题,如果是大型的数据库上呢,他的效率还是有待提高的。

  • 相关阅读:
    my.cnf
    js日期和毫秒互转
    传送门
    js 十进制转十六进制
    关键字
    常见异常
    Map迭代
    Hibernate

    MySql Host is blocked because of many connection errors; unblock with 'mysqladmin flushhosts' 解决方法
  • 原文地址:https://www.cnblogs.com/bianqi/p/12184035.html
Copyright © 2011-2022 走看看