zoukankan      html  css  js  c++  java
  • 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.cpp -- 模型测试

    测试代码。能看到,训练的时候是单个样本、单个样本的训练的,在NN中是属于“stochastic gradient descent”,否则,一批样本在一起的,就是“standard gradient descent”。

    void test_dA() 
    {
    	srand(0);
      
    	double learning_rate = 0.1;
    	double corruption_level = 0.3;
    	int training_epochs = 100;
    
    	int train_N = 10;
    	int test_N = 2;
    	int n_visible = 20;
    	int n_hidden = 5;
    
    	// training data
    	int train_X[10][20] = {
    		{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0}
    	};
    
    	// construct dA
    	dA da(train_N, n_visible, n_hidden, NULL, NULL, NULL);
    
    	// train
    	for(int epoch=0; epoch<training_epochs; epoch++) 
    	{
    		// train it sample by sample
    		for(int i=0; i<train_N; i++) 
    		{
    			da.train(train_X[i], learning_rate, corruption_level);
    		}
    	}
    
    	// test data
    	int test_X[2][20] = 
    	  {
    		{1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    		{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0}
    	  };
    	double reconstructed_X[2][20];
    
    
    	// test
    	for(int i=0; i<test_N; i++) 
    	{
    		da.reconstruct(test_X[i], reconstructed_X[i]);
    		for(int j=0; j<n_visible; j++) 
    		{
    			printf("%.5f ", reconstructed_X[i][j]);
    		}
    		cout << endl;
    	}
    	 cout << endl;
    }
    
    
    
    int main() 
    {
    	test_dA();
    
    	getchar();
    	return 0;
    }
    

    程序运行结果:


  • 相关阅读:
    Windows Phone 7 中常用Task
    设置Highchart柱子最大宽度( 让 highcharts支持maxPointWidth属性)
    Asp.Net MVC 使用FileResult导出Excel数据文件
    js获取网页高度
    使用window.addEventListener 和 window.attachEvent 判断浏览器
    slimscroll滚动条插件简单用法
    js中如何快速获取数组中的最大值最小值
    js 判断浏览器类型
    python使用ldap进行用户认证
    关于go声明切片的一些疑问
  • 原文地址:https://www.cnblogs.com/xinyuyuanm/p/3206547.html
Copyright © 2011-2022 走看看