zoukankan      html  css  js  c++  java
  • Chess

     1 # coding=utf-8
     2 import pandas as pd
     3 import numpy as np
     4 from sklearn import cross_validation
     5 import tensorflow as tf
     6 
     7 global flag
     8 flag=0
     9 
    10 def DataPreprocessing():
    11     abalone = pd.read_csv("train_data.csv", sep=',', header=0, keep_default_na=True)
    12     X_train=np.array(abalone.iloc[:,:6])
    13     Y_train_=np.array(abalone.iloc[:,6:])
    14     print(X_train)
    15     Y_train=[]
    16     for i in range(len(X_train)):
    17 
    18         X_train[i][0] = ord(X_train[i][0])-97
    19         X_train[i][2] = ord(X_train[i][2])-97
    20         X_train[i][4] = ord(X_train[i][4])-97
    21 
    22     # for i in range (len(X_train)):
    23     #     for j in range(6):
    24     #         X_train[i][j]=X_train[i][j]-0.0
    25     #
    26     #X_train.astype(np.float64)
    27   #  print(X_train,type(X_train),X_train[0][0],type(X_train[0][0]))
    28 
    29     #binary classifier
    30     for i in range(len(Y_train_)):
    31 
    32         if Y_train_[i][0]=="draw":
    33             Y_train.append(0)
    34         else:
    35             Y_train.append(1)
    36 
    37 
    38     # multiple classifer
    39 
    40     return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)
    41 
    42 
    43 def GetInputs():
    44     global flag
    45     X_train, X_test, Y_train, Y_test = DataPreprocessing()
    46 
    47     #print(type(X_train),type(X_train[0][0]))
    48     #print(X_train)
    49     # print(len(X_test))
    50     # print(len(Y_train))
    51     # print(len(Y_test))
    52 
    53 
    54     #X_train[X_train.isnull().any(axis=1)]
    55     #X_train.fillna('',inplace=True)
    56 
    57     # print(X_train)
    58     # print(Y_test)
    59 
    60     x_train=tf.constant(X_train)
    61     y_train=tf.constant(Y_train)
    62     x_test=tf.constant(X_test)
    63     y_test=tf.constant(Y_test)
    64     #
    65     # print(x_train)
    66     # print(y_train)
    67     # print(x_test)
    68     # print(y_test)
    69 
    70     if flag==0:
    71         return x_train,y_train
    72     else:
    73         return x_test,y_test
    74 
    75 
    76 def Main():
    77 
    78     global flag
    79 
    80     feature_columns=[tf.contrib.layers.real_valued_column("",dimension=6)]
    81 
    82     clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[20,40,20],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/chess")
    83 
    84     clf.fit(input_fn=GetInputs,steps=2000)
    85 
    86     flag=1
    87     accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]
    88 
    89     print("nTest Accuracy:{0:f}".format(accuracy_score))
    90 
    91 if __name__ =="__main__":
    92     #DataPreprocessing()
    93 
    94     Main()
    95 
    96 exit(0)
    View Code
  • 相关阅读:
    Spyder | 关于报错No module named 'PyQt5.QtWebKitWidgets'
    Java基础(11) | 接口
    Java基础(10) | 抽象
    Java基础(9) | 继承
    Java基础(7) | String
    Java基础(6) | ArrayList
    CodeBlocks17.12配置GNU GCC + 汉化
    图片懒加载
    Mac安装Mysql 超详细(转载)
    剑指 Offer 03. 数组中重复的数字
  • 原文地址:https://www.cnblogs.com/acm-jing/p/9098996.html
Copyright © 2011-2022 走看看