zoukankan      html  css  js  c++  java
  • pytorch学习笔记(3)

    #FizzBuzz
    def fizz_buzz_encode(i):
        if i%15==0:return 3
        elif i%5==0:return 2
        elif i%3==0:return 1
        else:return 0
    def fizz_buzz_decode(i,prediction):
        return [str(i), 'fizz', 'buzz', 'fizzbuzz'][prediction]
    def helper(i):
        print(fizz_buzz_decode(i,fizz_buzz_encode(i)))
    for i in range(1,16):
        helper(i)
    
    import numpy as np
    import torch
    NUM_DIGITS=10
    def fizz_buzz_encode(i):
        if i%15==0:return 3
        elif i%5==0:return 2
        elif i%3==0:return 1
        else:return 0
    #输入用二进制表示
    def binary_encode(i,num_digits):
        return np.array([i>>d & 1 for d in range(num_digits)][::-1])
    trX=torch.Tensor([binary_encode(i,NUM_DIGITS) for i in range(101,2**NUM_DIGITS)])
    trY=torch.LongTensor([fizz_buzz_encode(i) for i in range (101,2**NUM_DIGITS)])
    binary_encode(15,NUM_DIGITS)
    
    NUM_HIDDEN=100
    model=torch.nn.Sequential(torch.nn.Linear(NUM_DIGITS,NUM_HIDDEN),torch.nn.ReLU(),torch.nn.Linear(NUM_HIDDEN,4))
    if torch.cuda.is_available():
        model=model.cuda()
    
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),lr=0.05)
    BATCH_SIZE=128
    for epoch in range(1000):
        for start in range(0,len(trX),BATCH_SIZE):
            end=start+BATCH_SIZE
            batchX=trX[start:end]
            batchY=trY[start:end]
            if torch.cuda.is_available():
                batchX=batchX.cuda()
                batchY=batchY.cuda()
            y_pred=model(batchX)
            loss=loss_fn(y_pred,batchY)
            print('Epoch',epoch,loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    
    #测试
    testX=torch.Tensor([binary_encode(i,NUM_DIGITS)  for i in range(1,101)])
    if torch.cuda.is_available():
        testX=textX.cuda()
    with torch.no_grad():
        testY=model(testX)
    predicts=zip(range(1,101),testY.max(1)[1].cpu().data.tolist())
    print([fizz_buzz_decode(i,x) for i,x in predicts])
    

      

      

  • 相关阅读:
    thymeleaf时间戳转换
    alerttemplate 时间戳转换
    jQuery.Deferred exception: a.indexOf is not a function TypeError: a.indexOf is not a function
    区分数据是对象还是字符串
    eclipse中选取一列快捷键
    图片拉伸不变型
    这里ajax需要改成同步
    idea如何整理代码格式
    20170311-起早床
    20190310-解决头屑
  • 原文地址:https://www.cnblogs.com/Turing-dz/p/13226663.html
Copyright © 2011-2022 走看看