zoukankan      html  css  js  c++  java
  • 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)

    本篇博客代码来自于《动手学深度学习》pytorch版,也是代码较多,解释较少的一篇。不过好多方法在我以前的博客都有提,所以这次没提。还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂(前提是python语法大概了解),这是我不加很多解释的重要原因。

    K折交叉验证实现

    def get_k_fold_data(k, i, X, y):
        # 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据
        assert k > 1
        fold_size = X.shape[0] // k  # 双斜杠表示除完后再向下取整
        X_train, y_train = None, None
        for j in range(k):
            idx = slice(j * fold_size, (j + 1) * fold_size)  #slice(start,end,step)切片函数
            X_part, y_part = X[idx, :], y[idx]
            if j == i:
                X_valid, y_valid = X_part, y_part
            elif X_train is None:
                X_train, y_train = X_part, y_part
            else:
                X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接
                y_train = torch.cat((y_train, y_part), dim=0)
        return X_train, y_train, X_valid, y_valid
    
    def k_fold(k, X_train, y_train, num_epochs,learning_rate, weight_decay, batch_size):
        train_l_sum, valid_l_sum = 0, 0
        for i in range(k):
            data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据
            net = get_net(X_train.shape[1])  #get_net在这是一个基本的线性回归模型,方法实现见附录1
            train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                       weight_decay, batch_size)  #train方法见后面附录2
            train_l_sum += train_ls[-1]
            valid_l_sum += valid_ls[-1]
            if i == 0:
                d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',
                             range(1, num_epochs + 1), valid_ls,
                             ['train', 'valid'])   #画图,且是对y求对数了,x未变。方法实现见附录3
            print('fold %d, train rmse %f, valid rmse %f' % (i, train_ls[-1], valid_ls[-1]))
        return train_l_sum / k, valid_l_sum / k
    

     *args:表示接受任意长度的参数,然后存放入一个元组中;如def fun(*args) print(args),‘fruit','animal','human'作为参数传进去,输出(‘fruit','animal','human')

    **kwargs:表示接受任意长的参数,然后存放入一个字典中;如

    def fun(**kwargs):   
        for key, value in kwargs.items():
            print("%s:%s" % (key,value)
    

    fun(a=1,b=2,c=3)会输出 a=1 b=2 c=3

    附录1

    loss = torch.nn.MSELoss()
    
    def get_net(feature_num):
        net = nn.Linear(feature_num, 1)
        for param in net.parameters():
            nn.init.normal_(param, mean=0, std=0.01) 
        return net

    附录2

    def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):
        train_ls, test_ls = [], []
        dataset = torch.utils.data.TensorDataset(train_features, train_labels)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True) #TensorDataset和DataLoader的使用请查看我以前的博客
        
        #这里使用了Adam优化算法
        optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)
        net = net.float()
        for epoch in range(num_epochs):
            for X, y in train_iter:
                l = loss(net(X.float()), y.float())
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
            train_ls.append(log_rmse(net, train_features, train_labels))
            if test_labels is not None:
                test_ls.append(log_rmse(net, test_features, test_labels))
        return train_ls, test_ls
    

     附录3

    def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):
        set_figsize(figsize)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.semilogy(x_vals, y_vals)
        if x2_vals and y2_vals:
            plt.semilogy(x2_vals, y2_vals, linestyle=':')
            plt.legend(legend)

     注:由于最近有其他任务,所以此博客写的匆忙,等我有时间后会丰富,也可能加详细解释。

  • 相关阅读:
    Hibernate save, saveOrUpdate, persist, merge, update 区别
    Eclipse下maven使用嵌入式(Embedded)Neo4j创建Hello World项目
    Neo4j批量插入(Batch Insertion)
    嵌入式(Embedded)Neo4j数据库访问方法
    Neo4j 查询已经创建的索引与约束
    Neo4j 两种索引Legacy Index与Schema Index区别
    spring data jpa hibernate jpa 三者之间的关系
    maven web project打包为war包,目录结构的变化
    创建一个maven web project
    Linux下部署solrCloud
  • 原文地址:https://www.cnblogs.com/JadenFK3326/p/12164519.html
Copyright © 2011-2022 走看看