本章承接上一篇的手写数字识别,利用训练好的模型,结合pyqt画板,实现简易手写输入法,为"hello world"例子增添乐趣。
pyqt是开发图形界面的框架,可以百度查找相关资料了解安装及基础方法,我搭建的环境是pycharm+pyqt5+qtdesigner,配置好之后的界面长这样:
在左边的项目中右键某个文件,也可以打开qt菜单
具体怎么画界面不展开了,直接看下代码:
1 # coding: utf-8 2 from PyQt5.QtWidgets import * 3 from PyQt5.QtGui import * 4 from PyQt5.QtCore import * 5 import sys 6 sys.path.append(r'../ml/torch') 7 from digit_recog import Net 8 import torch 9 import os 10 import numpy as np 11 import matplotlib.pyplot as plt 12 from PIL import Image 13 14 15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 net = Net().to(device) 17 # 加载参数 18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth')) 19 # 参数加载到指定模型 20 net.load_state_dict(nn_state) 21 net.eval() 22 23 24 def predict(img): 25 # 读取图片并重设尺寸 26 image = Image.open(img).resize((28, 28)) 27 # 灰度图 28 gray_image = image.convert('L') 29 # plt.imshow(gray_image) 30 # plt.show() 31 # 图片数据处理 32 im_data = np.array(gray_image) 33 im_data = torch.from_numpy(im_data).float() 34 im_data = im_data.view(1, 1, 28, 28) 35 # 神经网络运算 36 outputs = net(im_data) 37 # 取最大预测值 38 _, pred = torch.max(outputs, 1) 39 return pred.item() 40 41 42 class SimpleDrawingBoard(QWidget): 43 win = '' 44 wins = [] 45 46 @classmethod 47 def showWin(cls): 48 # 聚焦到已有窗口 49 if not cls.win: 50 cls.win = cls() 51 cls.win.show() 52 else: 53 cls.win.activateWindow() 54 55 def __init__(self, parent=None): 56 super(SimpleDrawingBoard, self).__init__(parent) 57 58 self.setWindowTitle(u"手写数字识别") 59 self.setWindowFlags(Qt.WindowStaysOnTopHint) 60 self.size = (400, 350) 61 self.resize(*self.size) 62 self.setWindowFlag(Qt.FramelessWindowHint) # 隐藏边框 63 # self.setWindowOpacity(0.9) # 设置窗口透明度 64 # self.setAttribute(Qt.WA_TranslucentBackground) # 设置窗口背景透明 65 66 self.canvasSize = (280, 350) 67 self.sizeOffset = [a - b for a, b in zip(self.size, self.canvasSize)] 68 self.canvas = QPixmap(*self.canvasSize) 69 self.canvas.fill(Qt.black) 70 self.tempCanvas = QPixmap() 71 self.lastPoint = QPoint() 72 self.endPoint = QPoint() 73 self.isDrawing = False 74 self.penSize = 15 75 76 self.initUI() 77 78 def initUI(self): 79 self.penSizeLabel = QLabel(u'画笔粗细') 80 self.penSizeSpinBox = QSpinBox() 81 self.penSizeSpinBox.setValue(self.penSize) 82 self.penSizeSpinBox.valueChanged.connect(self.penSizeSpinBox_valueChanged) 83 self.penSizeSpinBox.setFixedWidth(80) 84 85 self.clearButton = QPushButton(u'清空') 86 self.clearButton.setFixedWidth(80) 87 self.clearButton.clicked.connect(self.clearPainter) 88 89 self.closeButton = QPushButton(u'关闭') 90 self.closeButton.setFixedWidth(80) 91 self.closeButton.clicked.connect(self.close) 92 93 self.inputLabel = QLabel(self) 94 self.inputLabel.setFixedSize(80, 200) 95 self.inputLabel.setAutoFillBackground(True) 96 self.inputLabel.setAlignment(Qt.AlignCenter) 97 self.inputLabel.setStyleSheet('''QLabel{background:#F76677;border-radius:5px;font-size:60px;font-weight:bolder;}''') 98 99 mainLayout = QVBoxLayout(self) 100 101 toolbarLayout = QGridLayout() 102 # toolbarLayout.setSpacing(20) 103 toolbarLayout.addWidget(self.penSizeLabel, 0, 0, 1, 1) 104 toolbarLayout.addWidget(self.penSizeSpinBox, 1, 0, 1, 1) 105 toolbarLayout.addWidget(self.clearButton, 2, 0, 1, 1) 106 toolbarLayout.addWidget(self.closeButton, 3, 0, 1, 1) 107 toolbarLayout.addWidget(self.inputLabel, 4, 0, 1, 1) 108 109 toolbarLayout.setAlignment(Qt.AlignLeft) 110 111 mainLayout.addLayout(toolbarLayout) 112 mainLayout.addStretch(1) 113 114 def penSizeSpinBox_valueChanged(self): 115 # 设置画笔粗细 116 self.penSize = self.penSizeSpinBox.value() 117 118 def paintEvent(self, event): 119 pp = QPainter(self.canvas) 120 pen = QPen(QColor(255, 255, 255), self.penSize) 121 pp.setPen(pen) 122 if self.lastPoint != self.endPoint: 123 pp.drawLine(self.lastPoint - QPoint(*self.sizeOffset), self.endPoint - QPoint(*self.sizeOffset)) 124 painter = QPainter(self) 125 painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas) 126 self.lastPoint = self.endPoint 127 128 def clearPainter(self): 129 print('clear...') 130 self.canvas.fill(Qt.black) 131 painter = QPainter(self) 132 painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas) 133 self.lastPoint = self.endPoint 134 self.update() 135 self.inputLabel.clear() 136 137 def mousePressEvent(self, event): 138 # 按下左键 139 if event.button() == Qt.LeftButton: 140 self.lastPoint = event.pos() 141 self.endPoint = self.lastPoint 142 self.isDrawing = True 143 144 def mouseMoveEvent(self, event): 145 if self.isDrawing: 146 self.update() 147 self.endPoint = event.pos() 148 149 def mouseReleaseEvent(self, event): 150 if event.button() == Qt.LeftButton: 151 self.isDrawing = False 152 self.endPoint = event.pos() 153 self.update() 154 self.canvas.toImage().save('input.png') 155 input = predict('input.png') 156 self.inputLabel.setText(str(input)) 157 print('你输入的是{}'.format(input)) 158 159 160 if __name__ == '__main__': 161 app = QApplication.instance() 162 if not app: 163 app = QApplication(sys.argv) 164 SimpleDrawingBoard.showWin() 165 app.exec_()
上面引入前一章训练好的模型,位于不同的文件夹内,需要加上这一行代码:
sys.path.append(r'../ml/torch')
看下运行效果:
上面写了两个数字,识别输出正确!
helloworld例子比较枯燥,通过动手参与与AI交互增强信心乐趣,信心是一步步建立起来的,而大的突破亦是如此,后面会持续围绕简单的例子,深入发掘AI的乐趣与应用场景。