zoukankan      html  css  js  c++  java
  • TensorFlow(一) 鸢尾花采用批量数据进行线性模拟

    import matplotlib.pyplot as plt
    import  numpy as np
    from sklearn import  datasets
    import tensorflow as tf
    sess=tf.Session()
    iris=datasets.load_iris()
    #print(iris)
    target=np.array([1. if x==0 else 0. for x in iris.target  ])
    #print(target.shape)
    iris_data=np.array([ [x[2],x[3]]  for x in iris.data] ) #shape[none,2]
    
    #声明批量
    batch_size=20
    #宽度长度 均为【NOne,1】
    x1_data=tf.placeholder(tf.float32,shape=[None,1])
    x2_data=tf.placeholder(tf.float32,shape=[None,1])
    y_target=tf.placeholder(tf.float32,shape=[None,1])
    #初始化类型为1,1  可以和x1 相乘
    A=tf.Variable(tf.random_normal(shape=[1,1]))
    b=tf.Variable(tf.random_normal(shape=[1,1]))
    
    #线性模型 x1=x2*A+b  --》 f=x1-x2*A-b
    my_mult=tf.matmul(x2_data,A)
    my_add=tf.add(my_mult,b)
    my_output=tf.subtract(x1_data,my_add)
    
    #损失函数 (交叉熵损失函数 非归一化 常用于两类验证)
    sigmoid_logits=tf.nn.sigmoid_cross_entropy_with_logits(labels=y_target,logits=my_output)
    
    #梯度下降取最小值 (选择学习率0.05)
    my_opt=tf.train.GradientDescentOptimizer(0.05)
    train_step=my_opt.minimize(sigmoid_logits)
    
    #初始化所有声明的变量
    init=tf.global_variables_initializer()
    sess.run(init)
    
    #迭代100次 训练模型 传入三种数据 长度 宽度 和目标
    for i in range(1500):
        #随机获取批量数据 根据(iris_data)的长度已经确定
        rand_index=np.random.choice(len(iris_data) ,batch_size)
        #shape=[batchsize,1]
        x1_rand= np.array([[iris_data[x][0]]  for x in rand_index],dtype=np.float32)
        x2_rand = np.array([[iris_data[x][1]] for x in rand_index],dtype=np.float32)
        y_rand=np.array([[target[x]] for x in rand_index],dtype=np.float32)
    
        sess.run(train_step,feed_dict={x1_data:x1_rand,x2_data:x2_rand,y_target:y_rand})
        if (i+1)%200 ==0:
            print('Step %s :A= %s  ; b=%s   ' % ( i+1,str(sess.run(A)), str(sess.run(b))   ))
    
    #保存A,b
    [[slope]]=sess.run(A)
    [[intercept]]=sess.run(b)
    x=np.linspace(0,3,num=50)
    abline=[]
    for i in x:
        abline.append(slope*i+intercept)
    #重新选取数据  从目标1中选取 长度 宽度
    set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 1]
    set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 1]
    #重新选取数据 从目标0中 选取长度宽度
    no_set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 0]
    no_set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 0]
    
    plt.plot(set1_x,set1_y,'rx',ms=10,mew=2,label='set1')
    #plt.clabel('set1')
    
    plt.plot(no_set1_x,no_set1_y,'ro',label='set0')
    #plt.clabel('set0')
    
    plt.plot(x,abline,'b-',label='my')
    plt.xlim([0.0,2.7])
    plt.ylim([0.0,7.1])
    plt.xlabel('length')
    plt.ylabel('width')
    plt.legend(loc='lower right')
    plt.show()

  • 相关阅读:
    【ElasticSearch】异常错误
    【CentOS7】系统设置
    【Ubuntu 18.04.03_64】系统配置
    【MySql】语法学习
    【ElasticSearch】聚合使用学习
    【Spring Boot】Spring Security登陆异常出路
    【ElasticSearch】查询使用学习
    Spring boot X-Frame-Options 异常 a frame because it set 'X-Frame-Options' to 'deny'
    【Thymeleaf】使用学习
    【MySql】日期时间
  • 原文地址:https://www.cnblogs.com/x0216u/p/9167229.html
Copyright © 2011-2022 走看看