zoukankan      html  css  js  c++  java
  • BankNote

     1 # coding=utf-8
     2 import pandas as pd
     3 import numpy as np
     4 from sklearn import cross_validation
     5 import tensorflow as tf
     6 
     7 global flag
     8 flag=0
     9 
    10 def DataPreprocessing():
    11     abalone = pd.read_csv("ceshi.csv", sep=',', header=0, keep_default_na=True,na_values=[])
    12     X_train=np.array(abalone.iloc[:,:4])
    13     Y_train=np.array(abalone.iloc[:,4:])
    14     # Y_train=[]
    15     # for i in range(len(X_train)):
    16     #     if X_train[i][0] == 'M':
    17     #         X_train[i][0]=0
    18     #     elif X_train[i][0]=='F':
    19     #         X_train[i][0]=1
    20     #     else:
    21     #         X_train[i][0]=2
    22     #
    23     # for i in range(len(Y_train_)):
    24     #
    25     #     #print(Y_train[i][0])
    26     #     Y_train.append(Y_train_[i][0])
    27 
    28     # print(X_train)
    29     # print(len(X_train))
    30     # print(Y_train)
    31     # print(len(Y_train))
    32    # print(min(Y_train))
    33    # print(max(Y_train))
    34 
    35     return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)
    36 
    37 
    38 def GetInputs():
    39     global flag
    40     X_train, X_test, Y_train, Y_test = DataPreprocessing()
    41 
    42     #print(X_train)
    43     # print(len(X_test))
    44     # print(len(Y_train))
    45     # print(len(Y_test))
    46 
    47 
    48     #X_train[X_train.isnull().any(axis=1)]
    49     #X_train.fillna('',inplace=True)
    50 
    51     print(X_train)
    52     print(Y_test)
    53 
    54     x_train=tf.constant(X_train)
    55     y_train=tf.constant(Y_train)
    56     x_test=tf.constant(X_test)
    57     y_test=tf.constant(Y_test)
    58 
    59     print(x_train)
    60     print(y_train)
    61     print(x_test)
    62     print(y_test)
    63 
    64     if flag==0:
    65         return x_train,y_train
    66     else:
    67         return x_test,y_test
    68 
    69 
    70 def Main():
    71 
    72     global flag
    73 
    74     feature_columns=[tf.contrib.layers.real_valued_column("",dimension=4)]
    75 
    76     clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10,20,10],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/banknote")
    77 
    78     clf.fit(input_fn=GetInputs,steps=2000)
    79 
    80     flag=1
    81     accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]
    82 
    83     print("nTest Accuracy:{0:f}".format(accuracy_score))
    84 
    85 if __name__ =="__main__":
    86     #DataPreprocessing()
    87 
    88     Main()
    89 
    90 exit(0)
  • 相关阅读:
    Java实现 LeetCode 69 x的平方根
    Java实现 LeetCode 68 文本左右对齐
    Java实现 LeetCode 68 文本左右对齐
    Java实现 LeetCode 68 文本左右对齐
    Java实现 LeetCode 67 二进制求和
    Java实现 LeetCode 67 二进制求和
    Java实现 LeetCode 67 二进制求和
    Java实现 LeetCode 66 加一
    Java实现 LeetCode 66 加一
    CxSkinButton按钮皮肤类
  • 原文地址:https://www.cnblogs.com/acm-jing/p/9097373.html
Copyright © 2011-2022 走看看