AI研习社最近举办了一个比赛——微博立场检测,实际上就是一个NLP文本分类的比赛
Baseline—FastText
我的Baseline方法用的是pkuseg分词+FastText,最好成绩是60,下面是我几次提交的得分截图
Load Data & Preprocess
先import之后要用到的库
import pkuseg
import random
import pandas as pd
import fasttext
df = pd.read_csv('train.csv', delimiter=' ')
官方给的数据,虽然是csv文件,但是字段之间用的是
隔开的,所以读取的时候注意一下就行了。数据样式如下
stance
字段有三类,分别是FAVOR
、AGAINST
、NONE
,这也是你需要最终预测的值。但是通过仔细分析数据可以发现,stance
字段除了上面三个值以外还有别的值,所以要先把其它的数据剔除掉
drop_list = []
for i in range(len(df)):
if df.stance[i] != 'FAVOR' and df.stance[i] != 'AGAINST' and df.stance[i] != 'NONE':
drop_list.append(i)
df.drop(drop_list, inplace=True)
FastText读取的数据应该满足__lable__xx text
,例如
__label__A 我 喜欢 打 篮球
__label__B 我 喜欢 鲲鲲
__label__A 我 喜欢 踢 足球
也就是说,每一行表示一个样本,并且标签在前,文本在后,两者之间用空格隔开。标签必须以__label__
开头。所以我们要先把原始数据的标签进行一个转换,即FAVOR
变成__label__A
,AGAINST
变成__label__B
,NONE
变成__label__C
mapping = {'FAVOR':'__label__A', 'AGAINST':'__label__B', 'NONE':'__label__C'}
df['stance'] = df['stance'].map(mapping)
这些都做完以后最好shuffle一下数据
df = df.sample(frac=1)
sample(frac=p)
其中(pin[0,1]),意思是随机sample出原始数据的百分之多少,如果(p=1),则表示随机sample出原始数据的全部,并且由于是随机sample的,所以原始数据的顺序就被打乱了
Split Train & Validation Data
这里我以7:3的比例将数据集拆分成Train Data和Valid Data
train_len = int(len(df) * 0.7)
df_train = df.loc[:train_len]
df_val = df.loc[train_len:]
Word Segmentation
从FastText读取数据的样式可以看出,我们需要对一句话进行分词。这里我用的是pkuseg,因为我看它官方API介绍的时候,里面提到它有一个web语料库
在分词前,我先从网上找了一些常见的中英文停用词
stopwords = []
for line in open('stopwords.txt', encoding='utf-8'):
stopwords.append(line)
stopwords.append('
')
stopwords = set(stopwords)
停用词表我就不提供了,网上有很多,自己下载即可
然后是一行一行读取数据并分词,分完词再过滤。这些都做完以后,按照FastText要求的格式,拼接字符串,保存到文件中
def dump_file(df, filename, mode='train'):
seg = pkuseg.pkuseg(model_name='web')
with open(filename, 'w',encoding='utf-8') as f:
for i in df.index.values:
segs = seg.cut(df.text[i])
segs = filter(lambda x:x not in stopwords, segs) #去掉停用词
# segs = filter(lambda x:len(x)>1,segs)
segs = filter(lambda x:x.startswith('http')==False, segs)
segs = filter(lambda x:x.startswith('.')==False, segs)
segs = filter(lambda x:x.startswith('-')==False, segs)
segs = filter(lambda x:x.startswith(',')==False, segs)
segs = filter(lambda x:x.startswith('。')==False, segs)
segs = filter(lambda x:x.startswith('…')==False, segs)
segs = filter(lambda x:x.startswith('/')==False, segs)
segs = filter(lambda x:x.startswith('—')==False, segs)
segs = filter(lambda x:x.startswith('、')==False, segs)
segs = filter(lambda x:x.startswith(':')==False, segs)
segs = filter(lambda x:x.startswith('~')==False, segs)
segs = filter(lambda x:x.startswith('[')==False, segs)
segs = filter(lambda x:x.startswith(']')==False, segs)
segs = filter(lambda x:(x.isalpha() and len(x) == 7) == False, segs)
string = ''
for j in segs:
string = string + ' ' + j
if mode == 'test':
string = string.lstrip()
else:
string = df.stance[i] + ' ' + string
string = string.lstrip()
f.write(string + '
')
dump_file(df_train, 'train.txt', 'train')
dump_file(df_val, 'val.txt', 'train')
FastText
首先从它官方的github仓库中clone dev版本(直接使用pip install fasttext是稳定版)
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ pip install .
因为最新的dev版本中有一个参数autotuneValidationFile
可以在训练过程中自动搜索使得acc最大的参数。fastText使用也很简单
clf = fasttext.train_supervised(input='train.txt', autotuneValidationFile='val.txt')
指定训练集以及用于帮助寻找最优参数的测试集的路径即可。如果要保存模型就用
clf.save_model('fasttext_model')
Predict & Submit
基本上如果你按照我的方法一路做下来,到现在为止在验证集上的最大分数也就60左右
然后就是对test集进行预测,预测完了提交就行了
test = pd.read_csv('test.csv', delimiter=' ')
dump_file(test, 'test.txt', 'test')
labels = []
for line in open('test.txt', encoding='utf-8'):
if line != '':
line = line.strip('
')
labels.append(clf.predict(line)[0][0])
test['idx'] = range(len(test))
test['stance'] = labels
mapping = {'__label__A':'FAVOR','__label__B':'AGAINST','__label__C':'NONE'}
test['stance'] = test['stance'].map(mapping)
test = test.drop(['target', 'text'], axis=1)
test.to_csv('test_pred.csv',index=False,header=False)
Improve
- 我的做法只用了
text
和stance
这两列,target
我觉得可以用也可以不用 - 仔细观察数据集会发现,其实样本分布及其不均匀,
stance
列中FAVOR
和AGAINST
两个值特别多,NONE
特别少,这就涉及到不均衡样本的训练问题,可以通过sample,将它们的比例设置的比较均衡了再训练 - 过滤词设置的更详细一点。如果你仔细查看了分词后的数据集,你应该能发现里面其实还有很多垃圾词,比方说网址、7位验证码、表情之类的
- 直接上BERT