0
点赞
收藏
分享

微信扫一扫

通俗举例讲解评价指标—recall、precision、ndcg,hit,auc

这篇文章整理自己在学习过程中所用到的评价指标,当然先陆续更新几个常用的,其他的后续再补上。言归正传,用大白话讲解这些,相信大家都能看懂!

直接步入正文:

大家都知道,我们在训练网络时,最后的评价标准就是通过各种评价指标的对比来直观得到此模型的效果好与不好,所以,这里也是最后的重要一步

为了便于理解,我们需要举个例子,便于在后续的理解,所以此处是开头!!请细细品味:

此处假设我们的网络由2000user1万item,假设我们已经把训练集的数据训练了几个epoch了,下面进入测试集中,看看我们模型训练的效果怎么样

users_to_test = list(data_generator.test_set.keys())   # 测试集中列表的索引(用户个数2000)

ret = test(model, users_to_test, drop_flag=False)

.......................

rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()

通过model.rating我们可以得到一个2000*1W的评分矩阵,保存的是每个用户对每个物品的评分

接下来就是遍历测试集中的样本了,针对每个用户进行对比

首先我们会有一个通用的参数Ks(大多数都这么定义一般取值为20、40、60、80、100),这里取值的意思是Ks数代表我们需要推荐前Ks个评分最高的物品, 这里Ks我们取20

我们会步入rating = x[0]、u = x[1],这个代表找到当前用户,以及当前用户对所有物品的评分,

test_items = list(all_items - set(training_items))  此段代表当前用户对所有的物品评分减去训练集中物品评分,(也就是说训练集中的物品数据只参与训练不参与评分)

然后我们取出该用户对于物品前20个评分较高的,然后我们依次遍历,与测试集中该用户所对应的物品进行对比,会得到0-1组成的20个数(包含则置为1,否则置为0,也就是说检测我们所推荐的20个物品在不在测试集的正样本中)

综上所述我们得到了评分最高的20物品,以及并判别了它们是否在测试集中,用0-1表示

一、precision指标

我们上面得到的矩阵,直接进行0-1的求和取均值即可:代码如下

def precision_at_k(r, k):
    """Score is precision @ k
    Relevance is binary (nonzero is relevant).
    Returns:
        Precision @ k
    Raises:
        ValueError: len(r) must be >= k
    """
    assert k >= 1
    r = np.asarray(r)[:k]    #在我们得到的[0-1]矩阵中  切取前20个
    return np.mean(r)        #对于这些0-1取均值

二、recall指标

我们上面得到的矩阵,求和得到预测正确的物品数然后再除测试集中该用户对应的物品数

def recall_at_k(r, k, all_pos_num):
    # if all_pos_num == 0:
    #     return 0
    r = np.asfarray(r)[:k]
    return np.sum(r) / all_pos_num    #前20中存在正确的物品数/测试集中对用的所有物品数

三、NDCG指标

这个三言两语说不太清,道道有点多,但很实用,请自行看链接

NDCG的理解_Luna2137的博客-CSDN博客_ndcg

def dcg_at_k(r, k, method=1):
    r = np.asfarray(r)[:k]  
    if r.size:
        if method == 0:
            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
        elif method == 1:
            return np.sum(r / np.log2(np.arange(2, r.size + 2)))
        else:
            raise ValueError('method must be 0 or 1.')
    return 0.


def ndcg_at_k(r, k, ground_truth, method=1):
    """Score is normalized discounted cumulative gain (ndcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Returns:
        Normalized discounted cumulative gain

        Low but correct defination
    """
    GT = set(ground_truth)
    if len(GT) > k :
        sent_list = [1.0] * k   #如果测试集的样本数大于 我们取得的前K个物品数 (得到20个1)
    else:
        sent_list = [1.0]*len(GT) + [0.0]*(k-len(GT))  #如果测试集的样本数小于 我们取得的前K个物品数   k个1后面补零(凑够样本数)
    dcg_max = dcg_at_k(sent_list, k, method)  
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k, method) / dcg_max

四、hit指标

我们上面得到的矩阵,表示我们在遍历前20个物品时,是否存在测试集中出现的物品,通俗点说就是我们预测推荐的20个物品中,有没有对的物品

def hit_at_k(r, k):
    r = np.array(r)[:k]
    if np.sum(r) > 0:
        return 1.
    else:
        return 0.
举报

相关推荐

0 条评论