前文基于训练数据集完成了模型训练,下面对模型在测试数据集上的准确率进行评估。
核心代码如下:
model.eval() #将模型设置为评估模式
correct = 0 #用于记录正确预测的数量
total = 0 #用于记录总的样本数量
with torch.no_grad(): #暂时禁用梯度计算,在评估模型时使用可以有效减少内存消耗
for inputs,labels in test_loader:
inputs, labels = inputs.to(device),labels.to(device)
outputs = model(inputs)
_,predicted = torch.max(outputs.data,1)
total+=labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct/total*100
print('Accuray:{:.2f}%'.format(accuracy))
swanlab.log({"test_acc":accuracy})
循环遍历test_loader,对测试数据集中的每一批数据执行相关操作。
outputs = model(inputs)
:通过模型前向传播得到输出。_, predicted = torch.max(outputs.data, 1)
:使用torch.max
函数找到输出中最大值的索引,即模型的预测结果。total += labels.size(0)
:更新总样本数量。correct += (predicted == labels).sum().item()
:计算这一批数据中预测正确的数量,并累加到correct
中。
执行完成后,在swanlab网站就可以查看到本次训练的执行结果,以可视化图表形式展示: