zoukankan      html  css  js  c++  java
  • 天气预测(CNN)

      1 import torch
      2 import torch.nn as nn
      3 import torch.utils.data as Data
      4 import numpy as np
      5 import pymysql
      6 import datetime
      7 import csv
      8 import time
      9 
     10 
     11 EPOCH = 100
     12 BATCH_SIZE = 50
     13 
     14 
     15 class MyNet(nn.Module):
     16     def __init__(self):
     17         super(MyNet, self).__init__()
     18         self.con1 = nn.Sequential(
     19             nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
     20             nn.MaxPool1d(kernel_size=1),
     21             nn.ReLU(),
     22         )
     23         self.con2 = nn.Sequential(
     24             nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
     25             nn.MaxPool1d(kernel_size=1),
     26             nn.ReLU(),
     27         )
     28         self.fc = nn.Sequential(
     29             # 线性分类器
     30             nn.Linear(128*6*1, 128),  # 修改大小后要重新计算
     31             nn.ReLU(),
     32             nn.Linear(128, 6),
     33             # nn.Softmax(dim=1),
     34         )
     35         self.mls = nn.MSELoss()
     36         self.opt = torch.optim.Adam(params=self.parameters(), lr=1e-3)
     37         self.start = datetime.datetime.now()
     38 
     39     def forward(self, inputs):
     40         out = self.con1(inputs)
     41         out = self.con2(out)
     42         out = out.view(out.size(0), -1)  # 展开成一维
     43         out = self.fc(out)
     44         # out = F.log_softmax(out, dim=1)
     45         return out
     46 
     47     def train(self, x, y):
     48         out = self.forward(x)
     49         loss = self.mls(out, y)
     50         print('loss: ', loss)
     51         self.opt.zero_grad()
     52         loss.backward()
     53         self.opt.step()
     54 
     55     def test(self, x):
     56         out = self.forward(x)
     57         return out
     58 
     59     def get_data(self):
     60         with open('aaa.csv', 'r') as f:
     61             results = csv.reader(f)
     62             results = [row for row in results]
     63             results = results[1:1500]
     64         inputs = []
     65         labels = []
     66         for result in results:
     67             # 手动独热编码
     68             one_hot = [0 for i in range(6)]
     69             index = int(result[6])-1
     70             one_hot[index] = 1
     71             # labels.append(label)
     72             # one_hot = []
     73             # label = result[6]
     74             # for i in range(6):
     75             #     if str(i) == label:
     76             #         one_hot.append(1)
     77             #     else:
     78             #         one_hot.append(0)
     79             labels.append(one_hot)
     80             input = result[:6]
     81             input = [float(x) for x in input]
     82             # label = [float(y) for y in label]
     83             inputs.append(input)
     84         # print(labels)  # [[0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1],
     85         time.sleep(10)
     86         inputs = np.array(inputs)
     87         labels = np.array(labels)
     88         inputs = torch.from_numpy(inputs).float()
     89         inputs = torch.unsqueeze(inputs, 1)
     90 
     91         labels = torch.from_numpy(labels).float()
     92         return inputs, labels
     93 
     94     def get_test_data(self):
     95         with open('aaa.csv', 'r') as f:
     96             results = csv.reader(f)
     97             results = [row for row in results]
     98             results = results[1500: 1817]
     99         inputs = []
    100         labels = []
    101         for result in results:
    102             label = [result[6]]
    103             input = result[:6]
    104             input = [float(x) for x in input]
    105             label = [float(y) for y in label]
    106             inputs.append(input)
    107             labels.append(label)
    108         inputs = np.array(inputs)
    109         # labels = np.array(labels)
    110         inputs = torch.from_numpy(inputs).float()
    111         inputs = torch.unsqueeze(inputs, 1)
    112         labels = np.array(labels)
    113         labels = torch.from_numpy(labels).float()
    114         return inputs, labels
    115 
    116 
    117 if __name__ == '__main__':
    118     # 训练数据
    119     # net = MyNet()
    120     # x_data, y_data = net.get_data()
    121     # torch_dataset = Data.TensorDataset(x_data, y_data)
    122     # loader = Data.DataLoader(
    123     #     dataset=torch_dataset,
    124     #     batch_size=BATCH_SIZE,
    125     #     shuffle=True,
    126     #     num_workers=2,
    127     # )
    128     # for epoch in range(EPOCH):
    129     #     for step, (batch_x, batch_y) in enumerate(loader):
    130     #         print(step)
    131     #         # print('batch_x={};  batch_y={}'.format(batch_x, batch_y))
    132     #         net.train(batch_x, batch_y)
    133     # # 保存模型
    134     # torch.save(net, 'net.pkl')
    135 
    136 
    137     # 测试数据
    138     net = MyNet()
    139     net.get_test_data()
    140     # 加载模型
    141     net = torch.load('net.pkl')
    142     x_data, y_data = net.get_test_data()
    143     torch_dataset = Data.TensorDataset(x_data, y_data)
    144     loader = Data.DataLoader(
    145         dataset=torch_dataset,
    146         batch_size=100,
    147         shuffle=False,
    148         num_workers=1,
    149     )
    150     num_success = 0
    151     num_sum = 317
    152     for step, (batch_x, batch_y) in enumerate(loader):
    153         # print(step)
    154         output = net.test(batch_x)
    155         # output = output.detach().numpy()
    156         y = batch_y.detach().numpy()
    157         for index, i in enumerate(output):
    158             i = i.detach().numpy()
    159             i = i.tolist()
    160             j = i.index(max(i))
    161             print('输出为{}标签为{}'.format(j+1, y[index][0]))
    162             loss = j+1-y[index][0]
    163             if loss == 0.0:
    164                 num_success += 1
    165     print('正确率为{}'.format(num_success/num_sum))
  • 相关阅读:
    find命令 -- 之查找指定时间内修改过的文件
    nginx
    lighttpd 搭建
    mysql主从复制5.6基于GID及多线程的复制笔记
    centos下MySQL主从同步配置
    数据库集群搭建
    linux 系统监控、诊断工具之 top 详解
    Linux下Apache并发连接数和带宽控制
    DXGI屏幕捕捉
    CUDA以及CUDNN安装配置(WIN10为例)
  • 原文地址:https://www.cnblogs.com/MC-Curry/p/10529566.html
Copyright © 2011-2022 走看看