zoukankan      html  css  js  c++  java
  • IMDB-二分类问题

    from keras.datasets import imdb
    from keras.utils.np_utils import to_categorical
    import numpy as np
    from keras import models
    from keras import layers
    import matplotlib.pyplot as plt
    #one-hot编码
    def vectorize_sequences(sequences,dimension = 10000):
        results = np.zeros((len(sequences),dimension))
        for i,sequence in enumerate(sequences):
            results[i,sequence] = 1
        return results
    #imdb是一个二分类问题
    #一共有5w条数据,2.5w用于训练,2.5w用于测试
    #每条数据是一个list,list里保存的是英文单词对应的排序
    #num_words=10000表示保留前1w个常出现的单词
    (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=10000)
    #下面的代码用来解码第一条数据的内容
    data = x_train[0]
    #word_index是一个dict,保存的是英文单词:单词排序位置
    word_index = imdb.get_word_index()
    index_word = dict((index,word) for (word,index) in word_index.items())
    #i-3是because 0, 1 and 2 are reserved indices for "padding", "start of sequence", and "unknown".
    data = ''.join(index_word.get(i-3,'?') for i in data)
    ######################################################
    #神经网络的输入得是一个张量,使用one-hot编码处理数据
    x_train = vectorize_sequences(x_train)
    x_test = vectorize_sequences(x_test)
    #keras的输入数据要转换为float类型,y是int类型,做一个类型转换
    
    #构建神经网络
    network = models.Sequential()
    network.add(layers.Dense(16,activation='relu'))
    network.add(layers.Dense(16,activation='relu'))
    network.add(layers.Dense(1,activation='sigmoid'))
    
    #选择优化器、损失函数、评估准则
    network.compile('rmsprop',loss='binary_crossentropy',metrics=['accuracy'])
    
    #训练模型
    history = network.fit(x_train,y_train,epochs=5,batch_size=512,validation_split=0.2)
    
    
    history_dict = history.history
    loss = history_dict['loss']
    val_loss = history_dict['val_loss']
    acc = history_dict['acc']
    val_acc = history_dict['val_acc']
    
    epochs = range(1,6)
    #loss的图
    plt.subplot(121)
    plt.plot(epochs,loss,'g',label = 'Training loss')
    plt.plot(epochs,val_loss,'b',label = 'Validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    #显示图例
    plt.legend()
    
    plt.subplot(122)
    plt.plot(epochs,acc,'g',label = 'Training accuracy')
    plt.plot(epochs,val_acc,'b',label = 'Validation accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('accuracy')
    plt.legend()
    plt.show()
    
    pre = network.predict(x_test)
    print(pre)
    print(y_test)

  • 相关阅读:
    感悟优化——Netty对JDK缓冲区的内存池零拷贝改造
    由浅入深理解Java线程池及线程池的如何使用
    Http学习笔记
    zookeeper集群配置详细教程
    kafka学习笔记——基本概念与安装
    干货——详解Java中的关键字
    Java基础巩固——排序
    你可以这么理解五种I/O模型
    Java中的NIO基础知识
    Java基础巩固——异常
  • 原文地址:https://www.cnblogs.com/vshen999/p/10457180.html
Copyright © 2011-2022 走看看