zoukankan      html  css  js  c++  java
  • 利用朴素贝叶斯根据名字判断性别

    import pandas as pd
    from collections import defaultdict
    import math


    train=pd.read_csv('train.txt')
    test=pd.read_csv('test.txt')
    submit=pd.read_csv('sample_submit.csv')

    train.head(10)

    #%%
    # 把数据分为男女两部分
    names_female = train[train['gender'] == 0]
    names_male = train[train['gender'] == 1]

    # totals用来存放训练集中女生、男生的总数
    totals = {'f': len(names_female),
    'm': len(names_male)}

    frequency_list_f = defaultdict(int)
    for name in names_female['name']:
    for char in name:
    frequency_list_f[char] += 1. / totals['f']

    frequency_list_m = defaultdict(int)
    for name in names_male['name']:
    for char in name:
    frequency_list_m[char] += 1. / totals['m']
    print(frequency_list_f['娟'])
    print(frequency_list_m['钢'])

    #%%
    def LaplaceSmooth(char, frequency_list, total, alpha=1.0):
    count = frequency_list[char] * total
    distinct_chars = len(frequency_list)
    freq_smooth = (count + alpha ) / (total + distinct_chars * alpha)
    return freq_smooth

    #%%
    base_f = math.log(1 - train['gender'].mean())
    base_f += sum([math.log(1 - frequency_list_f[char]) for char in frequency_list_f])

    base_m = math.log(train['gender'].mean())
    base_m += sum([math.log(1 - frequency_list_m[char]) for char in frequency_list_m])

    bases = {'f': base_f, 'm': base_m}
    #%%
    def GetLogProb(char, frequency_list, total):
    freq_smooth = LaplaceSmooth(char, frequency_list, total)
    return math.log(freq_smooth) - math.log(1 - freq_smooth)
    #%%
    def ComputeLogProb(name, bases, totals, frequency_list_m, frequency_list_f):
    logprob_m = bases['m']
    logprob_f = bases['f']
    for char in name:
    logprob_m += GetLogProb(char, frequency_list_m, totals['m'])
    logprob_f += GetLogProb(char, frequency_list_f, totals['f'])
    return {'male': logprob_m, 'female': logprob_f}

    def GetGender(LogProbs):
    return LogProbs['male'] > LogProbs['female']

    result = []
    for name in test['name']:
    LogProbs = ComputeLogProb(name, bases, totals, frequency_list_m, frequency_list_f)
    gender = GetGender(LogProbs)
    result.append(int(gender))

    submit['gender'] = result

    submit.to_csv('my_NB_prediction.csv', index=False)
  • 相关阅读:
    Thinkphp3.2.3关于开启DEBUG正常,关闭DEBUG就报错模版无法找到,页面错误!请稍后再试~
    Apache 工作模式的正确配置
    TIME_WAIT 你好!
    对称加密实现重要日志上报Openresty接口服务
    阿里nas挂载错误
    机器装多个版本php,并安装redis插件报错【已解决】
    find 删除文件
    从头认识java-6.7 初始化与类的加载
    从头认识java-6.6 final(4)-类与忠告
    从头认识java-6.6 final(3)-方法
  • 原文地址:https://www.cnblogs.com/liujinxin123/p/12499636.html
Copyright © 2011-2022 走看看