zoukankan      html  css  js  c++  java
  • 机器学习笔记:sklearn.model_selection.train_test_split切分训练、测试集

    一、背景

    接上所叙,在对比训练集、验证集、测试集之后,实战中需要对数据进行划分。

    通常将原始数据按比例划分为:训练集、测试集。

    可以利用 sklearn.model_selection.train_test_split 方法实现。

    二、介绍

    使用语法为:

    x_train, x_test, y_train, y_test = sklearn.model_selection.train_test_split(
        data, 
        target, 
        test_size=None,
        train_size=None, 
        random_state=None,
        shuffle=True,
        stratify=None
    )
    

    参数解释:

    data  -- 所要划分的样本特征集
    target -- 样本结果
    test_size -- 测试集样本占比(如果是整数,就是样本数量)
    train_size -- 同上 默认 0.75
    random_state -- 随机种子(保证可复现)
    shuffle -- 是否洗牌 打散数据
    stratify -- 保持类分布一致
    

    三、实操

    1.举个例子

    # 生成测试数据
    import numpy as np
    from sklearn.model_selection import train_test_split
    X, y = np.arange(10).reshape(5, 2), range(5)
    y = list(y)
    
    # 切分
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    

    2.test_size & train_size

    float or int, default=None 默认空缺,浮点数或者整型。

    代表测试集的大小,如果是小数的话,处于(0, 1)之间,表示测试集所占比例。

    如果是整数,表示测试集样本个数。

    如果 train_size 也为空,则默认 test_size=0.25

    import numpy as np
    from sklearn.model_selection import train_test_split
    x = np.random.randint(1, 100, 20).reshape((10, 2))
    x_train, x_test = train_test_split(x) # 7 3
    x_train, x_test = train_test_split(x, train_size=4) # 4 6
    x_train, x_test = train_test_split(x, train_size=0.8) # 8 2
    

    3.random_state

    表示随机状态,因为每次划分都是随机的,为保证模型可反复实验,需要控制该参数不变。

    具体不实验了,多操作几遍即可得知。

    4.shuffle

    在划分数据之前,是否重洗数据,即将数据打散重新排序。

    默认重洗。

    # 按原始顺序抽取
    x_train, x_test = train_test_split(x, train_size=0.8, shuffle=False)
    

    5.stratify

    结合结果集使用,保证分类分布一致。

    假设原始结果集中有2个分类,A:B=1:2,随机切分时,无法保证训练集和测试集中A与B的比例,此时,通过设置 stratify=y 参数,即可控制分布一致。

    注:此处的y为原始结果集。

    通常在种类分布不平衡的情况下使用该参数。

    将 stratify=X 就是按照X中的比例分配

    将 stratify=y 就是按照y中的比例分配

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, stratify=y)
    

    参考链接:sklearn之train_test_split()函数各参数含义(非常全)

    参考链接:sklearn.model_selection.train_test_split

    参考链接:sklearn函数:train_test_split(分割训练集和测试集)

  • 相关阅读:
    MySQL视图和存储过程
    MySQL数据操作
    Pair RDD编程
    HDFS组成架构和读写数据流程
    RDD编程
    MySQL数据查询和函数
    数据库学习(二)
    玩爆你的手机联系人--T9搜索(一)
    POJ3259 Wormholes 【Bellmanford推断是否存在负回路】
    Java 小技巧和在Java避免NullPonintException的最佳方法(翻译)
  • 原文地址:https://www.cnblogs.com/hider/p/15785119.html
Copyright © 2011-2022 走看看