zoukankan      html  css  js  c++  java
  • 使用ML.NET实现德州扑克牌型分类器

    导读:ML.NET系列文章

    本文将基于ML.NET v0.2预览版,重点介绍提取特征的思路和方法,实现德州扑克牌型分类器。

    先介绍一下德州扑克的基本牌型,一手完整的牌共有五张扑克,10种牌型分别是:

    1. 高牌,花色和点数同时没有相同的牌。

    2. 一对,点数有且仅有两张相同的牌。

    3. 两对,两张相同点数的牌,加另外两张相同点数的牌。

    4. 三条,有三张同一点数的牌。

    5. 顺子,五张顺连的牌。

    6. 同花,五张同一花色的牌。

    7. 葫芦,三张同一点数的牌,加一对其他点数的牌。

    8. 四条,有四张同一点数的牌。

    9. 同花顺,同一花色五张顺连的牌。

    10. 皇家同花顺,最高点数是A的同花顺的牌。

    这一次我们将使用逻辑回归模型,来训练数据完成我们想要的分类模型。

    准备数据集


    数据来源在Poker Hand Data Set,下载链接为:poker-hand-testing.datapoker-hand-training-true.data。内容类似如下:

    3,92,3,3,2,2,9,3,5,1
    4,4,1,11,2,9,4,13,2,7,0
    1,5,1,9,2,8,2,4,4,3,0
    4,12,4,7,4,5,2,10,2,2,0
    4,3,2,4,4,13,3,6,4,12,0
    2,5,4,5,4,1,4,9,2,7,1
    2,12,3,12,3,7,3,11,2,7,2
    4,13,2,6,4,6,4,10,4,9,1
    ...

    说明一下每一行的格式:

    第1张花色,第1张点数,第2张花色,第2张点数,第3张花色,第3张点数,第4张花色,第4张点数,第5张花色,第5张点数,牌型

    花色是1-4代表红心,黑桃,方块,梅花。点数1表示A,2-10保持不变,11表示J,12表示Q,13表示K。

    特征分析


    前几篇数据集的内容,基本上分割好就是特征了,这一次不同,每一行的数值仅仅是元数据,也就是说,通过五张牌的花色和点数值是不能直接和牌型形成一一对应的联系,需要先按本文开头介绍的10种牌型的描述,找到关键可数值化的字段。因此,我选择了这样一些字段:对子数,是否是三条,是否是四条,是否是顺子,是否同花。通过这5个字段值的组合,一定能判断出牌型。

    于是,我定义出我想要的数据类型:

    public class PokerHandData
    {
        [Column(ordinal: "0")]
        public float S1;
        [Column(ordinal: "1")]
        public float C1;
        [Column(ordinal: "2")]
        public float S2;
        [Column(ordinal: "3")]
        public float C2;
        [Column(ordinal: "4")]
        public float S3;
        [Column(ordinal: "5")]
        public float C3;
        [Column(ordinal: "6")]
        public float S4;
        [Column(ordinal: "7")]
        public float C4;
        [Column(ordinal: "8")]
        public float S5;
        [Column(ordinal: "9")]
        public float C5;
        [Column(ordinal: "10", name: "Label")]
        public float Power;
    [Column(ordinal:
    "11")] public float IsSameSuit; [Column(ordinal: "12")] public float IsStraight; [Column(ordinal: "13")] public float FourOfKind; [Column(ordinal: "14")] public float ThreeOfKind; [Column(ordinal: "15")] public float PairsCount; }

    S表示花色,C表示点数,Power表示牌型,PairsCount表示对子数,ThreeOfKind表示是否是三条,FourOfKind表示是否是四条,IsStraight表示是否顺子,IsSameSuit表示是否同花。

    判断是否同花,只需要比较S1-S5的值即可。

    public float GetIsSameSuit()
    {
        if (S1 == S2 && S2 == S3 && S3 == S4 && S4 == S5)
            return 1;
        else
            return 0;
    }

    判断其它几个特征,我需要一个通用方法,先统计出每一行的C1-C5,每种点数出现的次数。

    static Dictionary<int, int> GetValueCountsOfCondition(IEnumerable<int> cards)
    {
        var dic = new Dictionary<int, int>();
    
        foreach (var item in cards)
        {
            if (dic.ContainsKey(item))
            {
                dic[item] += 1;
            }
            else
            {
                dic.Add(item, 1);
            }
        }
        return dic;
    }

    然后再按特征涵义计算值。

    public float GetFourOfKind()
    {
        return GetCountOfCondition(4);
    }
    
    public float GetThreeOfKind()
    {
        return GetCountOfCondition(3);
    }
    
    public float GetPairsCount()
    {
        return GetCountOfCondition(2);
    }
    
    private IEnumerable<int> GetCards()
    {
        if (cards == null)
        {
            cards = new[] { Convert.ToInt32(C1), Convert.ToInt32(C2), Convert.ToInt32(C3), Convert.ToInt32(C4), Convert.ToInt32(C5) };
        }
    
        return cards;
    }
    
    private float GetCountOfCondition(int target)
    {
        if (valueCounts == null)
        {
            valueCounts = GetValueCountsOfCondition(GetCards());
        }
    
        return valueCounts.Count(i => i.Value == target);
    }

    判断是否为顺子的方法,简单而直接,就是看间隔差是不是为1,或者最高点有A剩下的必须是10、J、Q、K,都算顺子。

    public float GetIsStraight()
    {
        var keys = GetCards().ToArray();
        Array.Sort(keys);
        if (keys[1] - keys[0] == keys[2] - keys[1] && keys[2] - keys[1] == keys[3] - keys[2] && keys[3] - keys[2] == keys[4] - keys[3] && keys[4] - keys[3] == 1)
        {
            return 1;
        }
        else if (keys[0] == 1 && keys[1] == 10 && keys[2] == 11 && keys[3] == 12 && keys[4] == 13)
        {
            return 1;
        }
        else
        {
            return 0;
        }
    }

    加载数据


    这次由于使用了ML.NET v0.2,该版本的LearningPipeline新增了一种支持集合类型的数据源。因此,我将示范一种全新的载入数据集的方法,先以文件载入元数据,然后直接初始化特征的值。

    static IEnumerable<PokerHandData> LoadData(string path)
    {
        using (var environment = new TlcEnvironment())
        {
            var pokerHandData = new List<PokerHandData>();
            var textLoader = new Microsoft.ML.Data.TextLoader(path).CreateFrom<PokerHandData>(useHeader: false, separator: ',', trimWhitespace: false);
            var experiment = environment.CreateExperiment();
            var output = textLoader.ApplyStep(null, experiment) as ILearningPipelineDataStep;
    
            experiment.Compile();
            textLoader.SetInput(environment, experiment);
            experiment.Run();
    
            var data = experiment.GetOutput(output.Data);
    
            using (var cursor = data.GetRowCursor((a => true)))
            {
                var getters = new ValueGetter<float>[]{
                    cursor.GetGetter<float>(0),
                    cursor.GetGetter<float>(1),
                    cursor.GetGetter<float>(2),
                    cursor.GetGetter<float>(3),
                    cursor.GetGetter<float>(4),
                    cursor.GetGetter<float>(5),
                    cursor.GetGetter<float>(6),
                    cursor.GetGetter<float>(7),
                    cursor.GetGetter<float>(8),
                    cursor.GetGetter<float>(9),
                    cursor.GetGetter<float>(10)
                };
    
                while (cursor.MoveNext())
                {
                    float value0 = 0;
                    float value1 = 0;
                    float value2 = 0;
                    float value3 = 0;
                    float value4 = 0;
                    float value5 = 0;
                    float value6 = 0;
                    float value7 = 0;
                    float value8 = 0;
                    float value9 = 0;
                    float value10 = 0;
                    getters[0](ref value0);
                    getters[1](ref value1);
                    getters[2](ref value2);
                    getters[3](ref value3);
                    getters[4](ref value4);
                    getters[5](ref value5);
                    getters[6](ref value6);
                    getters[7](ref value7);
                    getters[8](ref value8);
                    getters[9](ref value9);
                    getters[10](ref value10);
    
                    var hands = new PokerHandData()
                    {
                        S1 = value0,
                        C1 = value1,
                        S2 = value2,
                        C2 = value3,
                        S3 = value4,
                        C3 = value5,
                        S4 = value6,
                        C4 = value7,
                        S5 = value8,
                        C5 = value9,
                        Power = value10
                    };
                    hands.Init();
                    pokerHandData.Add(hands);
                }
            }
    
            return pokerHandData;
        }
    }

    其中PokerHandData类增加一个初始化的方法。

    public void Init()
    {
        IsSameSuit = GetIsSameSuit();
        IsStraight = GetIsStraight();
        FourOfKind = GetFourOfKind();
        ThreeOfKind = GetThreeOfKind();
        PairsCount = GetPairsCount();
    }

    训练模型


    预测的结构定义,以计分为目标,float[]类型表示是对每一种牌型有一个得分,分值越高属于那一种牌型的概率越大。

    public class PokerHandPrediction
    {
        [ColumnName("Score")]
        public float[] PredictedPower;
    }

    模型的选择是LogisticRegressionClassifier,CollectionDataSource就是用来创建集合类型数据载入的对象。而特征的指定不再是全部字段,而是之前增加的那几个。

    public static PredictionModel<PokerHandData, PokerHandPrediction> Train(IEnumerable<PokerHandData> data)
    {
        var pipeline = new LearningPipeline();
        var collection = CollectionDataSource.Create(data);
        pipeline.Add(collection);
        pipeline.Add(new ColumnConcatenator("Features", "IsSameSuit", "IsStraight", "FourOfKind", "ThreeOfKind", "PairsCount"));
        pipeline.Add(new LogisticRegressionClassifier());
        var model = pipeline.Train<PokerHandData, PokerHandPrediction>();
        return model;
    }

    预测结果


    首先,对预测的得分,我们需要判断一个概率倾向。

    static string GetPower(float[] nums)
    {
        var index = -1;
        var last = 0F;
        for (int i = 0; i < nums.Length; i++)
        {
            if (nums[i] > last)
            {
                index = i;
                last = nums[i];
            }
        }
    var suit = string.Empty; switch (index) { case 0: suit = "高牌"; break; case 1: suit = "一对"; break; case 2: suit = "两对"; break; case 3: suit = "三条"; break; case 4: suit = "顺子"; break; case 5: suit = "同花"; break; case 6: suit = "葫芦"; break; case 7: suit = "四条"; break; case 8: suit = "同花顺"; break; case 9: suit = "皇家同花顺"; break; } return suit; }

    最后就是进行预测的部分了。

    public static void Predict(PredictionModel<PokerHandData, PokerHandPrediction> model)
    {
        var test1 = new PokerHandData
        {
            S1 = 1,
            C1 = 2,
            S2 = 1,
            C2 = 3,
            S3 = 3,
            C3 = 4,
            S4 = 4,
            C4 = 5,
            S5 = 2,
            C5 = 6
        };
    
        var test2 = new PokerHandData
        {
            S1 = 4,
            C1 = 5,
            S2 = 1,
            C2 = 5,
            S3 = 3,
            C3 = 5,
            S4 = 2,
            C4 = 12,
            S5 = 4,
            C5 = 7
        };
        test1.Init();
        test2.Init();
        IEnumerable<PokerHandData> pokerHands = new[]
        {
            test1,
            test2
        };
        IEnumerable<PokerHandPrediction> predictions = model.Predict(pokerHands);
        Console.WriteLine();
        Console.WriteLine("PokerHand Predictions");
        Console.WriteLine("---------------------");
    
        var pokerHandsAndPredictions = pokerHands.Zip(predictions, (pokerHand, prediction) => (pokerHand, prediction));
        foreach (var (pokerHand, prediction) in pokerHandsAndPredictions)
        {
            Console.WriteLine($"PokerHand: {ShowHand(pokerHand)} | Prediction: { GetPower(prediction.PredictedPower)}");
        }
        Console.WriteLine();
    
    }

    创建项目的步骤请参看本文开头导读给出的文章链接,不再赘述,运行结果如下:

    最后放出源代码文件:下载

    希望读者们保持对ML.NET的持续关注,相信新的特性一定能实现更复杂有趣的场景。

  • 相关阅读:
    STL目录
    Hola!
    SWPUCTF 2019总结以及部分WP
    SQL手工注入基础篇
    JDK11,JDK12没有JRE的解决方法
    FJUT2019暑假周赛三部分题解
    FJUT2019暑假周赛一题解
    随笔1
    关于针对本校教务系统漏洞的一次信息检索
    KMP算法讲解
  • 原文地址:https://www.cnblogs.com/BeanHsiang/p/9080358.html
Copyright © 2011-2022 走看看