Skip to content

Commit

Permalink
update: assinments1_base
Browse files Browse the repository at this point in the history
  • Loading branch information
黄重庆 committed Dec 18, 2021
1 parent de2b346 commit 6b4538b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Assignments/Assignment1/cs231n/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from cs231n.classifiers.k_nearest_neighbor import *
from cs231n.classifiers.k_nearest_neighbor import * # 等同于import KNearestNeighbor (即调用class KNearestNeighbor(object):)
from cs231n.classifiers.linear_classifier import *
# 查找路径
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from past.builtins import xrange


# KNN类名
class KNearestNeighbor(object):
""" a kNN classifier with L2 distance """

Expand All @@ -25,6 +25,7 @@ def train(self, X, y):
self.X_train = X
self.y_train = y

# 预测
def predict(self, X, k=1, num_loops=0):
"""
Predict labels for test data using this classifier.
Expand All @@ -51,6 +52,7 @@ def predict(self, X, k=1, num_loops=0):

return self.predict_labels(dists, k=k)

# 2循环:用了两个循环的算法实现(L2距离)
def compute_distances_two_loops(self, X):
"""
Compute the distance between each test point in X and each training point
Expand Down Expand Up @@ -83,6 +85,7 @@ def compute_distances_two_loops(self, X):
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
return dists

# 1循环:用了一个循环的算法实现(L2距离)# 使用了广播机制,省去了一个循环.
def compute_distances_one_loop(self, X):
"""
Compute the distance between each test point in X and each training point
Expand All @@ -107,6 +110,7 @@ def compute_distances_one_loop(self, X):
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
return dists

# 无循环:不用循环的算法实现(L2距离)
def compute_distances_no_loops(self, X):
"""
Compute the distance between each test point in X and each training point
Expand Down Expand Up @@ -137,6 +141,7 @@ def compute_distances_no_loops(self, X):
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
return dists

# 预测标签
def predict_labels(self, dists, k=1):
"""
Given a matrix of distances between test points and training points,
Expand Down
33 changes: 24 additions & 9 deletions Assignments/Assignment1/cs231n/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,39 @@
import platform

def load_pickle(f):
version = platform.python_version_tuple()
version = platform.python_version_tuple() # 不同version不同处理方式
if version[0] == '2':
return pickle.load(f)
elif version[0] == '3':
return pickle.load(f, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))

# 加载每个文件pkl文件
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = load_pickle(f)
datadict = load_pickle(f) # 调用函数
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
Y = np.array(Y)
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float") # transpose成BxWxHxC
Y = np.array(Y) # 转成np格式
return X, Y

# 加载CIFAR10
# 加载CIFAR10========================================
def load_CIFAR10(ROOT):
""" load all of cifar """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
X, Y = load_CIFAR_batch(f)
X, Y = load_CIFAR_batch(f) # data(BxWxHxC), label
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Xtr = np.concatenate(xs) # train训练集
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) # test测试集
return Xtr, Ytr, Xte, Yte # 返回


def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000,
Expand Down Expand Up @@ -261,3 +262,17 @@ def load_imagenet_val(num=None):
X = X[:num]
y = y[:num]
return X, y, class_names

# 测试
if __name__ == "__main__":
print ('测试(hcq)')
cifar10_dir = 'cs231n/datasets/cifar-10-batches-py' # 数据集目录
# Cleaning up variables to prevent loading data multiple times (which may cause memory issue)
try:
del X_train, y_train
del X_test, y_test
print('Clear previously loaded data.')
except:
pass
# 调用
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
67 changes: 40 additions & 27 deletions Assignments/Assignment1/knn.ipynb

Large diffs are not rendered by default.

0 comments on commit 6b4538b

Please sign in to comment.