def predict_labels(self, dists, k=1):
print("predict_labels")
"""
Given a matrix of distances between test points and training points,
predict a label for each test point.
Inputs:
- dists: A numpy array of shape (num_test, num_train) where dists[i, j]
gives the distance betwen the ith test point and the jth training point.
Returns:
- y: A numpy array of shape (num_test,) containing predicted labels for the
test data, where y[i] is the predicted label for the test point X[i].
"""
num_test = dists.shape[0]
#获取测试数据的个数
print(num_test)
y_pred = np.zeros(num_test)
#利用测试集的个数创建一个0的向量
for i in range(num_test):
# A list of length k storing the labels of the k nearest neighbors to
# the ith test point.
closest_y = []
#这个数组保存与该测试集相似的图片所属的类别
#########################################################################
# TODO: #
# Use the distance matrix to find the k nearest neighbors of the ith #
# testing point, and use self.y_train to find the labels of these #
# neighbors. Store these labels in closest_y. #
# Hint: Look up the function numpy.argsort. #
#########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
idx = np.argsort(dists[i])[:k]#idx是比较最小的下标
#找到与该测试图片相似的最小的k个位置
closest_y = self.y_train[idx]
#通过位置找到与该测试图片相似的训练图片所属的类别
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
#########################################################################
# TODO: #
# Now that you have found the labels of the k nearest neighbors, you #
# need to find the most common label in the list closest_y of labels. #
# Store this label in y_pred[i]. Break ties by choosing the smaller #
# label. #
#########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
#self.y_train的值是属于哪个分类
counter = np.zeros(np.max(self.y_train) + 1) # (C,)
'''for j in closest_y:
counter[j] += 1'''
np.add.at(counter, closest_y, 1)
y_pred[i] = np.argmax(counter)
#看哪个类别多,就预测哪个类别 投票机制
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
return y_pred```
开始感觉很难 其实知道后 很简单