zoukankan      html  css  js  c++  java
  • 人工智能-根据姓名判断性别

    本帖训练一个可以根据姓名判断性别的CNN模型;我使用自己爬取的35万中文姓名进行训练。

    使用同样的数据集还可以训练起名字模型,参看:

    • TensorFlow练习7: 基于RNN生成古诗词

    • https://github.com/tensorflow/models/tree/master/namignizer

    • TensorFlow练习13: 制作一个简单的聊天机器人

    准备姓名数据集

    我上网找了一下,并没有找到现成的中文姓名数据集,额,看来只能自己动手了。

    我写了一个简单的Python脚本,爬取了上万中文姓名,格式整理如下:

    [python] view plain copy

    1. 姓名,性别  

    2. 安镶怡,女  

    3. 饶黎明,男  

    4. 段焙曦,男  

    5. 苗芯萌,男  

    6. 覃慧藐,女  

    7. 芦玥微,女  

    8. 苏佳琬,女  

    9. 王旎溪,女  

    10. 彭琛朗,男  

    11. 李昊,男  

    12. 利欣怡,女  

    13. # 貌似有很多名字男女通用  

    数据集:https://pan.baidu.com/s/1hsHTEU4。

    训练模型

    [python] view plain copy

    1. import tensorflow as tf  

    2. import numpy as np  

    3.    

    4. name_dataset = 'name.csv'  

    5.    

    6. train_x = []  

    7. train_y = []  

    8. with open(name_dataset, 'r') as f:  

    9.     first_line = True  

    10.     for line in f:  

    11.         if first_line is True:  

    12.             first_line = False  

    13.             continue  

    14.         sample = line.strip().split(',')  

    15.         if len(sample) == 2:  

    16.             train_x.append(sample[0])  

    17.             if sample[1] == '男':  

    18.                 train_y.append([01])  # 男  

    19.             else:  

    20.                 train_y.append([10])  # 女  

    21.    

    22. max_name_length = max([len(name) for name in train_x])  

    23. print("最长名字的字符数: ", max_name_length)  

    24. max_name_length = 8  

    25.    

    26. # 数据已shuffle  

    27. #shuffle_indices = np.random.permutation(np.arange(len(train_y)))  

    28. #train_x = train_x[shuffle_indices]  

    29. #train_y = train_y[shuffle_indices]  

    30.    

    31. # 词汇表(参看聊天机器人练习)  

    32. counter = 0  

    33. vocabulary = {}  

    34. for name in train_x:  

    35.     counter += 1  

    36.     tokens = [word for word in name]  

    37.     for word in tokens:  

    38.         if word in vocabulary:  

    39.             vocabulary[word] += 1  

    40.         else:  

    41.             vocabulary[word] = 1  

    42.    

    43. vocabulary_list = [' '] + sorted(vocabulary, key=vocabulary.get, reverse=True)  

    44. print(len(vocabulary_list))  

    45.    

    46. # 字符串转为向量形式  

    47. vocab = dict([(x, y) for (y, x) in enumerate(vocabulary_list)])  

    48. train_x_vec = []  

    49. for name in train_x:  

    50.     name_vec = []  

    51.     for word in name:  

    52.         name_vec.append(vocab.get(word))  

    53.     while len(name_vec) < max_name_length:  

    54.         name_vec.append(0)  

    55.     train_x_vec.append(name_vec)  

    56.    

    57. #######################################################  

    58.    

    59. input_size = max_name_length  

    60. num_classes = 2  

    61.    

    62. batch_size = 64  

    63. num_batch = len(train_x_vec) // batch_size  

    64.    

    65. X = tf.placeholder(tf.int32, [None, input_size])  

    66. Y = tf.placeholder(tf.float32, [None, num_classes])  

    67.    

    68. dropout_keep_prob = tf.placeholder(tf.float32)  

    69.    

    70. def neural_network(vocabulary_size, embedding_size=128, num_filters=128):  

    71.     # embedding layer  

    72.     with tf.device('/cpu:0'), tf.name_scope("embedding"):  

    73.         W = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.01.0))  

    74.         embedded_chars = tf.nn.embedding_lookup(W, X)  

    75.         embedded_chars_expanded = tf.expand_dims(embedded_chars, -1)  

    76.     # convolution + maxpool layer  

    77.     filter_sizes = [3,4,5]  

    78.     pooled_outputs = []  

    79.     for i, filter_size in enumerate(filter_sizes):  

    80.         with tf.name_scope("conv-maxpool-%s" % filter_size):  

    81.             filter_shape = [filter_size, embedding_size, 1, num_filters]  

    82.             W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1))  

    83.             b = tf.Variable(tf.constant(0.1, shape=[num_filters]))  

    84.             conv = tf.nn.conv2d(embedded_chars_expanded, W, strides=[1111], padding="VALID")  

    85.             h = tf.nn.relu(tf.nn.bias_add(conv, b))  

    86.             pooled = tf.nn.max_pool(h, ksize=[1, input_size - filter_size + 111], strides=[1111], padding='VALID')  

    87.             pooled_outputs.append(pooled)  

    88.    

    89.     num_filters_total = num_filters * len(filter_sizes)  

    90.     h_pool = tf.concat(3, pooled_outputs)  

    91.     h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])  

    92.     # dropout  

    93.     with tf.name_scope("dropout"):  

    94.         h_drop = tf.nn.dropout(h_pool_flat, dropout_keep_prob)  

    95.     # output  

    96.     with tf.name_scope("output"):  

    97.         W = tf.get_variable("W", shape=[num_filters_total, num_classes], initializer=tf.contrib.layers.xavier_initializer())  

    98.         b = tf.Variable(tf.constant(0.1, shape=[num_classes]))  

    99.         output = tf.nn.xw_plus_b(h_drop, W, b)  

    100.           

    101.     return output  

    102. # 训练  

    103. def train_neural_network():  

    104.     output = neural_network(len(vocabulary_list))  

    105.    

    106.     optimizer = tf.train.AdamOptimizer(1e-3)  

    107.     loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, Y))  

    108.     grads_and_vars = optimizer.compute_gradients(loss)  

    109.     train_op = optimizer.apply_gradients(grads_and_vars)  

    110.    

    111.     saver = tf.train.Saver(tf.global_variables())  

    112.     with tf.Session() as sess:  

    113.         sess.run(tf.global_variables_initializer())  

    114.    

    115.         for e in range(201):  

    116.             for i in range(num_batch):  

    117.                 batch_x = train_x_vec[i*batch_size : (i+1)*batch_size]  

    118.                 batch_y = train_y[i*batch_size : (i+1)*batch_size]  

    119.                 _, loss_ = sess.run([train_op, loss], feed_dict={X:batch_x, Y:batch_y, dropout_keep_prob:0.5})  

    120.                 print(e, i, loss_)  

    121.             # 保存模型  

    122.             if e % 50 == 0:  

    123.                 saver.save(sess, "name2sex.model", global_step=e)  

    124.    

    125. train_neural_network()  

    126.    

    127. # 使用训练的模型  

    128. def detect_sex(name_list):  

    129.     x = []  

    130.     for name in name_list:  

    131.         name_vec = []  

    132.         for word in name:  

    133.             name_vec.append(vocab.get(word))  

    134.         while len(name_vec) < max_name_length:  

    135.             name_vec.append(0)  

    136.         x.append(name_vec)  

    137.    

    138.     output = neural_network(len(vocabulary_list))  

    139.    

    140.     saver = tf.train.Saver(tf.global_variables())  

    141.     with tf.Session() as sess:  

    142.         # 恢复前一次训练  

    143.         ckpt = tf.train.get_checkpoint_state('.')  

    144.         if ckpt != None:  

    145.             print(ckpt.model_checkpoint_path)  

    146.             saver.restore(sess, ckpt.model_checkpoint_path)  

    147.         else:  

    148.             print("没找到模型")  

    149.    

    150.         predictions = tf.argmax(output, 1)  

    151.         res = sess.run(predictions, {X:x, dropout_keep_prob:1.0})  

    152.    

    153.         i = 0  

    154.         for name in name_list:  

    155.             print(name, '女' if res[i] == 0 else '男')  

    156.             i += 1  

    157.    

    158. detect_sex(["白富美""高帅富""王婷婷""田野"])  

    执行结果:

    640?wx_fmt=png

    Share the post "TensorFlow练习18: 根据姓名判断性别"

    文章已获作者授权转载,原文链接如下:

    http://blog.csdn.net/u014365862/article/details/53869732

  • 相关阅读:
    Springboot + Caffeine 实现本地缓存
    springboot + mybatis-plus + sharding-jdbc 实现单库分表
    工厂模式+策略模式 使用
    JAVA 金额自动除以100,精确到分
    spring aop + 自定义注解实现本地缓存
    springboot 使用 retry重试机制
    Mybatis-plus 自动注入公共字段
    docker 安装kafka
    ES 实现聚合分页
    Authentication token manipulation error 及 mongodb WiredTigerLAS.wt 文件过大问题
  • 原文地址:https://www.cnblogs.com/finer/p/11895143.html
Copyright © 2011-2022 走看看