Skip to content

Commit

Permalink
savez knn_data.npz
Browse files Browse the repository at this point in the history
  • Loading branch information
makelove committed Aug 8, 2017
1 parent 2797421 commit cf5048a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
8 changes: 6 additions & 2 deletions ch46-机器学习-K近邻/2-使用kNN对手写数字OCR.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@
accuracy = correct * 100.0 / result.size
print('准确率', accuracy) # 准确率91%

#
''''''
# save the data
np.savez('knn_data.npz', train=train, train_labels=train_labels)
np.savez('knn_data.npz', train=train, train_labels=train_labels,test=test,test_labels=test_labels)
# Now load the data
with np.load('knn_data.npz') as data:
print(data.files)
train = data['train']
train_labels = data['train_labels']
test = data['test']
test_labels = data['test_labels']


#TODO 怎样预测数字?
# knn.predict?
# Docstring: predict(samples[, results[, flags]]) -> retval, results
24 changes: 24 additions & 0 deletions ch46-机器学习-K近邻/预测手写数字1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# @Time : 2017/8/8 11:57
# @Author : play4fun
# @File : 预测手写数字1.py
# @Software: PyCharm

"""
预测手写数字1.py:
"""

import numpy as np
import cv2
from matplotlib import pyplot as plt

with np.load('knn_data.npz') as data:
print(data.files)
train = data['train']
train_labels = data['train_labels']
test = data['test']
test_labels = data['test_labels']

knn = cv2.ml.KNearest_create()
knn.train(train, cv2.ml.ROW_SAMPLE, train_labels)
ret, result, neighbours, dist = knn.findNearest(test, k=5)

0 comments on commit cf5048a

Please sign in to comment.