zoukankan      html  css  js  c++  java
  • AI手写输入法

    本章承接上一篇的手写数字识别,利用训练好的模型,结合pyqt画板,实现简易手写输入法,为"hello world"例子增添乐趣。

    pyqt是开发图形界面的框架,可以百度查找相关资料了解安装及基础方法,我搭建的环境是pycharm+pyqt5+qtdesigner,配置好之后的界面长这样:

    在左边的项目中右键某个文件,也可以打开qt菜单

    具体怎么画界面不展开了,直接看下代码:

      1 # coding: utf-8
      2 from PyQt5.QtWidgets import *
      3 from PyQt5.QtGui import *
      4 from PyQt5.QtCore import *
      5 import sys
      6 sys.path.append(r'../ml/torch')
      7 from digit_recog import Net
      8 import torch
      9 import os
     10 import numpy as np
     11 import matplotlib.pyplot as plt
     12 from PIL import Image
     13 
     14 
     15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     16 net = Net().to(device)
     17 # 加载参数
     18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth'))
     19 # 参数加载到指定模型
     20 net.load_state_dict(nn_state)
     21 net.eval()
     22 
     23 
     24 def predict(img):
     25     # 读取图片并重设尺寸
     26     image = Image.open(img).resize((28, 28))
     27     # 灰度图
     28     gray_image = image.convert('L')
     29     # plt.imshow(gray_image)
     30     # plt.show()
     31     # 图片数据处理
     32     im_data = np.array(gray_image)
     33     im_data = torch.from_numpy(im_data).float()
     34     im_data = im_data.view(1, 1, 28, 28)
     35     # 神经网络运算
     36     outputs = net(im_data)
     37     # 取最大预测值
     38     _, pred = torch.max(outputs, 1)
     39     return pred.item()
     40 
     41 
     42 class SimpleDrawingBoard(QWidget):
     43     win = ''
     44     wins = []
     45 
     46     @classmethod
     47     def showWin(cls):
     48         # 聚焦到已有窗口
     49         if not cls.win:
     50             cls.win = cls()
     51             cls.win.show()
     52         else:
     53             cls.win.activateWindow()
     54 
     55     def __init__(self, parent=None):
     56         super(SimpleDrawingBoard, self).__init__(parent)
     57 
     58         self.setWindowTitle(u"手写数字识别")
     59         self.setWindowFlags(Qt.WindowStaysOnTopHint)
     60         self.size = (400, 350)
     61         self.resize(*self.size)
     62         self.setWindowFlag(Qt.FramelessWindowHint)  # 隐藏边框
     63         # self.setWindowOpacity(0.9)  # 设置窗口透明度
     64         # self.setAttribute(Qt.WA_TranslucentBackground)  # 设置窗口背景透明
     65 
     66         self.canvasSize = (280, 350)
     67         self.sizeOffset = [a - b for a, b in zip(self.size, self.canvasSize)]
     68         self.canvas = QPixmap(*self.canvasSize)
     69         self.canvas.fill(Qt.black)
     70         self.tempCanvas = QPixmap()
     71         self.lastPoint = QPoint()
     72         self.endPoint = QPoint()
     73         self.isDrawing = False
     74         self.penSize = 15
     75 
     76         self.initUI()
     77 
     78     def initUI(self):
     79         self.penSizeLabel = QLabel(u'画笔粗细')
     80         self.penSizeSpinBox = QSpinBox()
     81         self.penSizeSpinBox.setValue(self.penSize)
     82         self.penSizeSpinBox.valueChanged.connect(self.penSizeSpinBox_valueChanged)
     83         self.penSizeSpinBox.setFixedWidth(80)
     84 
     85         self.clearButton = QPushButton(u'清空')
     86         self.clearButton.setFixedWidth(80)
     87         self.clearButton.clicked.connect(self.clearPainter)
     88 
     89         self.closeButton = QPushButton(u'关闭')
     90         self.closeButton.setFixedWidth(80)
     91         self.closeButton.clicked.connect(self.close)
     92 
     93         self.inputLabel = QLabel(self)
     94         self.inputLabel.setFixedSize(80, 200)
     95         self.inputLabel.setAutoFillBackground(True)
     96         self.inputLabel.setAlignment(Qt.AlignCenter)
     97         self.inputLabel.setStyleSheet('''QLabel{background:#F76677;border-radius:5px;font-size:60px;font-weight:bolder;}''')
     98 
     99         mainLayout = QVBoxLayout(self)
    100 
    101         toolbarLayout = QGridLayout()
    102         # toolbarLayout.setSpacing(20)
    103         toolbarLayout.addWidget(self.penSizeLabel, 0, 0, 1, 1)
    104         toolbarLayout.addWidget(self.penSizeSpinBox, 1, 0, 1, 1)
    105         toolbarLayout.addWidget(self.clearButton, 2, 0, 1, 1)
    106         toolbarLayout.addWidget(self.closeButton, 3, 0, 1, 1)
    107         toolbarLayout.addWidget(self.inputLabel, 4, 0, 1, 1)
    108 
    109         toolbarLayout.setAlignment(Qt.AlignLeft)
    110 
    111         mainLayout.addLayout(toolbarLayout)
    112         mainLayout.addStretch(1)
    113 
    114     def penSizeSpinBox_valueChanged(self):
    115         # 设置画笔粗细
    116         self.penSize = self.penSizeSpinBox.value()
    117 
    118     def paintEvent(self, event):
    119         pp = QPainter(self.canvas)
    120         pen = QPen(QColor(255, 255, 255), self.penSize)
    121         pp.setPen(pen)
    122         if self.lastPoint != self.endPoint:
    123             pp.drawLine(self.lastPoint - QPoint(*self.sizeOffset), self.endPoint - QPoint(*self.sizeOffset))
    124         painter = QPainter(self)
    125         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
    126         self.lastPoint = self.endPoint
    127 
    128     def clearPainter(self):
    129         print('clear...')
    130         self.canvas.fill(Qt.black)
    131         painter = QPainter(self)
    132         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
    133         self.lastPoint = self.endPoint
    134         self.update()
    135         self.inputLabel.clear()
    136 
    137     def mousePressEvent(self, event):
    138         # 按下左键
    139         if event.button() == Qt.LeftButton:
    140             self.lastPoint = event.pos()
    141             self.endPoint = self.lastPoint
    142             self.isDrawing = True
    143 
    144     def mouseMoveEvent(self, event):
    145         if self.isDrawing:
    146             self.update()
    147             self.endPoint = event.pos()
    148 
    149     def mouseReleaseEvent(self, event):
    150         if event.button() == Qt.LeftButton:
    151             self.isDrawing = False
    152             self.endPoint = event.pos()
    153             self.update()
    154             self.canvas.toImage().save('input.png')
    155             input = predict('input.png')
    156             self.inputLabel.setText(str(input))
    157             print('你输入的是{}'.format(input))
    158 
    159 
    160 if __name__ == '__main__':
    161     app = QApplication.instance()
    162     if not app:
    163         app = QApplication(sys.argv)
    164     SimpleDrawingBoard.showWin()
    165     app.exec_()

    上面引入前一章训练好的模型,位于不同的文件夹内,需要加上这一行代码:

    sys.path.append(r'../ml/torch')
    

    看下运行效果:

    上面写了两个数字,识别输出正确!

    helloworld例子比较枯燥,通过动手参与与AI交互增强信心乐趣,信心是一步步建立起来的,而大的突破亦是如此,后面会持续围绕简单的例子,深入发掘AI的乐趣与应用场景。

  • 相关阅读:
    【转载】[SMS]SMS内容的7bit和UCS2编码方式简介
    【转载】两篇关于字符编码的博文
    【IRA/GSM/UCS2】the difference of IRA/GSM/UCS2 character set
    【LTE】LTE中SINR的理解
    【LTE】为什么使用SNR来表征信道质量,而并不用RSRQ?这两者的区别是什么?
    【C++】C++为什么要引入引用这个复合类型?
    【HTML55】HTML5与CSS3基础教程
    python 三种单例模式
    python3.10 新增的 match case 语句
    Python pyqt5简单的图片过滤工具
  • 原文地址:https://www.cnblogs.com/migomiddle/p/11976094.html
Copyright © 2011-2022 走看看