zoukankan      html  css  js  c++  java
  • UDAFTest

    package com.XX.udf;
    
    import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    import org.apache.hadoop.hive.ql.metadata.HiveException;
    import org.apache.hadoop.hive.ql.parse.SemanticException;
    import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
    import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
    import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
    import org.apache.hadoop.io.LongWritable;
    
    public class UDAFTest extends AbstractGenericUDAFResolver{
        //判断
        @Override
        public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)//字段的描述信息参数parameters
                throws SemanticException {
            if(info.length !=2){
                throw new UDFArgumentTypeException(info.length-1,
                        "Exactly two argument is expected.");
            }
    
            //返回处理逻辑的类
            return new GenericEvaluate();
        }
    
        public static class GenericEvaluate extends GenericUDAFEvaluator{
    
            private LongWritable result;
            private PrimitiveObjectInspector inputIO1;
            private PrimitiveObjectInspector inputIO2;
    
            //这个方法map与reduce阶段都需要执行
            /**
             * map阶段:parameters长度与udaf输入的参数个数有关
             * reduce阶段:parameters长度为1
             */
            //初始化
            @Override
            public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                    throws HiveException {
                super.init(m, parameters);
    
                //返回最终的结果
                result = new LongWritable(0);
    
                inputIO1 = (PrimitiveObjectInspector) parameters[0];
                if (parameters.length>1) {
                    inputIO2 = (PrimitiveObjectInspector) parameters[1];
                }
    
                return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
            }
    
            //map阶段  iterate函数处理读入的行数据
            @Override
            public void iterate(AggregationBuffer agg, Object[] parameters)//agg缓存结果值
                    throws HiveException {
    
                assert(parameters.length==2);
    
                if(parameters==null || parameters[0]==null ||  parameters[1]==null){
                    return;
                }
    
                double base = PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputIO1);
                double tmp = PrimitiveObjectInspectorUtils.getDouble(parameters[1], inputIO2);
    
                if(base > tmp){
                    ((CountAgg)agg).count++;
                }
            }
    
            //获得一个聚合的缓冲对象,每个map执行一次
            @Override
            public AggregationBuffer getNewAggregationBuffer() throws HiveException {
    
                CountAgg agg = new CountAgg();
    
                reset(agg);
    
                return agg;
            }
    
            //自定义类用于计数
            public static class CountAgg implements AggregationBuffer{
                long count;//计数,保存每次临时的结果
            }
    
            //重置
            @Override
            public void reset(AggregationBuffer countagg) throws HiveException {
                CountAgg agg = (CountAgg)countagg;
                agg.count=0;
            }
    
            //该方法当做iterate执行后,部分结果返回。  terminatePartial 返回iterate处理的中间结果
            @Override
            public Object terminatePartial(AggregationBuffer agg)
                    throws HiveException {
    
                result.set(((CountAgg)agg).count);
    
                return result;
            }
    
    
    
            @Override    //合并处理结果
            public void merge(AggregationBuffer agg, Object partial)
                    throws HiveException {
                if(partial != null){
                    long p = PrimitiveObjectInspectorUtils.getLong(partial, inputIO1);
                    ((CountAgg)agg).count += p;
                }
            }
    
            @Override  //返回最终值
            public Object terminate(AggregationBuffer agg) throws HiveException {
                result.set(((CountAgg)agg).count);
                return result;
            }
        }
    }
  • 相关阅读:
    Leetcode 16.25 LRU缓存 哈希表与双向链表的组合
    Leetcode437 路径总和 III 双递归与前缀和
    leetcode 0404 二叉树检查平衡性 DFS
    Leetcode 1219 黄金矿工 暴力回溯
    Leetcode1218 最长定差子序列 哈希表优化DP
    Leetcode 91 解码方法
    Leetcode 129 求根到叶子节点数字之和 DFS优化
    Leetcode 125 验证回文串 双指针
    Docker安装Mysql记录
    vmware虚拟机---Liunx配置静态IP
  • 原文地址:https://www.cnblogs.com/yin-fei/p/10879736.html
Copyright © 2011-2022 走看看