zoukankan      html  css  js  c++  java
  • golang调用tensorflow keras训练的音频分类模型

    1 实现场景分析

    业务在外呼中经常会遇到接听者因忙或者空号导致返回的回铃音被语音识别引擎识别并传递给业务流程解析,而这种情况会在外呼后的业务统计中导致接通率的统计较低,为了解决该问题,打算在回铃音进入语音识别引擎前进行识别,判断为非接通的则直接丢弃不在接入流程处理。
    经过对场景中的录音音频及语音识别的文字进行分析,发现大部分的误识别回铃音都是客户忙或者是空号,与正常接通的音频特征区分很明显,如下所示采用科大讯飞的语音识别引擎对失败的回铃音转写的结果
    image
    从转写结果统计也验证了我们的分析。(针对回铃音为视频彩铃等的暂时没有统计到,这里也不作为主要的失败音频分析)

    2 模型训练实现基于深度学习的声音分类

    实际实践参考 keras实现声音二分类 文章中有对音频特征mfcc的说明,流程分析也很详细,可以参考。这里主要贴下在验证中使用的代码,模型训练代码

    import os
    import keras
    import librosa
    import numpy as np
    import matplotlib.pyplot as plt
    from keras import Sequential
    from keras.utils import to_categorical
    from keras.layers import Dense
    from sklearn.model_selection import train_test_split
    import tensorflow as tf
    from keras import backend as  K
    
    DATA = 'data.npy'
    TARGET = 'target.npy'
    
    
    def load_label(label_path):
        """
        遍历当前给定的目录,便于后文进行标签加载
        :param label_path:
        :return:
        """
        label = os.listdir(label_path)
        return label
    
    
    # 提取 mfcc 参数
    def wav2mfcc(path, max_pad_size=11):
        """
        备注:由于我们拿到的音频文件,持续时间都不尽相同,所以提取到的 mfcc 大小是不相同的。
        但是神经网络要求待处理的矩阵大小要相同,所以这里我们用到了铺平操作。我们 mfcc 系数默认提取 20 帧,对于每一帧来说,
        如果帧长小于 11,我们就用 0 填满不满足要求的帧;如果帧长大于 11,我们就只选取前 11 个参数
        :param path:    音频文件地址
        :param max_pad_size:    帧长,最大设置为11
        :return:
        """
        # 读取音频文件,按照音频本身的采样率进行读取
        y, sr = librosa.load(path=path, sr=None, mono=True)
        y = y[::3]  # 不需要太高的采样率数据,这里进行每三个点选用一个
        audio_mac = librosa.feature.mfcc(y=y, sr=16000)
        y_shape = audio_mac.shape[1]
        if y_shape < max_pad_size:
            """
            函数numpy.pad(array, pad_width, mode),其中 array 是我们需要填充的矩阵,pad_width是各个维度上首尾填充的个数。
            举个例子,假定我们设置的 pad_width 是((0,0), (0,2)),而待处理的 mfcc 系数是 20 * 11 的矩阵。
            我们把 mfcc 系数看成 20 行 11 列的矩阵,进行 pad 操作,第一个(0,0)对行进行操作,
            表示每一行最前面和最后面增加的数个数为零,也就相当于总共增加了 0 列。第二个(0,2)对列操作,
            表示每一列最前面增加的数为 0 个,但最后面要增加两个数,也就相当于总共增加了 2 行。
            mode 设置为 ‘constant’,表明填充的是常数,且默认为 0 
            """
            pad_size = max_pad_size - y_shape
            audio_mac = np.pad(audio_mac, ((0, 0), (0, pad_size)), mode='constant')
        else:
            audio_mac = audio_mac[:, :max_pad_size]
        return audio_mac
    
    
    def save_data_to_array(label_path, max_pad_size=11):
        """
        存储处理过的数据,方便下一次的使用
        :param label_path:
        :param max_pad_size:
        :return:
        """
        mfcc_vectors = []
        target = []
        labels = load_label(label_path=label_path)
        for i, label in enumerate(labels):
            path = label_path + '/' + label
            wavfiles = [path + '/' + file for file in os.listdir(path)]
            for wavfile in wavfiles:
                wav = wav2mfcc(wavfile, max_pad_size=max_pad_size)
                mfcc_vectors.append(wav)
                target.append(i)
        np.save(DATA, mfcc_vectors)
        np.save(TARGET, target)
        # return mfcc_vectors, target
    
    
    def get_train_test(split_ratio=.6, random_state=42):
        """
        使用sklearn 中的train_test_split,把数据集分为训练集和验证集。其中训练集占 6 成,测试集占 4 成
        :param split_ratio:
        :param random_state:
        :return:
        """
        X = np.load(DATA)
        y = np.load(TARGET)
        assert X.shape[0] == y.shape[0]
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=(1 - split_ratio), random_state=random_state,
                                                            shuffle=True)
        return X_train, X_test, y_train, y_test
    
    
    def main():
        x_train, x_test, y_train, y_test = get_train_test()
        # 变成二维矩阵且第二个维度大小为 220
        x_train = x_train.reshape(-1, 220)
        x_test = x_test.reshape(-1, 220)
        # 使用kears中的onehot编码
        y_train_hot = to_categorical(y_train)
        y_test_hot = to_categorical(y_test)
        model = Sequential()
        model.add(Dense(64, activation='relu', input_shape=(220,), name="input_layer"))
        model.add(Dense(64, activation='relu', name="dropout_layer1"))
        model.add(Dense(64, activation='relu', name="dropout_layer2"))
        model.add(Dense(2, activation='softmax', name="output_layer"))
    
        # 模型训练
        sess = tf.Session()
        K.set_session(sess)
        # 这步找到input_layer和output_layer的完整路径,在golang中使用时需要用来定义输入输出node
        for n in sess.graph.as_graph_def().node:
            if 'input_layer' in n.name:
                print(n.name)
            if 'output_layer' in n.name:
                print(n.name)
    
        model.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.RMSprop(),
                      metrics=['accuracy'])
        history = model.fit(x_train, y_train_hot, batch_size=100, epochs=100, verbose=1,
                            validation_data=(x_test, y_test_hot))
    
        # 以下是关键代码
        # Use TF to save the graph model instead of Keras save model to load it in Golang
        builder = tf.saved_model.builder.SavedModelBuilder("cnnModel")
        # Tag the model, required for Go
        builder.add_meta_graph_and_variables(sess, ["myTag"])
        builder.save()
    
        model.save("classaud.h5")
        plot_history(history)
    
        sess.close()
    
    
    def save():
        label_path = 'F:\doc\项目\音频分类\audio-heji\'
        save_data_to_array(label_path, max_pad_size=11)
    
    
    def plot_history(history):
        plt.plot(history.history['acc'], label='train')
        plt.plot(history.history['val_acc'], label='validation')
        plt.legend()
        plt.show()
    
    
    if __name__ == "__main__":
        save_data_to_array("F:\doc\项目\音频分类\audio-heji\", max_pad_size=11)
        # save_data_to_array("/home/audio/audio-heji/", max_pad_size=11)
        main()
    

    音频文件重命名代码

    import os
    
    
    def rename(pic_path):
        """
        这里对两个目录下文件进行排序,第一个fail-audio设置成10001这种,第二个success-audio目录下设置成了90001这种。
        :param pic_path:
        :return:
        """
        piclist = os.listdir(pic_path)
        i = 1
        print("ok")
        for pic in piclist:
            if pic.endswith(".wav"):
                old_path = os.path.join(os.path.abspath(pic_path), pic)
                new_path = os.path.join(os.path.abspath(pic_path), str(
                    90000 + (int(i))) + '.wav')
                os.renames(old_path, new_path)
                print("把原命名格式:" + old_path + u"转换为新命名格式:" + new_path)
                i = i + 1
    
    
    # 加载标签
    def load_label(label_path):
        label = os.listdir(label_path)
        return label
    
    if __name__ == '__main__':
        rename("F:\doc\项目\音频分类\audio-heji\success-audio")
    

    测试代码

    import librosa
    import numpy as np
    import os
    from keras.models import load_model
    
    
    # 提取 mfcc 参数
    def wav2mfcc(path, max_pad_size=11):
        y, sr = librosa.load(path=path, sr=None, mono=1)
        y = y[::3]  # 每三个点选用一个
        audio_mac = librosa.feature.mfcc(y=y, sr=16000)
        y_shape = audio_mac.shape[1]
        if y_shape < max_pad_size:
            pad_size = max_pad_size - y_shape
            audio_mac = np.pad(audio_mac, ((0, 0), (0, pad_size)), mode='constant')
        else:
            audio_mac = audio_mac[:, :max_pad_size]
        return audio_mac
    
    
    def load_label(label_path):
        """
        遍历当前给定的目录,便于后文进行标签加载
        :param label_path:
        :return:
        """
        label = os.listdir(label_path)
        return label
    
    
    if __name__ == '__main__':
        # 加载模型
        model = load_model('classaud.h5')  # 加载训练模型
        wavs = [wav2mfcc("F:\doc\项目\音频分类\test\" + file, 11) for file in os.listdir("F:\doc\项目\音频分类\test")]
        X = np.array(wavs)
        X = X.reshape(-1, 220)
        print(X.shape)
    
        for j in range(X.shape[0]):
            print(j)
            print(X[j:j+1])
            result = model.predict(X[j:j+1])[0]  #
            print("识别结果", result)
            #  因为在训练的时候,标签集的名字 为:  0:fail-audio   1:success-audio
            name = ["fail-audio", "success-audio"]  # 创建一个跟训练时一样的标签集
            ind = 0  # 结果中最大的一个数
            for i in range(len(result)):
                if result[i] > result[ind]:
                    ind = 1
            print("识别的语音结果是:", name[ind])
    

    在模型训练中,为了后续golang中加载训练好的模型,增加了部分代码,主要是如下

    # 模型训练
        sess = tf.Session()
        K.set_session(sess)
        # 这步找到input_layer和output_layer的完整路径,在golang中使用时需要用来定义输入输出node
        for n in sess.graph.as_graph_def().node:
            if 'input_layer' in n.name:
                print(n.name)
            if 'output_layer' in n.name:
                print(n.name)
    
        model.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.RMSprop(),
                      metrics=['accuracy'])
        history = model.fit(x_train, y_train_hot, batch_size=100, epochs=100, verbose=1,
                            validation_data=(x_test, y_test_hot))
    
        # 以下是关键代码
        # Use TF to save the graph model instead of Keras save model to load it in Golang
        builder = tf.saved_model.builder.SavedModelBuilder("cnnModel")
        # Tag the model, required for Go
        builder.add_meta_graph_and_variables(sess, ["myTag"])
        builder.save()
    

    模型训练完成后会生成响应的模型,其中cnnModel文件夹包含pd模型及variables文件夹,为后续golang调度使用的模型,h5模型为这里测试使用的模型,结果如下图
    image
    另外在模型训练时,我们打印出了每层神经网络的名字,这块需要关注,因为在后续的golang环境中因为加载节点名错误导致的问题,打印如下

    WARNING:tensorflow:From E:pycharm
    lu-algorithm-packagevenvlibsite-packages	ensorflowpythonframeworkop_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
    Instructions for updating:
    Colocations handled automatically by placer.
    2020-08-27 19:06:19.437965: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
    input_layer_input
    input_layer/random_uniform/shape
    input_layer/random_uniform/min
    input_layer/random_uniform/max
    input_layer/random_uniform/RandomUniform
    input_layer/random_uniform/sub
    input_layer/random_uniform/mul
    input_layer/random_uniform
    input_layer/kernel
    input_layer/kernel/Assign
    input_layer/kernel/read
    input_layer/Const
    input_layer/bias
    input_layer/bias/Assign
    input_layer/bias/read
    input_layer/MatMul
    input_layer/BiasAdd
    input_layer/Relu
    output_layer/random_uniform/shape
    output_layer/random_uniform/min
    output_layer/random_uniform/max
    output_layer/random_uniform/RandomUniform
    output_layer/random_uniform/sub
    output_layer/random_uniform/mul
    output_layer/random_uniform
    output_layer/kernel
    output_layer/kernel/Assign
    output_layer/kernel/read
    output_layer/Const
    output_layer/bias
    output_layer/bias/Assign
    output_layer/bias/read
    output_layer/MatMul
    output_layer/BiasAdd
    output_layer/Softmax
    

    测试中的一条结果记录为

    25
    [[-1.66098328e+02 -2.12202209e+02 -4.31954193e+02 -4.47424835e+02
      -3.33869904e+02 -3.58775604e+02 -5.43935608e+02 -6.60088867e+02
      -2.30052383e+02 -6.87607422e+01 -1.21050777e+01  4.23143768e+01
       4.26393929e+01  4.03902130e+01  3.61016235e+01  3.74344673e+01
       3.76606216e+01  2.93862953e+01  1.91747296e+00 -9.67822647e+00
      -1.10491867e+01 -1.35342636e+01 -6.71506195e+01 -6.73550262e+01
      -6.63572083e+01 -4.63606071e+01 -4.91609383e+01 -5.08716202e+01
      -4.48681793e+01 -1.87636101e+00 -4.61852417e+01 -4.17765961e+01
      -2.01916142e+01 -9.79832268e+00 -9.91659927e+00 -1.03536911e+01
      -3.72100983e+01 -3.87253189e+01 -3.71894073e+01 -2.30876770e+01
       7.12325764e+00 -1.86448917e+01 -1.23296547e+01 -1.36949463e+01
       2.36775169e+01  2.40340843e+01  2.65298653e+01 -1.50987358e+01
      -1.75523911e+01 -1.85126324e+01 -1.25076523e+01  7.13483989e-01
       1.25870705e+01  9.64732361e+00  1.51968069e+01  1.57215271e+01
       1.60245552e+01  1.79179173e+01  8.83283997e+00  6.01854324e+00
       5.03554726e+00  7.35849476e+00  7.32028627e+00  2.61785851e+01
       2.36798325e+01  2.17165947e+01 -8.92044258e+00 -8.98554516e+00
      -9.36619282e+00  8.72706318e+00  7.63292933e+00  8.11320019e+00
       1.23650103e+01  4.05273724e+00 -2.74963975e+00  6.03793526e+00
       3.32019657e-01 -1.07288876e+01 -1.08059797e+01 -1.12256699e+01
       3.14207101e+00  2.12598038e+00  3.41194320e+00  1.52846613e+01
       3.58989859e+00 -7.22187281e+00 -2.92357159e+00 -1.40336580e+01
       1.25538235e+01  1.30331860e+01  1.53311596e+01  4.33768129e+00
       3.44467473e+00  3.40012264e+00  8.20392323e+00  4.26902437e+00
       2.08825417e+01  1.70654850e+01  6.18888092e+00  1.60812531e+01
       1.66913948e+01  1.89202042e+01 -3.88430738e+00 -3.84901094e+00
      -2.63745117e+00 -1.11108208e+00  6.36666417e-01 -2.25304246e+00
      -8.85197830e+00 -1.91202374e+01 -1.46577859e+00 -1.23925185e+00
       7.83551395e-01  2.47491241e+00  1.99022865e+00  1.92004573e+00
      -4.81319517e-01  5.03325987e+00 -3.67527580e+00  1.00166121e+01
       1.32022190e+01  1.41424477e+00  1.61442292e+00  1.70821917e+00
       1.69235497e+01  1.71120987e+01  1.50358353e+01  1.89076972e+00
       6.39788628e+00 -1.38607597e+00  1.02029294e-01 -1.02280178e+01
       4.86864090e+00  5.27380610e+00  7.52170134e+00 -5.28688669e+00
      -4.54772520e+00 -1.99729645e+00  9.84556198e+00  6.01770210e+00
       6.42514515e+00  5.05019569e+00 -3.58427215e+00  2.53060675e+00
       2.77301073e+00  3.36337566e+00 -3.74559736e+00 -3.88710737e+00
      -2.85192370e+00  8.00442696e+00  5.90060949e-01  4.97644138e+00
       8.34950066e+00  6.98132086e+00 -2.52411366e+00 -2.62051678e+00
      -2.87881041e+00  7.85154676e+00  7.58860874e+00  7.21512508e+00
       6.53992605e+00 -2.57507980e-01 -4.69269657e+00 -3.40787172e+00
       1.32537198e+00 -4.12547922e+00 -4.38115311e+00 -5.22326088e+00
       2.03997135e+00  2.89152718e+00  4.46722126e+00  6.32854557e+00
       5.27089882e+00  8.66948891e+00  5.42871141e+00  1.18013754e+01
      -1.06842923e+00 -1.21782100e+00 -1.62583649e+00  4.68027020e+00
       4.82862568e+00  5.17801666e+00  5.02924442e+00  3.57280898e+00
      -1.60658951e+01 -3.89933228e+00  5.28476810e+00  7.85575271e-01
       7.23506689e-01  9.12119806e-01  1.29786149e-01 -1.08789623e+00
      -2.01344442e+00 -3.12455368e+00  4.22501802e+00 -2.05132604e-01
       4.64199352e+00  1.28417645e+01 -2.14332151e+00 -2.28186941e+00
      -2.92554092e+00 -7.33241290e-02 -4.54517424e-01 -9.12135363e-01
      -1.92673039e+00  1.32999837e+00 -5.95955181e+00 -1.38899193e+01
      -1.53170991e+00 -3.57915735e+00 -3.69184279e+00 -3.76903296e+00
      -3.22209269e-01 -6.59340382e-01  5.86225927e-01  1.03645115e+01
       2.81656504e+00 -1.55326450e+00 -1.87907255e+00 -2.12706447e+00]]
    识别结果 [1.4359492e-26 1.0000000e+00]
    识别的语音结果是: success-audio
    

    后续的golang调用tensorflow模型中我们以此结果作为测试。
    备注:此处使用的tensorflow为1.13.1版本、kears为2.2.4、python3.6.0,音频特征抽取使用的librosa库,版本0.8.0。

    其他方式实现的音频分类,在实践中也参考了 Python Project – Music Genre Classification 该文章时使用的K近邻方式实现的,在实践中也做了此方式,准确率在70%左右,也可以参考。

    3 golang调用tensorflow/keras训练的模型

    安装 go 版 TensorFlow的问题记录 我们已经将相关环境设置好,此处只需要完成相关的go代码即可上线测试。代码验证参考了 golang调用tensorflow/keras训练的模型 此处先给出基于此文章实现的验证代码

    package main
     
    import (
            "fmt"
            tf "github.com/tensorflow/tensorflow/tensorflow/go"
    )
     
    func main() {
            // 特征长度
            const MAXLEN int = 220
            // 将文本转换为id序列,为了实验方便直接使用转换好的ID序列即可,此处是使用上文中测试中打印出来的音频特征
            input_data := [1][MAXLEN]float32{{-1.66098328e+02,-2.12202209e+02,-4.31954193e+02,-4.47424835e+02,-3.33869904e+02,-3.58775604e+02,-5.43935608e+02,-6.60088867e+02,-2.30052383e+02,-6.87607422e+01,-1.21050777e+01,4.23143768e+01,4.26393929e+01,4.03902130e+01,3.61016235e+01,3.74344673e+01,3.76606216e+01,2.93862953e+01,1.91747296e+00,-9.67822647e+00,-1.10491867e+01,-1.35342636e+01,-6.71506195e+01,-6.73550262e+01,-6.63572083e+01,-4.63606071e+01,-4.91609383e+01,-5.08716202e+01,-4.48681793e+01,-1.87636101e+00,-4.61852417e+01,-4.17765961e+01,-2.01916142e+01,-9.79832268e+00,-9.91659927e+00,-1.03536911e+01,-3.72100983e+01,-3.87253189e+01,-3.71894073e+01,-2.30876770e+01,7.12325764e+00,-1.86448917e+01,-1.23296547e+01,-1.36949463e+01,2.36775169e+01,2.40340843e+01,2.65298653e+01,-1.50987358e+01,-1.75523911e+01,-1.85126324e+01,-1.25076523e+01,7.13483989e-01,1.25870705e+01,9.64732361e+00,1.51968069e+01,1.57215271e+01,1.60245552e+01,1.79179173e+01,8.83283997e+00,6.01854324e+00,5.03554726e+00,7.35849476e+00,7.32028627e+00,2.61785851e+01,2.36798325e+01,2.17165947e+01,-8.92044258e+00,-8.98554516e+00,-9.36619282e+00,8.72706318e+00,7.63292933e+00,8.11320019e+00,1.23650103e+01,4.05273724e+00,-2.74963975e+00,6.03793526e+00,3.32019657e-01,-1.07288876e+01,-1.08059797e+01,-1.12256699e+01,3.14207101e+00,2.12598038e+00,3.41194320e+00,1.52846613e+01,3.58989859e+00,-7.22187281e+00,-2.92357159e+00,-1.40336580e+01,1.25538235e+01,1.30331860e+01,1.53311596e+01,4.33768129e+00,3.44467473e+00,3.40012264e+00,8.20392323e+00,4.26902437e+00,2.08825417e+01,1.70654850e+01,6.18888092e+00,1.60812531e+01,1.66913948e+01,1.89202042e+01,-3.88430738e+00,-3.84901094e+00,-2.63745117e+00,-1.11108208e+00,6.36666417e-01,-2.25304246e+00,-8.85197830e+00,-1.91202374e+01,-1.46577859e+00,-1.23925185e+00,7.83551395e-01,2.47491241e+00,1.99022865e+00,1.92004573e+00,-4.81319517e-01,5.03325987e+00,-3.67527580e+00,1.00166121e+01,1.32022190e+01,1.41424477e+00,1.61442292e+00,1.70821917e+00,1.69235497e+01,1.71120987e+01,1.50358353e+01,1.89076972e+00,6.39788628e+00,-1.38607597e+00,1.02029294e-01,-1.02280178e+01,4.86864090e+00,5.27380610e+00,7.52170134e+00,-5.28688669e+00,-4.54772520e+00,-1.99729645e+00,9.84556198e+00,6.01770210e+00,6.42514515e+00,5.05019569e+00,-3.58427215e+00,2.53060675e+00,2.77301073e+00,3.36337566e+00,-3.74559736e+00,-3.88710737e+00,-2.85192370e+00,8.00442696e+00,5.90060949e-01,4.97644138e+00,8.34950066e+00,6.98132086e+00,-2.52411366e+00,-2.62051678e+00,-2.87881041e+00,7.85154676e+00,7.58860874e+00,7.21512508e+00,6.53992605e+00,-2.57507980e-01,-4.69269657e+00,-3.40787172e+00,1.32537198e+00,-4.12547922e+00,-4.38115311e+00,-5.22326088e+00,2.03997135e+00,2.89152718e+00,4.46722126e+00,6.32854557e+00,5.27089882e+00,8.66948891e+00,5.42871141e+00,1.18013754e+01,-1.06842923e+00,-1.21782100e+00,-1.62583649e+00,4.68027020e+00,4.82862568e+00,5.17801666e+00,5.02924442e+00,3.57280898e+00,-1.60658951e+01,-3.89933228e+00,5.28476810e+00,7.85575271e-01,7.23506689e-01,9.12119806e-01,1.29786149e-01,-1.08789623e+00,-2.01344442e+00,-3.12455368e+00,4.22501802e+00,-2.05132604e-01,4.64199352e+00,1.28417645e+01,-2.14332151e+00,-2.28186941e+00,-2.92554092e+00,-7.33241290e-02,-4.54517424e-01,-9.12135363e-01,-1.92673039e+00,1.32999837e+00,-5.95955181e+00,-1.38899193e+01,-1.53170991e+00,-3.57915735e+00,-3.69184279e+00,-3.76903296e+00,-3.22209269e-01,-6.59340382e-01,5.86225927e-01,1.03645115e+01,2.81656504e+00,-1.55326450e+00,-1.87907255e+00,-2.12706447e+00}}
            tensor, err := tf.NewTensor(input_data)
            if err != nil {
                    fmt.Printf("Error NewTensor: err: %s", err.Error())
                    return
            }
            //读取模型
            model, err := tf.LoadSavedModel("cnnModel", []string{"myTag"}, nil)
            if err != nil {
                    fmt.Printf("Error loading Saved Model: %s
    ", err.Error())
                    return
            }
            // 识别
            result, err := model.Session.Run(
                    map[tf.Output]*tf.Tensor{
                            // python版tensorflow/keras中定义的输入层input_layer
                            model.Graph.Operation("input_layer").Output(0): tensor,
                    },
                    []tf.Output{
                            // python版tensorflow/keras中定义的输出层output_layer
                            model.Graph.Operation("output_layer/Softmax").Output(0),
                    },
                    nil,
            )
     
            if err != nil {
                    fmt.Printf("Error running the session with input, err: %s  ", err.Error())
                    return
            }
            // 输出结果,interface{}格式
            fmt.Printf("Result value: %v", result[0].Value())
    }
    

    将训练好的模型及验证程序上传到虚拟机的gopath目录

    [root@localhost gopath]# cd /home/gopath/
    [root@localhost gopath]# 
    [root@localhost gopath]# ll
    总用量 8
    -rw-r--r--. 1 root root 5072 8月  27 21:21 cnn.go
    drwxr-xr-x. 3 root root   45 8月  27 19:27 cnnModel
    drwxr-xr-x. 3 root root   25 8月  27 15:40 pkg
    drwxr-xr-x. 5 root root   65 8月  27 15:40 src
    

    执行程序后报错,如下所示

    [root@localhost gopath]# go run cnn.go 
    2020-08-27 21:22:55.306065: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: cnnModel
    2020-08-27 21:22:55.312739: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { myTag }
    2020-08-27 21:22:55.320111: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
    2020-08-27 21:22:55.335002: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1799995000 Hz
    2020-08-27 21:22:55.335970: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x1ba6ea0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
    2020-08-27 21:22:55.336091: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
    2020-08-27 21:22:55.364460: I tensorflow/cc/saved_model/loader.cc:202] Restoring SavedModel bundle.
    2020-08-27 21:22:55.449190: I tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { myTag }; Status: success. Took 143138 microseconds.
    模型read成功panic: nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.
    
    goroutine 1 [running]:
    github.com/tensorflow/tensorflow/tensorflow/go.Output.c(...)
    	/home/gopath/src/github.com/tensorflow/tensorflow/tensorflow/go/operation.go:130
    github.com/tensorflow/tensorflow/tensorflow/go.newCRunArgs(0xc000097e50, 0xc000097e20, 0x1, 0x1, 0x0, 0x0, 0x0, 0xc000097668)
    	/home/gopath/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:369 +0x594
    github.com/tensorflow/tensorflow/tensorflow/go.(*Session).Run(0xc00000e0c0, 0xc000097e50, 0xc000097e20, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, ...)
    	/home/gopath/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:143 +0x1e2
    main.main()
    	/home/gopath/cnn.go:27 +0x1199
    exit status 2
    [root@localhost gopath]#
    

    错误显示找不到输出节点,可是我们在代码中已经设置了相关的输出节点信息,如下所示

    map[tf.Output]*tf.Tensor{
                            // python版tensorflow/keras中定义的输入层input_layer
                            model.Graph.Operation("input_layer").Output(0): tensor,
                    },
                    []tf.Output{
                            // python版tensorflow/keras中定义的输出层output_layer
                            model.Graph.Operation("output_layer/Softmax").Output(0),
                    },
    

    那找不到相关的节点是否是节点名字绑定问题,我们回看在模型训练时打印的各层节点名称,发现是以“input_layer_input”开始“output_layer/Softmax”结束,查看我们代码中是按照训练中设置的name进行标记的,故将输入的开始节点从“input_layer”修改为“input_layer_input”,如下

    map[tf.Output]*tf.Tensor{
                            // python版tensorflow/keras中定义的输入层input_layer
                            model.Graph.Operation("input_layer_input").Output(0): tensor,
                    },
                    []tf.Output{
                            // python版tensorflow/keras中定义的输出层output_layer
                            model.Graph.Operation("output_layer/Softmax").Output(0),
                    },
    

    重新运行程序,结果如下

    [root@localhost gopath]# go run cnn.go 
    2020-08-27 21:28:33.161277: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: cnnModel
    2020-08-27 21:28:33.166881: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { myTag }
    2020-08-27 21:28:33.173316: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
    2020-08-27 21:28:33.186007: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1799995000 Hz
    2020-08-27 21:28:33.187068: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x13fbea0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
    2020-08-27 21:28:33.187126: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
    2020-08-27 21:28:33.210363: I tensorflow/cc/saved_model/loader.cc:202] Restoring SavedModel bundle.
    2020-08-27 21:28:33.294067: I tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { myTag }; Status: success. Took 132796 microseconds.
    模型read成功模型识别成功Result value: [[1.4359492e-26 1]]
    

    对比上文给出的测试中的识别结果是保持一致的,故这里golang调度keras训练的模型成功,后续就是实际场景如何进行使用的问题规划。

  • 相关阅读:
    Java 过滤器的作用
    TreeView的绑定
    设计模式(一)工厂模式Factory(创建型)
    【剑指offer】员工年龄排序
    Spring3.0 AOP 具体解释
    IT行业新名词--透明手机/OCR(光学字符识别)/夹背电池
    MYSQL C API 记录
    Hibernate的介绍
    数据绑定(八)使用Binding的RelativeSource
    一、ExtJS下载使用
  • 原文地址:https://www.cnblogs.com/yhzhou/p/13579496.html
Copyright © 2011-2022 走看看