zoukankan      html  css  js  c++  java
  • 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.cpp -- 模型测试

    测试代码。能看到,训练的时候是单个样本、单个样本的训练的,在NN中是属于“stochastic gradient descent”,否则,一批样本在一起的,就是“standard gradient descent”。

    void test_dA() 
    {
    	srand(0);
      
    	double learning_rate = 0.1;
    	double corruption_level = 0.3;
    	int training_epochs = 100;
    
    	int train_N = 10;
    	int test_N = 2;
    	int n_visible = 20;
    	int n_hidden = 5;
    
    	// training data
    	int train_X[10][20] = {
    		{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0}
    	};
    
    	// construct dA
    	dA da(train_N, n_visible, n_hidden, NULL, NULL, NULL);
    
    	// train
    	for(int epoch=0; epoch<training_epochs; epoch++) 
    	{
    		// train it sample by sample
    		for(int i=0; i<train_N; i++) 
    		{
    			da.train(train_X[i], learning_rate, corruption_level);
    		}
    	}
    
    	// test data
    	int test_X[2][20] = 
    	  {
    		{1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0}
    	  };
    	double reconstructed_X[2][20];
    
    
    	// test
    	for(int i=0; i<test_N; i++) 
    	{
    		da.reconstruct(test_X[i], reconstructed_X[i]);
    		for(int j=0; j<n_visible; j++) 
    		{
    			printf("%.5f ", reconstructed_X[i][j]);
    		}
    		cout << endl;
    	}
    	 cout << endl;
    }
    
    
    
    int main() 
    {
    	test_dA();
    
    	getchar();
    	return 0;
    }
    

    程序运行结果:


  • 相关阅读:
    跨域(六)——window.name
    跨域(五)——postMessage
    跨域(四)——document.domain
    跨域(三)——JSONP
    Web安全颜色
    跨域(二)——WebSocket
    Win7下npm命令Error: ENOENT问题解决
    跨域(一)——CORS机制
    父组件传值给孙组件
    vue使用bus进行兄弟组件传值
  • 原文地址:https://www.cnblogs.com/xinyuyuanm/p/3206547.html
Copyright © 2011-2022 走看看