zoukankan      html  css  js  c++  java
  • 李宏毅老师hw2 _classification模型

    一、数据处理

    针对原始数据中顺序数据和分类数据我们采用one-hot模型将其转化为数值类型数据

     1 '''
     2 将顺序数据、分类数据采用one-hot编码
     3 '''
     4 def deal_one_hot_coding(train_data):
     5     '''需要编码的object类型属性'''
     6     object_attribute_trian_data=train_data.select_dtypes(include=[object])
     7     '''不需要编码的int类型属性'''
     8     int_attribute_trian_data_df=train_data.select_dtypes(exclude=[object])
     9     '''转化'''
    10     enc=OneHotEncoder(sparse=False)
    11     new_object_attribute_trian_data=enc.fit_transform(object_attribute_trian_data)#转化后的类型为ndarray
    12     '''将ndarray转化为dataframe'''
    13     new_object_attribute_trian_data_df=pd.DataFrame(new_object_attribute_trian_data,index=range(54256),columns=range(417))
    14     '''拼接转化数据和不需要转化的数据'''
    15     new_train_data=pd.concat([int_attribute_trian_data_df,new_object_attribute_trian_data_df], axis=1)
    16     return new_train_data #(54256, 424)

    数据预处理步骤

    '''
    数据预处理方法
    '''
    def data_pre_processing():
        '''将y转码为true false'''
        train_data=deal_trage_Y(original_data)
        '''拆分出XY数据部分'''
        train_data_X=train_data.iloc[:,range(29)]
        train_data_Y = train_data.iloc[:, range(30,31)]
        '''X数据one_hot编码'''
        train_data_X=deal_one_hot_coding(train_data_X)
        return train_data_X,train_data_Y

    二、模型主体部分

    手写sigmod函数

     1 '''
     2 sigmod函数
     3 '''
     4 def sigmod(y_model):
     5     sigmod_y_model=list()
     6     for ele in y_model[0]:
     7         sigmod_value=1/(1+math.exp(-ele))
     8         sigmod_y_model.append(sigmod_value)
     9     sigmod_y_model=[ 0.999999 if i>0.999999999 else i for i in sigmod_y_model]#解决sigmod函数敏感度问题
    10     return np.array(sigmod_y_model)

    在这里有一个sigmod函数敏感度问题当你的z过大时sig(z)=1/(1+exp(1))是无限接近1的虽然不是1,但是由于计算机精度问题导致他在存储的时候回直接当做1来存储,如果有这种情况在计算loss值的时候会产生数学计算错误,因此使用0.99999代替1

    模型主体函数

     1 '''
     2 计算loss值
     3 形式为两个伯努利分布的交叉熵
     4 '''
     5 def get_Loss(Y_model,Y):
     6     #计算工资大于5k的概率
     7     loss_list=list()
     8     for ele in zip(Y_model,Y.T.values[0]):
     9         if ele[1]==True:
    10             loss_list.append(-math.log(ele[0],math.e))
    11         else:
    12             loss_list.append(-math.log(1-ele[0],math.e))
    13     return loss_list
    14 
    15 '''
    16 更新参数
    17 '''
    18 def updata_W_b(loss_list,W,b,Y,y_model,X):
    19     learnning_rate=0.75
    20     sigmod_Z=y_model
    21     for ele in zip(loss_list,Y.T.values[0]):
    22         if ele[1]==True:
    23             updata=1-sigmod_Z
    24         else:
    25             updata=-sigmod_Z
    26         #更新W
    27         for index in range(424):
    28             new_Wi=W[0][index]-learnning_rate*updata*X.iloc[0,index]
    29             W[0][index]=new_Wi
    30         #更新b
    31         b=b-updata
    32     return W,b
    33 
    34 '''
    35 模型主体
    36 计算线性回归和sigmoid函数
    37 '''
    38 def model(train_X,train_Y,W0,b0):
    39     W_train=W0
    40     b_train=b0
    41     loss_value_list=list()
    42     '''
    43     batch_size 设置为8
    44     '''
    45     for index in range(0,54256,1):
    46         #创建XY数据
    47         X=train_X.iloc[range(index,index+1),:]
    48         Y_true=train_Y.iloc[range(index,index+1),:]
    49         #线性部分计算值
    50         y_model=linear_model(X,W_train,b_train)
    51         #sigmod函数
    52         y_model=sigmod(y_model)
    53         # 计算Loss值
    54         loss_value=get_Loss(y_model,Y_true)
    55         loss_value_list.append(loss_value)
    56         print(''+str(index)+"次loss值为:"+str(loss_value[0]))
    57         #更新参数
    58         W_train,b_train=updata_W_b(loss_value,W_train,b_train,Y_true,y_model,X)
    59     x_index = [i for i in range(len(loss_value_list))]
    60     plt.plot(x_index, loss_value_list, color='red', linewidth=2.0, linestyle='-')
    61     plt.title('54256次每一次迭代的Loss值')
    62     plt.ylabel('Loss值')
    63     plt.xlabel('次数')
    64     plt.show()

    三、训练结果

    正确率在78%左右

  • 相关阅读:
    狗 日 的 360
    Django搭建简单的站点
    ZOJ 3675 Trim the Nails(bfs)
    Qt移动应用开发(二):使用动画框架
    SPOJ QTREE2 lct
    [Phonegap+Sencha Touch] 移动开发77 Cordova Hot Code Push插件实现自己主动更新App的Web内容
    Bitmap工具类BitmapHelper
    Material-design icon生成插件
    闲聊ROOT权限——ROOT权限的前世今生
    Java深入浅出系列(四)——深入剖析动态代理--从静态代理到动态代理的演化
  • 原文地址:https://www.cnblogs.com/SAM-CJM/p/13918101.html
Copyright © 2011-2022 走看看