zoukankan      html  css  js  c++  java
  • Kernel Regression from Nando's Deep Learning lecture 5

    require 'torch'
    require 'gnuplot'
    
    local nData = 10
    local kWidth = 1
    local xTrain = torch.linspace(-1, 1, nData)
    local yTrain = torch.pow(xTrain, 2)
    print(xTrain)
    print(yTrain)
    local yTrain = yTrain + torch.mul(torch.randn(nData), 0.1)
    print(yTrain)
    
    local function phi(x, y)
        return torch.exp(-(1/kWidth)*torch.sum(torch.pow(x-y,2)))
    end
    
    local Phi = torch.Tensor(nData, nData)
    for i = 1, nData do
        for j = 1, nData do
            Phi[i][j] = phi(xTrain[{{i}}], xTrain[{{j}}])
        end
    end
    
    local regularizer = torch.mul(torch.eye(nData), 0.001)
    local theta = torch.inverse((Phi:t()*Phi) + regularizer) * Phi:t() * yTrain
    
    local nTestData = 100
    local xTest = torch.linspace(-1, 1, nTestData)
    
    local PhiTest = torch.Tensor(nData, nTestData)
    for i = 1, nData do
        for j = 1, nTestData do
            PhiTest[i][j] = phi(xTrain[{{i}}], xTest[{{j}}])
        end
    end
    
    local yPred = PhiTest:t() * theta
    
    gnuplot.plot({'Data', xTrain, yTrain, '+'}, {'Prediction', xTest, yPred, '-'})

  • 相关阅读:
    java中finally的使用
    String基本方法
    java文件读写常用方法
    java笔试面试(转载)
    单链表的反转
    单链表的冒泡排序
    Java快速教程
    Java快速教程
    后海日记(4)
    后海日记(3)
  • 原文地址:https://www.cnblogs.com/devai/p/4839923.html
Copyright © 2011-2022 走看看