zoukankan      html  css  js  c++  java
  • dga model train and test code

    # _*_coding:UTF-8_*_
    
    import operator
    import tldextract
    import random
    import pickle
    import os
    import tflearn
    
    from math import log
    from tflearn.data_utils import to_categorical, pad_sequences
    from tflearn.layers.core import input_data, dropout, fully_connected
    from tflearn.layers.conv import conv_1d, max_pool_1d
    from tflearn.layers.estimator import regression
    from tflearn.layers.normalization import batch_normalization
    from sklearn.model_selection import train_test_split
    
    
    def get_cnn_model(max_len, volcab_size=None):
        if volcab_size is None:
            volcab_size = 10240000
    
        # Building convolutional network
        network = tflearn.input_data(shape=[None, max_len], name='input')
        network = tflearn.embedding(network, input_dim=volcab_size, output_dim=32)
    
        network = conv_1d(network, 64, 3, activation='relu', regularizer="L2")
        network = max_pool_1d(network, 2)
        network = conv_1d(network, 64, 3, activation='relu', regularizer="L2")
        network = max_pool_1d(network, 2)
    
        network = batch_normalization(network)
        network = fully_connected(network, 64, activation='relu')
        network = dropout(network, 0.5)
    
        network = fully_connected(network, 2, activation='softmax')
        sgd = tflearn.SGD(learning_rate=0.1, lr_decay=0.96, decay_step=1000)
        network = regression(network, optimizer=sgd, loss='categorical_crossentropy')
    
        model = tflearn.DNN(network, tensorboard_verbose=0)
        return model
    
    
    def get_data_from(file_name):
        ans = []
        with open(file_name) as f:
            for line in f:
                domain_name = line.strip()
                ans.append(domain_name)
        return ans
    
    
    def get_local_data(tag="labeled"):
        white_data = get_data_from(file_name="dga_360_sorted.txt")
        black_data = get_data_from(file_name="top-1m.csv")
        return black_data, white_data
    
    
    def get_data():
        black_x, white_x = get_local_data()
        black_y, white_y = [1]*len(black_x), [0]*len(white_x)
    
        X = black_x + white_x
        labels = black_y + white_y
    
        # Generate a dictionary of valid characters
        valid_chars = {x:idx+1 for idx, x in enumerate(set(''.join(X)))}
    
        max_features = len(valid_chars) + 1
        print("max_features:", max_features)
        maxlen = max([len(x) for x in X])
        print("max_len:", maxlen)
        maxlen = min(maxlen, 256)
    
        # Convert characters to int and pad
        X = [[valid_chars[y] for y in x] for x in X]
        X = pad_sequences(X, maxlen=maxlen, value=0.)
    
        # Convert labels to 0-1
        Y = to_categorical(labels, nb_classes=2)
        
        volcab_file = "volcab.pkl"
        output = open(volcab_file, 'wb')
        # Pickle dictionary using protocol 0.
        data = {"valid_chars": valid_chars, "max_len": maxlen, "volcab_size": max_features}
        pickle.dump(data, output)
        output.close()
    
        return X, Y, maxlen, max_features
    
    
    def train_model():
        X, Y, max_len, volcab_size = get_data()
    
        print("X len:", len(X), "Y len:", len(Y))
        trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=42)
        print(trainX[:1])
        print(trainY[:1])
        print(testX[-1:])
        print(testY[-1:])
    
        model = get_cnn_model(max_len, volcab_size)
        model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True, batch_size=1024)
       
        filename = 'finalized_model.tflearn'
        model.save(filename)
    
        model.load(filename)
        print("Just review 3 sample data test result:")
        result = model.predict(testX[0:3])
        print(result)
    
    
    def test_model():
        volcab_file = "volcab.pkl"
        assert os.path.exists(volcab_file)
        pkl_file = open(volcab_file, 'rb')
        data = pickle.load(pkl_file)
        valid_chars, max_document_length, max_features = data["valid_chars"], data["max_len"], data["volcab_size"]
    
        print("max_features:", max_features)
        print("max_len:", max_document_length)
    
        cnn_model = get_cnn_model(max_document_length, max_features)
        filename = 'finalized_model.tflearn'
        cnn_model.load(filename)
        print("predict domains:")
        bls = list()
    
        
        with open("dga_360_sorted.txt") as f:
        # with open("todo.txt") as f:
            lines = f.readlines()
            print("domain_list len:", len(lines))
            cnt = 1000
            for i in range(0, len(lines), cnt):
                lines2 = lines[i:i+cnt]
                domain_list = [line.strip() for line in lines2]
                #print("domain_list sample:", domain_list[:5])
            
                # Convert characters to int and pad
                X = [[valid_chars[y] if y in valid_chars else 0 for y in x] for x in domain_list]
                X = pad_sequences(X, maxlen=max_document_length, value=0.)
            
                result = cnn_model.predict(X)
                for i, domain in enumerate(domain_list):
                    if result[i][1] > .5: #.95:
                        #print(lines2[i], domain + " is GDA")
                        print(lines2[i].strip() + "	" + domain, result[i][1])
                        bls.append(domain)
                    else:
                        #print(lines2[i], domain )
                        pass
                #print(bls)
            print(len(bls) , "dga found!")
    
    
    if __name__ == "__main__":
        print("train model...")
        train_model()
        print("test model...")
        test_model()
    
  • 相关阅读:
    Gradle Gretty进行runAppDebug的Listening for transport dt_socket at address: 5005 的后续配置
    Oracle :value too large for column "SCHEMA"."TABLE"."COLUMN" (actual: 519, maximum: 500)的解决方案
    js file对象 文件大小转换可视容易阅读的单位
    JS的Event各种属性级target/currentTarget/relatedTarget各种目录的解释
    浏览器控制台是否打开的一些措施的讨论
    eclipse启动指定jvm的版本
    IDEA terminal无法从vim的编辑模式转换为命令模式
    win7 64位系统在IronPython2.7 rc安装后运行出现"ipy64/ipy.exe"does not exist解决办法
    VS2010 插件 CSS3 IS 2.1.1 在win7 64位机子上安装小记
    Asp.net ajax 1.0 绑定drowdownlist时取值问题
  • 原文地址:https://www.cnblogs.com/bonelee/p/11958214.html
Copyright © 2011-2022 走看看