zoukankan      html  css  js  c++  java
  • AI手写输入法

    本章承接上一篇的手写数字识别,利用训练好的模型,结合pyqt画板,实现简易手写输入法,为"hello world"例子增添乐趣。

    pyqt是开发图形界面的框架,可以百度查找相关资料了解安装及基础方法,我搭建的环境是pycharm+pyqt5+qtdesigner,配置好之后的界面长这样:

    在左边的项目中右键某个文件,也可以打开qt菜单

    具体怎么画界面不展开了,直接看下代码:

      1 # coding: utf-8
      2 from PyQt5.QtWidgets import *
      3 from PyQt5.QtGui import *
      4 from PyQt5.QtCore import *
      5 import sys
      6 sys.path.append(r'../ml/torch')
      7 from digit_recog import Net
      8 import torch
      9 import os
     10 import numpy as np
     11 import matplotlib.pyplot as plt
     12 from PIL import Image
     13 
     14 
     15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     16 net = Net().to(device)
     17 # 加载参数
     18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth'))
     19 # 参数加载到指定模型
     20 net.load_state_dict(nn_state)
     21 net.eval()
     22 
     23 
     24 def predict(img):
     25     # 读取图片并重设尺寸
     26     image = Image.open(img).resize((28, 28))
     27     # 灰度图
     28     gray_image = image.convert('L')
     29     # plt.imshow(gray_image)
     30     # plt.show()
     31     # 图片数据处理
     32     im_data = np.array(gray_image)
     33     im_data = torch.from_numpy(im_data).float()
     34     im_data = im_data.view(1, 1, 28, 28)
     35     # 神经网络运算
     36     outputs = net(im_data)
     37     # 取最大预测值
     38     _, pred = torch.max(outputs, 1)
     39     return pred.item()
     40 
     41 
     42 class SimpleDrawingBoard(QWidget):
     43     win = ''
     44     wins = []
     45 
     46     @classmethod
     47     def showWin(cls):
     48         # 聚焦到已有窗口
     49         if not cls.win:
     50             cls.win = cls()
     51             cls.win.show()
     52         else:
     53             cls.win.activateWindow()
     54 
     55     def __init__(self, parent=None):
     56         super(SimpleDrawingBoard, self).__init__(parent)
     57 
     58         self.setWindowTitle(u"手写数字识别")
     59         self.setWindowFlags(Qt.WindowStaysOnTopHint)
     60         self.size = (400, 350)
     61         self.resize(*self.size)
     62         self.setWindowFlag(Qt.FramelessWindowHint)  # 隐藏边框
     63         # self.setWindowOpacity(0.9)  # 设置窗口透明度
     64         # self.setAttribute(Qt.WA_TranslucentBackground)  # 设置窗口背景透明
     65 
     66         self.canvasSize = (280, 350)
     67         self.sizeOffset = [a - b for a, b in zip(self.size, self.canvasSize)]
     68         self.canvas = QPixmap(*self.canvasSize)
     69         self.canvas.fill(Qt.black)
     70         self.tempCanvas = QPixmap()
     71         self.lastPoint = QPoint()
     72         self.endPoint = QPoint()
     73         self.isDrawing = False
     74         self.penSize = 15
     75 
     76         self.initUI()
     77 
     78     def initUI(self):
     79         self.penSizeLabel = QLabel(u'画笔粗细')
     80         self.penSizeSpinBox = QSpinBox()
     81         self.penSizeSpinBox.setValue(self.penSize)
     82         self.penSizeSpinBox.valueChanged.connect(self.penSizeSpinBox_valueChanged)
     83         self.penSizeSpinBox.setFixedWidth(80)
     84 
     85         self.clearButton = QPushButton(u'清空')
     86         self.clearButton.setFixedWidth(80)
     87         self.clearButton.clicked.connect(self.clearPainter)
     88 
     89         self.closeButton = QPushButton(u'关闭')
     90         self.closeButton.setFixedWidth(80)
     91         self.closeButton.clicked.connect(self.close)
     92 
     93         self.inputLabel = QLabel(self)
     94         self.inputLabel.setFixedSize(80, 200)
     95         self.inputLabel.setAutoFillBackground(True)
     96         self.inputLabel.setAlignment(Qt.AlignCenter)
     97         self.inputLabel.setStyleSheet('''QLabel{background:#F76677;border-radius:5px;font-size:60px;font-weight:bolder;}''')
     98 
     99         mainLayout = QVBoxLayout(self)
    100 
    101         toolbarLayout = QGridLayout()
    102         # toolbarLayout.setSpacing(20)
    103         toolbarLayout.addWidget(self.penSizeLabel, 0, 0, 1, 1)
    104         toolbarLayout.addWidget(self.penSizeSpinBox, 1, 0, 1, 1)
    105         toolbarLayout.addWidget(self.clearButton, 2, 0, 1, 1)
    106         toolbarLayout.addWidget(self.closeButton, 3, 0, 1, 1)
    107         toolbarLayout.addWidget(self.inputLabel, 4, 0, 1, 1)
    108 
    109         toolbarLayout.setAlignment(Qt.AlignLeft)
    110 
    111         mainLayout.addLayout(toolbarLayout)
    112         mainLayout.addStretch(1)
    113 
    114     def penSizeSpinBox_valueChanged(self):
    115         # 设置画笔粗细
    116         self.penSize = self.penSizeSpinBox.value()
    117 
    118     def paintEvent(self, event):
    119         pp = QPainter(self.canvas)
    120         pen = QPen(QColor(255, 255, 255), self.penSize)
    121         pp.setPen(pen)
    122         if self.lastPoint != self.endPoint:
    123             pp.drawLine(self.lastPoint - QPoint(*self.sizeOffset), self.endPoint - QPoint(*self.sizeOffset))
    124         painter = QPainter(self)
    125         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
    126         self.lastPoint = self.endPoint
    127 
    128     def clearPainter(self):
    129         print('clear...')
    130         self.canvas.fill(Qt.black)
    131         painter = QPainter(self)
    132         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
    133         self.lastPoint = self.endPoint
    134         self.update()
    135         self.inputLabel.clear()
    136 
    137     def mousePressEvent(self, event):
    138         # 按下左键
    139         if event.button() == Qt.LeftButton:
    140             self.lastPoint = event.pos()
    141             self.endPoint = self.lastPoint
    142             self.isDrawing = True
    143 
    144     def mouseMoveEvent(self, event):
    145         if self.isDrawing:
    146             self.update()
    147             self.endPoint = event.pos()
    148 
    149     def mouseReleaseEvent(self, event):
    150         if event.button() == Qt.LeftButton:
    151             self.isDrawing = False
    152             self.endPoint = event.pos()
    153             self.update()
    154             self.canvas.toImage().save('input.png')
    155             input = predict('input.png')
    156             self.inputLabel.setText(str(input))
    157             print('你输入的是{}'.format(input))
    158 
    159 
    160 if __name__ == '__main__':
    161     app = QApplication.instance()
    162     if not app:
    163         app = QApplication(sys.argv)
    164     SimpleDrawingBoard.showWin()
    165     app.exec_()

    上面引入前一章训练好的模型,位于不同的文件夹内,需要加上这一行代码:

    sys.path.append(r'../ml/torch')
    

    看下运行效果:

    上面写了两个数字,识别输出正确!

    helloworld例子比较枯燥,通过动手参与与AI交互增强信心乐趣,信心是一步步建立起来的,而大的突破亦是如此,后面会持续围绕简单的例子,深入发掘AI的乐趣与应用场景。

  • 相关阅读:
    thinkphp在模型中自动完成session赋值
    highcharts实例教程二:结合php与mysql生成饼图
    程序员应该经常看看的网站
    highcharts实例教程一:结合php与mysql生成折线图
    2015-2-10 ecshop
    一个简单的javascript获取URL参数的代码
    table 西边框样式
    PHP 获取当前日期及格式化
    mysql 获取当前日期及格式化
    mysql时间int日期转换
  • 原文地址:https://www.cnblogs.com/migomiddle/p/11976094.html
Copyright © 2011-2022 走看看