kashgari做Bert+BiLSTM+CRF
kashgari:
-
是一个基于tensorflow的做Bert+LSTM模型的库
-
kashgari基于tensorflow2.0的版本里面没有CRF,所以使用的版本如下
环境:
- conda create --envs myTestNER python==3.6
- pip insall tensorflow==1.14.0
- pip install kashgari==1.1.5`
数据准备:
train_x = []
train_y = []
with open('data_file_name', encoding='utf-8') as f:
for line in f.readlines():
cur = line.strip().split()
train_x.append(list(data[0]))
train_y.append(list(data[1]))
# train_x = [[seq1], [seq2], [seq3], ...]
# train_y = [[tag1], [tag2], [tag3], ...]
# seq1 = ['我', '爱', '北', '京']
# tag1 = ['O', 'O', 'B-D', 'I-D']
训练时候需要下载BERT预训练模型,从这里下载:https://github.com/ymcui/Chinese-BERT-wwm
这里训练很简单,超参数调节可以参考: https://github.com/BrikerMan/Kashgari/blob/v2-trunk/docs/tutorial/text-labeling.md
训练:
import kashgari
from kashgari.embeddings import BERTEmbedding
from kashgari.tasks.labeling import BiLSTM_CRF_Model
bert = BERTEmbedding(model_folder="chinese_roberta_wwm_ext_L-12_H-768_A-12", sequence_length=256, task=kashgari.LABELING)
model = BiLSTM_CRF_Model(bert)
model.fit(train_x, train_y, x_validate=train_x, y_validate=train_y, epochs=2, batch_size=32)
model.save('my_bert_crf.h5')
预测:
import kashgari
model = kashgari.utils.load_model('my_bert_crf.h5')
predict = model.predict([['我', '爱', '北', '京']]) # 二维
print(predict)
# [['O', 'O', 'B-D', 'I-D']]