论文种类分类
4.1 任务说明
- 学习主题:论文分类(数据建模任务),利用已有数据建模,对新论文进行类别分类;
- 学习任务:使用论文标题完成类别分类;
- 学习成果:学会文本分类的基本方法(IF-IDF、Fastext、WordVec、Bert)
4.2 数据处理步骤
在原始arxiv论文中论文都有对应的类别,而论文类别是作者填写的。在本次任务中我们可以借助论文的标题和摘要完成:
* 对论文标题和摘要进行理;
* 对论文类别进行处理;
* 构建文本分类模型;
4.3 文本分类思路
- 思路1:TF-IDF+机器学习分类器
直接使用IF-IDF对文本提取特征,使用分类器进行分类,分类器的选择上可以使用SVM、LR、XGboost等
- 思路2:FastText
Fastext是入门款的词向量,利用Facebook提高的Fastext工具,可以快速构建分类器
- 思路3:WordVec+深度学习分类器
WordVec是进阶款的词向量,并通过构建深度学习分类器完成分类。深度学习分类的网络结构可以选择TextCNN、TextRnn或者BiLSTM
- 思路4:Bert词向量
Bert是高配款的词向量,具有强大的建模学习能力
4.4 代码
import seaborn as sns
from bs4 import BeautifulSoup
import re
import requests
import json
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
data=pd.read_csv("data.csv")
data=data[['title','categories','abstract']]
data=data.iloc[:200000]#取部分数据
data.head()
D:anaconda3libsite-packagesIPythoncoreinteractiveshell.py:3146: DtypeWarning: Columns (0) have mixed types.Specify dtype option on import or set low_memory=False.
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
title | categories | abstract | |
---|---|---|---|
0 | Calculation of prompt diphoton production cros... | hep-ph | A fully differential calculation in perturba... |
1 | Sparsity-certifying Graph Decompositions | math.CO cs.CG | We describe a new algorithm, the $(k,ell)$-... |
2 | The evolution of the Earth-Moon system based o... | physics.gen-ph | The evolution of Earth-Moon system is descri... |
3 | A determinant of Stirling cycle numbers counts... | math.CO | We show that a determinant of Stirling cycle... |
4 | From dyadic $Lambda_{alpha}$ to $Lambda_{a... | math.CA math.FA | In this paper we show how to compute the $L... |
data['text']=data['title']+data['abstract']
data['text']=data['text'].apply(lambda x:x.replace('
',''))
data['text']=data['text'].apply(lambda x:x.lower())
data=data.drop(['abstract','title'],axis=1)
data.head()
categories | text | |
---|---|---|
0 | hep-ph | calculation of prompt diphoton production cros... |
1 | math.CO cs.CG | sparsity-certifying graph decompositions we d... |
2 | physics.gen-ph | the evolution of the earth-moon system based o... |
3 | math.CO | a determinant of stirling cycle numbers counts... |
4 | math.CA math.FA | from dyadic $lambda_{alpha}$ to $lambda_{a... |
data.iloc[0,1]
'calculation of prompt diphoton production cross sections at tevatron and lhc energies a fully differential calculation in perturbative quantum chromodynamics ispresented for the production of massive photon pairs at hadron colliders. allnext-to-leading order perturbative contributions from quark-antiquark,gluon-(anti)quark, and gluon-gluon subprocesses are included, as well asall-orders resummation of initial-state gluon radiation valid atnext-to-next-to-leading logarithmic accuracy. the region of phase space isspecified in which the calculation is most reliable. good agreement isdemonstrated with data from the fermilab tevatron, and predictions are made formore detailed tests with cdf and do data. predictions are shown fordistributions of diphoton pairs produced at the energy of the large hadroncollider (lhc). distributions of the diphoton pairs from the decay of a higgsboson are contrasted with those produced from qcd processes at the lhc, showingthat enhanced sensitivity to the signal can be obtained with judiciousselection of events.'
data['categories']=data['categories'].apply(lambda x:x.split(' '))
data['categories_big']=data['categories'].apply(lambda x:[xx.split('.')[0] for xx in x])
data.head()
categories | text | categories_big | |
---|---|---|---|
0 | [hep-ph] | calculation of prompt diphoton production cros... | [hep-ph] |
1 | [math.CO, cs.CG] | sparsity-certifying graph decompositions we d... | [math, cs] |
2 | [physics.gen-ph] | the evolution of the earth-moon system based o... | [physics] |
3 | [math.CO] | a determinant of stirling cycle numbers counts... | [math] |
4 | [math.CA, math.FA] | from dyadic $lambda_{alpha}$ to $lambda_{a... | [math, math] |
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
data_label = mlb.fit_transform(data['categories_big'].iloc[:])
data_label
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 1, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 1, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 1, 0, ..., 0, 0, 0]])
print(data_label[1])
print(data_label[3])
[0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
IF-IDF
原理可参考:https://blog.csdn.net/zrc199021/article/details/53728499
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(max_features=4000)
data_tfidf = vectorizer.fit_transform(data['text'].iloc[:])
# 划分训练集和验证集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(data_tfidf, data_label,test_size = 0.2,random_state = 1)
# 构建多标签分类模型
from sklearn.multioutput import MultiOutputClassifier
from sklearn.naive_bayes import MultinomialNB
clf = MultiOutputClassifier(MultinomialNB()).fit(x_train, y_train)
关于MultiOutputClassifier和MultinomialNB(先验为多项式分布的朴素贝叶斯),可参考:
- https://blog.csdn.net/Islotus/article/details/78671238
- https://blog.csdn.net/TeFuirnever/article/details/100125386
#验证模型的精度
from sklearn.metrics import classification_report
print(classification_report(y_test, clf.predict(x_test)))
precision recall f1-score support
0 0.95 0.84 0.89 7925
1 0.86 0.78 0.82 7339
2 0.77 0.70 0.73 2944
3 0.00 0.00 0.00 4
4 0.73 0.44 0.55 2123
5 0.52 0.64 0.58 987
6 0.85 0.33 0.47 544
7 0.71 0.67 0.69 3649
8 0.77 0.58 0.66 3388
9 0.85 0.88 0.86 10745
10 0.46 0.10 0.16 1757
11 0.90 0.04 0.07 729
12 0.45 0.31 0.37 507
13 0.55 0.32 0.41 1083
14 0.68 0.12 0.20 3441
15 0.82 0.16 0.27 655
16 0.93 0.14 0.24 268
17 0.87 0.40 0.55 2484
18 0.84 0.34 0.49 692
micro avg 0.82 0.63 0.71 51264
macro avg 0.71 0.41 0.47 51264
weighted avg 0.80 0.63 0.68 51264
samples avg 0.71 0.70 0.69 51264
D:anaconda3libsite-packagessklearnmetrics\_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
D:anaconda3libsite-packagessklearnmetrics\_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
关于classification_report:
LSTM(Keras实现)
Keras的官方中文文档:https://keras.io/zh/
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(data['text'].iloc[:], data_label,test_size = 0.2,random_state = 1)
# parameter
max_features= 500
max_len= 150
embed_size=100
batch_size = 128
epochs = 1
from keras.preprocessing.text import Tokenizer
from keras.preprocessing import sequence
#文本预处理
tokens = Tokenizer(num_words = max_features)
tokens.fit_on_texts(list(x_train)+list(x_test))
x_sub_train = tokens.texts_to_sequences(x_train)
x_sub_test = tokens.texts_to_sequences(x_test)
tokens.word_index
{'the': 1,
'of': 2,
'a': 3,
'and': 4,
'in': 5,
'to': 6,
'we': 7,
'is': 8,
'for': 9,
'with': 10,
'that': 11,
'on': 12,
'are': 13,
'by': 14,
'this': 15,
'as': 16,
'an': 17,
'from': 18,
'at': 19,
'be': 20,
'1': 21,
'2': 22,
'which': 23,
'0': 24,
'model': 25,
'can': 26,
'n': 27,
'two': 28,
'it': 29,
'x': 30,
'field': 31,
'these': 32,
'results': 33,
'show': 34,
'quantum': 35,
'our': 36,
'also': 37,
'3': 38,
'energy': 39,
'using': 40,
'have': 41,
'one': 42,
'time': 43,
'theory': 44,
'or': 45,
'between': 46,
'k': 47,
'study': 48,
'non': 49,
'mass': 50,
'data': 51,
'm': 52,
'has': 53,
's': 54,
'such': 55,
'not': 56,
'system': 57,
't': 58,
'i': 59,
'new': 60,
'p': 61,
'high': 62,
'based': 63,
'paper': 64,
'present': 65,
'e': 66,
'order': 67,
'state': 68,
'space': 69,
'models': 70,
'd': 71,
'phase': 72,
'c': 73,
'large': 74,
'spin': 75,
'r': 76,
'magnetic': 77,
'its': 78,
'g': 79,
'than': 80,
'all': 81,
'4': 82,
'systems': 83,
'find': 84,
'function': 85,
'well': 86,
'b': 87,
'some': 88,
'properties': 89,
'where': 90,
'first': 91,
'density': 92,
'their': 93,
'dimensional': 94,
'both': 95,
'5': 96,
'number': 97,
'structure': 98,
'case': 99,
'states': 100,
'method': 101,
'h': 102,
'low': 103,
'been': 104,
'when': 105,
'type': 106,
'ray': 107,
'if': 108,
'analysis': 109,
'f': 110,
'10': 111,
'z': 112,
'used': 113,
'different': 114,
'temperature': 115,
'but': 116,
'problem': 117,
'l': 118,
'over': 119,
'into': 120,
'only': 121,
'observed': 122,
'galaxies': 123,
'more': 124,
'stars': 125,
'finite': 126,
'star': 127,
'obtained': 128,
'group': 129,
'distribution': 130,
'found': 131,
'three': 132,
'equation': 133,
'approach': 134,
'other': 135,
'gamma': 136,
'may': 137,
'effect': 138,
'dynamics': 139,
'ofthe': 140,
'general': 141,
'use': 142,
'point': 143,
'result': 144,
'functions': 145,
'q': 146,
'emission': 147,
'set': 148,
'range': 149,
'equations': 150,
'single': 151,
'effects': 152,
'then': 153,
'matter': 154,
'there': 155,
'due': 156,
'linear': 157,
'local': 158,
'fields': 159,
'scale': 160,
'form': 161,
'given': 162,
'transition': 163,
'small': 164,
'evolution': 165,
'observations': 166,
'terms': 167,
'potential': 168,
'limit': 169,
'very': 170,
'wave': 171,
'optical': 172,
'parameters': 173,
'surface': 174,
'any': 175,
'shown': 176,
'gas': 177,
'formation': 178,
'solutions': 179,
'light': 180,
'within': 181,
'particular': 182,
'was': 183,
'under': 184,
'up': 185,
'black': 186,
'spectrum': 187,
'rate': 188,
'dark': 189,
'most': 190,
'electron': 191,
'v': 192,
'6': 193,
'here': 194,
'consider': 195,
'out': 196,
'prove': 197,
'known': 198,
'through': 199,
'alpha': 200,
'strong': 201,
'possible': 202,
'they': 203,
'class': 204,
'parameter': 205,
'simple': 206,
'will': 207,
'power': 208,
'lattice': 209,
'like': 210,
'discuss': 211,
'while': 212,
'galaxy': 213,
'each': 214,
'how': 215,
'particle': 216,
'give': 217,
'process': 218,
'conditions': 219,
'interaction': 220,
'however': 221,
'free': 222,
'line': 223,
'work': 224,
'u': 225,
'random': 226,
'symmetry': 227,
'no': 228,
'coupling': 229,
'simulations': 230,
'region': 231,
'measurements': 232,
'solution': 233,
'about': 234,
'o': 235,
'level': 236,
'current': 237,
'classical': 238,
'standard': 239,
'recent': 240,
'spectral': 241,
'j': 242,
'so': 243,
'same': 244,
'information': 245,
'many': 246,
'cluster': 247,
'provide': 248,
'pi': 249,
'value': 250,
'algorithm': 251,
'values': 252,
'stellar': 253,
'scattering': 254,
'matrix': 255,
'gauge': 256,
'hole': 257,
'investigate': 258,
'associated': 259,
'8': 260,
'near': 261,
'size': 262,
'complex': 263,
'constant': 264,
'times': 265,
'several': 266,
'numerical': 267,
'effective': 268,
'long': 269,
'critical': 270,
'behavior': 271,
'spectra': 272,
'studied': 273,
'second': 274,
'obtain': 275,
'methods': 276,
'higher': 277,
'were': 278,
'proposed': 279,
'zero': 280,
'groups': 281,
'source': 282,
'velocity': 283,
'self': 284,
'sources': 285,
'even': 286,
'experimental': 287,
'particles': 288,
'frequency': 289,
'lambda': 290,
'sample': 291,
'7': 292,
'decay': 293,
'consistent': 294,
'presented': 295,
'interactions': 296,
'sigma': 297,
'gravity': 298,
'those': 299,
'similar': 300,
'flow': 301,
'algebra': 302,
'bound': 303,
'dependence': 304,
'physics': 305,
'way': 306,
'disk': 307,
'derived': 308,
'flux': 309,
'clusters': 310,
'radio': 311,
'delta': 312,
'processes': 313,
'quark': 314,
'mean': 315,
'main': 316,
'network': 317,
'theorem': 318,
'via': 319,
'presence': 320,
'boundary': 321,
'induced': 322,
'correlation': 323,
'w': 324,
'including': 325,
'mu': 326,
'band': 327,
'related': 328,
'charge': 329,
'ii': 330,
'dynamical': 331,
'corresponding': 332,
'dependent': 333,
'real': 334,
'momentum': 335,
'discussed': 336,
'ratio': 337,
'networks': 338,
'scalar': 339,
'weak': 340,
'existence': 341,
'masses': 342,
'structures': 343,
'inthe': 344,
'spaces': 345,
'production': 346,
'experiments': 347,
'physical': 348,
'solar': 349,
'vector': 350,
'relation': 351,
'thus': 352,
'lower': 353,
'thermal': 354,
'approximation': 355,
'measure': 356,
'describe': 357,
'initial': 358,
'compared': 359,
'derive': 360,
'important': 361,
'framework': 362,
'noise': 363,
'factor': 364,
'cases': 365,
'theoretical': 366,
'motion': 367,
'massive': 368,
'certain': 369,
'plane': 370,
'law': 371,
'four': 372,
'finally': 373,
'cross': 374,
'propose': 375,
'calculations': 376,
'operators': 377,
'problems': 378,
'invariant': 379,
'gravitational': 380,
'allows': 381,
'various': 382,
'recently': 383,
'algebras': 384,
'modes': 385,
'previous': 386,
'during': 387,
'considered': 388,
'y': 389,
'measured': 390,
'component': 391,
'part': 392,
'photon': 393,
'us': 394,
'length': 395,
'without': 396,
'molecular': 397,
'could': 398,
'distance': 399,
'generalized': 400,
'defined': 401,
'multi': 402,
'applications': 403,
'compact': 404,
'scheme': 405,
'total': 406,
'regions': 407,
'symmetric': 408,
'applied': 409,
'independent': 410,
'constraints': 411,
'objects': 412,
'neutrino': 413,
'transport': 414,
'galactic': 415,
'channel': 416,
'around': 417,
'points': 418,
'coupled': 419,
'along': 420,
'dimension': 421,
'qcd': 422,
'distributions': 423,
'detection': 424,
'regime': 425,
'probability': 426,
'resolution': 427,
'mode': 428,
'role': 429,
'cosmic': 430,
'survey': 431,
'omega': 432,
'nonlinear': 433,
'shows': 434,
'evidence': 435,
'geometry': 436,
'beta': 437,
'theories': 438,
'does': 439,
'background': 440,
'binary': 441,
'operator': 442,
'infrared': 443,
'studies': 444,
'exact': 445,
'universe': 446,
'mechanism': 447,
'report': 448,
'central': 449,
'loop': 450,
'features': 451,
'cosmological': 452,
'luminosity': 453,
'entropy': 454,
'nuclear': 455,
'al': 456,
'them': 457,
'demonstrate': 458,
'dimensions': 459,
'do': 460,
'term': 461,
'strongly': 462,
'lines': 463,
'measurement': 464,
'graph': 465,
'extended': 466,
'fluctuations': 467,
'above': 468,
'ground': 469,
'degree': 470,
'relativistic': 471,
'after': 472,
'determine': 473,
'provides': 474,
'complete': 475,
'heavy': 476,
'radiation': 477,
'stable': 478,
'fermi': 479,
'application': 480,
'action': 481,
'control': 482,
'called': 483,
'series': 484,
'body': 485,
'redshift': 486,
'gev': 487,
'expansion': 488,
'described': 489,
'positive': 490,
'fixed': 491,
'further': 492,
'leads': 493,
'short': 494,
'bar': 495,
'larger': 496,
'differential': 497,
'description': 498,
'direct': 499,
'close': 500,
'waves': 501,
'scaling': 502,
'agreement': 503,
'optimal': 504,
'dust': 505,
'et': 506,
'condition': 507,
'core': 508,
'entanglement': 509,
'signal': 510,
'global': 511,
'expected': 512,
'phi': 513,
'pair': 514,
'neutron': 515,
'search': 516,
'construct': 517,
'significant': 518,
'test': 519,
'polarization': 520,
'equilibrium': 521,
'collisions': 522,
'open': 523,
'higgs': 524,
'respect': 525,
'account': 526,
'technique': 527,
'stability': 528,
'holes': 529,
'scales': 530,
'9': 531,
'string': 532,
'spatial': 533,
'sets': 534,
'determined': 535,
'upper': 536,
'simulation': 537,
'sequence': 538,
'example': 539,
'correlations': 540,
'much': 541,
'nu': 542,
'energies': 543,
'examples': 544,
'population': 545,
'whose': 546,
'addition': 547,
'quasi': 548,
'rates': 549,
'tau': 550,
'statistical': 551,
'multiple': 552,
'graphene': 553,
'leading': 554,
'proof': 555,
'curves': 556,
'metric': 557,
'estimate': 558,
'investigated': 559,
'nature': 560,
'let': 561,
'accretion': 562,
'arbitrary': 563,
'lie': 564,
'medium': 565,
'gap': 566,
'sum': 567,
'generated': 568,
'fe': 569,
'normal': 570,
'full': 571,
'double': 572,
'atoms': 573,
'growth': 574,
'graphs': 575,
'components': 576,
'minimal': 577,
'jet': 578,
'co': 579,
'special': 580,
'formula': 581,
'angular': 582,
'force': 583,
'asymptotic': 584,
'de': 585,
'detected': 586,
'pressure': 587,
'domain': 588,
'integral': 589,
'few': 590,
'moreover': 591,
'topological': 592,
'maximum': 593,
'good': 594,
'developed': 595,
'fraction': 596,
'respectively': 597,
'since': 598,
'calculate': 599,
'discrete': 600,
'infinite': 601,
'representation': 602,
'elements': 603,
'continuous': 604,
'resonance': 605,
'being': 606,
'relative': 607,
'log': 608,
'introduce': 609,
'techniques': 610,
'mixing': 611,
'stochastic': 612,
'bounds': 613,
'experiment': 614,
'fundamental': 615,
'below': 616,
'specific': 617,
'maps': 618,
'gaussian': 619,
'decays': 620,
'performance': 621,
'means': 622,
'error': 623,
'radius': 624,
'20': 625,
'diffusion': 626,
'average': 627,
'numbers': 628,
'should': 629,
'comparison': 630,
'product': 631,
'su': 632,
'calculated': 633,
'corrections': 634,
'einstein': 635,
'either': 636,
'context': 637,
'surfaces': 638,
'predictions': 639,
'less': 640,
'period': 641,
'chiral': 642,
'hamiltonian': 643,
'algorithms': 644,
'basis': 645,
'functional': 646,
'among': 647,
'ring': 648,
'suggest': 649,
'introduced': 650,
'compute': 651,
'explicit': 652,
'rotation': 653,
'magnitude': 654,
'closed': 655,
'index': 656,
'transitions': 657,
'amplitude': 658,
'driven': 659,
'telescope': 660,
'construction': 661,
'metal': 662,
'conjecture': 663,
'atomic': 664,
'early': 665,
'analyze': 666,
'compare': 667,
'map': 668,
'origin': 669,
'family': 670,
'periodic': 671,
'electronic': 672,
'absorption': 673,
'curvature': 674,
'made': 675,
'electrons': 676,
'bulk': 677,
'natural': 678,
'orbital': 679,
'estimates': 680,
'performed': 681,
'change': 682,
'manifolds': 683,
'plasma': 684,
'100': 685,
'least': 686,
'relations': 687,
'tensor': 688,
'variables': 689,
'would': 690,
'transfer': 691,
'variable': 692,
'types': 693,
'classes': 694,
'resulting': 695,
'electric': 696,
'wide': 697,
'negative': 698,
'mathbb': 699,
'contribution': 700,
'gives': 701,
'polynomial': 702,
'far': 703,
'algebraic': 704,
'lhc': 705,
'hard': 706,
'shape': 707,
'interacting': 708,
'furthermore': 709,
'universal': 710,
'vacuum': 711,
'limits': 712,
'future': 713,
'because': 714,
'manifold': 715,
'sim': 716,
'pm': 717,
'coefficients': 718,
'property': 719,
'design': 720,
'fluid': 721,
'previously': 722,
'geometric': 723,
'temperatures': 724,
'events': 725,
'breaking': 726,
'tothe': 727,
'phases': 728,
'ion': 729,
'monte': 730,
'rm': 731,
'matrices': 732,
'therefore': 733,
'rho': 734,
'volume': 735,
'partial': 736,
'although': 737,
'12': 738,
'theta': 739,
'lead': 740,
'spectroscopy': 741,
'pairs': 742,
'review': 743,
'almost': 744,
'analytic': 745,
'strength': 746,
'superconducting': 747,
'radial': 748,
'curve': 749,
'characteristic': 750,
'available': 751,
'detector': 752,
'liquid': 753,
'chain': 754,
'edge': 755,
'agn': 756,
'code': 757,
'halo': 758,
'carlo': 759,
'angle': 760,
'produced': 761,
'extension': 762,
'beam': 763,
'charged': 764,
'increase': 765,
'version': 766,
'equivalent': 767,
'key': 768,
'efficient': 769,
'layer': 770,
'apply': 771,
'cm': 772,
'orbit': 773,
'significantly': 774,
'oscillations': 775,
'smooth': 776,
'formalism': 777,
'peak': 778,
'nuclei': 779,
'down': 780,
'observation': 781,
'analytical': 782,
'center': 783,
'images': 784,
'changes': 785,
'dual': 786,
'15': 787,
'difference': 788,
'scenario': 789,
'every': 790,
'coherent': 791,
'infty': 792,
'loss': 793,
'response': 794,
'chemical': 795,
'exchange': 796,
'section': 797,
'generation': 798,
'detailed': 799,
'external': 800,
'active': 801,
'direction': 802,
'additional': 803,
'channels': 804,
'flat': 805,
'laser': 806,
'fast': 807,
'principle': 808,
'explain': 809,
'forms': 810,
'reduced': 811,
'diagram': 812,
'half': 813,
'heat': 814,
'highly': 815,
'depends': 816,
'statistics': 817,
'off': 818,
'factors': 819,
'area': 820,
'novel': 821,
'fact': 822,
'allow': 823,
'calculation': 824,
'semi': 825,
'bounded': 826,
'able': 827,
'smaller': 828,
'complexity': 829,
'codes': 830,
'transverse': 831,
'sun': 832,
'dynamic': 833,
'abelian': 834,
'eta': 835,
'boson': 836,
'thin': 837,
'influence': 838,
'predicted': 839,
'connected': 840,
'article': 841,
'rather': 842,
'relevant': 843,
'environment': 844,
'make': 845,
'best': 846,
'continuum': 847,
'sub': 848,
'provided': 849,
'include': 850,
'contrast': 851,
'dirac': 852,
'mathcal': 853,
'fit': 854,
'years': 855,
'au': 856,
'perturbation': 857,
'beyond': 858,
'soft': 859,
'end': 860,
'degrees': 861,
'modified': 862,
'imaging': 863,
'square': 864,
'proton': 865,
'develop': 866,
'dispersion': 867,
'reduction': 868,
'material': 869,
'forthe': 870,
'minimum': 871,
'following': 872,
'possibility': 873,
'cp': 874,
'regular': 875,
'relaxation': 876,
'cannot': 877,
'rays': 878,
'polynomials': 879,
'instability': 880,
'dwarf': 881,
'extend': 882,
'onthe': 883,
'densities': 884,
'define': 885,
'must': 886,
'pure': 887,
'taken': 888,
'ir': 889,
'threshold': 890,
'supersymmetric': 891,
'variety': 892,
'representations': 893,
'brane': 894,
'parallel': 895,
'next': 896,
'km': 897,
'latter': 898,
'impact': 899,
'intermediate': 900,
'photons': 901,
'required': 902,
'understanding': 903,
'mechanics': 904,
'path': 905,
'identify': 906,
'increasing': 907,
'identified': 908,
'potentials': 909,
'hot': 910,
'accuracy': 911,
'11': 912,
'better': 913,
'still': 914,
'final': 915,
'procedure': 916,
'30': 917,
'towards': 918,
'conformal': 919,
'uniform': 920,
'constructed': 921,
'crystal': 922,
'approaches': 923,
'samples': 924,
'width': 925,
'increases': 926,
'analyzed': 927,
'atom': 928,
'inner': 929,
'ngc': 930,
'mev': 931,
'homogeneous': 932,
'seen': 933,
'color': 934,
'convergence': 935,
'electromagnetic': 936,
'efficiency': 937,
'inverse': 938,
'bose': 939,
'equal': 940,
'image': 941,
'probe': 942,
'cloud': 943,
'ads': 944,
'necessary': 945,
'fully': 946,
'profile': 947,
'step': 948,
'what': 949,
'static': 950,
'exist': 951,
'top': 952,
'speed': 953,
'position': 954,
'view': 955,
'andthe': 956,
'shock': 957,
'likely': 958,
'correlated': 959,
'epsilon': 960,
'psi': 961,
'exhibit': 962,
'together': 963,
'basic': 964,
'modeling': 965,
'per': 966,
'young': 967,
'disks': 968,
'behaviour': 969,
'whether': 970,
'object': 971,
'note': 972,
'rich': 973,
'tev': 974,
'ratios': 975,
'common': 976,
'estimation': 977,
'focus': 978,
'rank': 979,
'deep': 980,
'connection': 981,
'giant': 982,
'excitation': 983,
'assuming': 984,
'fermions': 985,
'forming': 986,
'appear': 987,
'combined': 988,
'thatthe': 989,
'contributions': 990,
'measures': 991,
'spacetime': 992,
'exists': 993,
'feature': 994,
'sqrt': 995,
'shell': 996,
'question': 997,
'nucleon': 998,
'flavor': 999,
'flows': 1000,
...}
x_sub_train=sequence.pad_sequences(x_sub_train, maxlen=max_len)
x_sub_test=sequence.pad_sequences(x_sub_test, maxlen=max_len)
len(x_sub_train[0])
150
关于Keras的文本预处理,参见:https://zhuanlan.zhihu.com/p/55412623
#导入必要模块
from keras.layers import Dense,Input,LSTM,Bidirectional,Activation,Conv1D,GRU
from keras.layers import Dropout,Embedding,GlobalMaxPooling1D, MaxPooling1D, Add, Flatten
from keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D, concatenate, SpatialDropout1D# Keras Callback Functions:
from keras.callbacks import Callback
from keras.callbacks import EarlyStopping,ModelCheckpoint
from keras import initializers, regularizers, constraints, optimizers, layers, callbacks
from keras.models import Model
from keras.optimizers import Adam
# 建立LSTM模型
sequence_input = Input(shape=(max_len, ))
x = Embedding(max_features, embed_size,trainable = False)(sequence_input)
x = SpatialDropout1D(0.2)(x)
x = Bidirectional(GRU(128,
return_sequences=True,dropout=0.1,recurrent_dropout=0.1))(x)
x = Conv1D(64, kernel_size = 3, padding = "valid", kernel_initializer =
"glorot_uniform")(x)
avg_pool = GlobalAveragePooling1D()(x)
max_pool = GlobalMaxPooling1D()(x)
x = concatenate([avg_pool, max_pool])
preds = Dense(19, activation="sigmoid")(x)
关于LSTM模型,参见:https://zhuanlan.zhihu.com/p/139617364
print(x_sub_train.shape)
print(y_train.shape)
(160000, 150)
(160000, 19)
#模型拟合
model = Model(sequence_input, preds)
model.compile(loss='binary_crossentropy',optimizer=Adam(lr=1e-3),metrics=['accuracy'])
model.fit(x_sub_train, y_train, batch_size=batch_size, epochs=epochs)
1250/1250 [==============================] - 2934s 2s/step - loss: 0.1984 - accuracy: 0.3651
<tensorflow.python.keras.callbacks.History at 0x18c48aa9c70>
#用训练好的模型预测test数据,得到预测类别
prediction=model.predict(x_sub_test, batch_size=batch_size)
prediction
array([[0.02252209, 0.21710196, 0.05228466, ..., 0.00359941, 0.04740071,
0.00767908],
[0.06721482, 0.07629076, 0.1790415 , ..., 0.0033555 , 0.04327598,
0.01796243],
[0.03353414, 0.19576049, 0.06734261, ..., 0.00558475, 0.02856755,
0.00945702],
...,
[0.01758364, 0.7247476 , 0.0205844 , ..., 0.00262681, 0.12628657,
0.0036349 ],
[0.02018005, 0.04203317, 0.03764346, ..., 0.00153542, 0.00649127,
0.01208609],
[0.01772627, 0.1940268 , 0.06325474, ..., 0.00296167, 0.03129855,
0.01045477]], dtype=float32)
prediction.shape
(40000, 19)
y_test[0]
array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])
#设置阈值,预测类别
def lastprocess(prediction,yuzhi):
import numpy as np
myarray=np.zeros(prediction.shape)
for i in range(0,len(prediction)):
for j in range(0,len(prediction[0])):
if prediction[i][j]>=yuzhi:
myarray[i][j]=1
else:
myarray[i][j]=0
if sum(myarray[i])==0:
myarray[i][np.argmax(prediction[i])]=1
return myarray
lastprocess(prediction,0.15)
array([[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
...,
[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.]])
#评价预测值
from sklearn.metrics import classification_report
print(classification_report(y_test, lastprocess(prediction,0.15)))
precision recall f1-score support
0 0.64 0.79 0.71 7925
1 0.39 0.77 0.52 7339
2 0.28 0.57 0.38 2944
3 0.00 0.00 0.00 4
4 0.48 0.44 0.46 2123
5 0.32 0.35 0.34 987
6 0.33 0.12 0.18 544
7 0.39 0.56 0.46 3649
8 0.44 0.43 0.44 3388
9 0.43 0.96 0.59 10745
10 0.09 0.00 0.00 1757
11 0.00 0.00 0.00 729
12 0.04 0.00 0.00 507
13 0.15 0.07 0.09 1083
14 0.22 0.05 0.08 3441
15 0.00 0.00 0.00 655
16 0.00 0.00 0.00 268
17 0.39 0.63 0.48 2484
18 0.00 0.00 0.00 692
micro avg 0.43 0.60 0.50 51264
macro avg 0.24 0.30 0.25 51264
weighted avg 0.39 0.60 0.45 51264
samples avg 0.50 0.65 0.53 51264
该模型只训练了一个epoch,并且需要人为设定阈值,预测效果并不好