zoukankan      html  css  js  c++  java
  • pytorch入门与实战学习->第一课复习(fizzuzzfizzbuzz小游戏)

    视频链接:https://www.bilibili.com/video/BV12741177Cu?from=search&seid=17209581732555565064

    视频上是用的jupyter notebook实现的,这次我是用的pycharm实现的代码。

    fizzuzzfizzbuzz小游戏的意思是:如果被3除尽打印fizz,被5除尽打印buzz,被15除尽打印fizzbuzz。这可以用一个函数实现,但是我们是学习神经网络,所以用一个二层神经网络实现,自己去学习,自己去玩,当然界面不实现

    主要有三个.py文件:utils.py存放工具函数,model.py训练模型,paragraph2.py:使用模型进行预测

    utils.py

    import numpy as np
    
    def binary_encode(i, num_digits):   # 转二进制计算
        return np.array([i >> d & 1 for d in range(num_digits)])[::-1]   # [::-1]是把arry倒过来,因为一开始转的是二进制反的
    
    def fizz_buzz_encode(i):
        if i % 15 == 0: return 3
        elif i % 5 == 0: return 2
        elif i % 3 == 0: return 1
        else: return 0
    
    def fizz_buzz_decode(i, prediction):
        return [str(i), 'fizz', 'buzz', 'fizzbuzz'][prediction]   #这是个很好玩的用法,我也是第一次见,各位可以打印一下试试

    model.py实现:

    import torch
    from p2.utils import binary_encode, fizz_buzz_encode
    NUM_DIGITS = 10
    
    trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])   # 训练数据, 101致以上,好像是923个
    trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])   # x可以是float类型,但是y是表示类别的,不行
    
    NUM_HIDDEN = 100
    model = torch.nn.Sequential(    # 模型定义,激活函数为ReLU
        torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
        torch.nn.ReLU(),
        torch.nn.Linear(NUM_HIDDEN, 4)
    )
    
    if torch.cuda.is_available():   # 模型转到gpu上运行
        model = model.cuda()
    
    loss_fn = torch.nn.CrossEntropyLoss() # 损失函数使用交叉熵损失函数
    optimizer = torch.optim.SGD(model.parameters(), lr=0.05)   # 优化算法选择SGD,可百度下SGD,是随机梯度下降法,torch封装了好几个优化算法,可以自行试试
    
    BATCH_SIZE = 128
    
    def __main__():
        for epoch in range(1000):    # 训练epoch是1000, 视频上老师训练是10000,我嫌太大了,慢,所以改为了1000,但是效果确实不如10000的,可以自己试试 
            for start in range(0, len(trX), BATCH_SIZE):   # 批量大小为BATCH_SIZE
                end = start + BATCH_SIZE
                batchX = trX[start:end]
                batchY = trY[start:end]
    
                if torch.cuda.is_available():   # 训练数据搬到gpu
                    batchX = batchX.cuda()
                    batchY = batchY.cuda()
    
                y_pred = model(batchX)
    
                loss = loss_fn(y_pred, batchY)
                print("Epoch", epoch, loss.item())
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
        torch.save(model, 'fbmodel.pkl')

    paragraph2.py实现

    import torch
    from p2.utils import binary_encode, fizz_buzz_decode
    
    model = torch.load('p2/fbmodel.pkl')
    
    NUM_DIGITS = 10
    
    testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
    if torch.cuda.is_available():
        testX = testX.cuda()
    
    with torch.no_grad():
        testY = model(testX)
    
    predictions = zip(range(1, 101), testY.max(1)[1].cpu().data.tolist())      # 非常有意思和技巧的一个东西,testY.max(1)[1].cpu().data.tolist()可以自己试试,打印
    print([fizz_buzz_decode(i, x) for i, x in predictions])

    训练epoch为1000的结果:

    ['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', '10', '11', 'fizz', '13', '14', 'fizz', '16', '17', 'fizz', '19', '20', 'fizz', '22', '23', 'fizz', '25', '26', 'fizz', '28', '29', 'fizz', '31', '32', 'fizz', '34', '35', 'fizz', '37', '38', 'fizz', '40', '41', '42', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', '50', 'fizz', '52', '53', 'fizz', '55', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', 'fizz', 'fizz', '70', '71', 'fizz', '73', '74', 'fizzbuzz', '76', 'buzz', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'fizz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', '100']
    训练epoch为10000的结果:

    ['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', 'fizz', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', '66', '67', '68', 'fizz', '70', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', '78', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', 'fizz', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
    训练数据多少还是有区别的。

    我是小白,虽然我黑。一起学习,一起探讨,加油。

  • 相关阅读:
    【转载】SAP_ECC6.0_EHP4或SAP_ECC6.0_EHP5_基于Windows_Server_2008R2_和SQL_server_2008下的安装
    使用delphi 开发多层应用(二十四)KbmMW 的消息方式和创建WIB节点
    使用delphi 开发多层应用(二十三)KbmMW 的WIB
    实现KbmMw web server 支持https
    KbmMW 服务器架构简介
    Devexpress VCL Build v2014 vol 14.1.1 beta发布
    使用delphi 开发多层应用(二十二)使用kbmMW 的认证管理器
    KbmMW 4.50.00 测试版发布
    Basic4android v3.80 beta 发布
    KbmMW 认证管理器说明(转载)
  • 原文地址:https://www.cnblogs.com/JadenFK3326/p/13113421.html
Copyright © 2011-2022 走看看