0
点赞
收藏
分享

微信扫一扫

MindSpore报ValueError:`x rank` in `NLLLoss` should be int and must in [1, 2], but got `4` with type

奔跑的酆 2022-04-01 阅读 73

1 报错描述

1.1 系统环境

Hardware Environment(Ascend/GPU/CPU): GPU
Software Environment:
– MindSpore version (source or binary): 1.6.0
– Python version (e.g., Python 3.7.5): 3.7.6
– OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic
– GCC/Compiler version (if compiled from source):

1.2 基本信息

1.2.1 脚本

根据标杆torch算子NLLLoss的用例,编写输入为(N, C, d1, d2)的用例,脚本如下:

 01 loss = nn.NLLLoss()
 02 input = torch.randn(5, 4, 8, 8)
 03 m = nn.LogSoftmax(dim=1)
 04 target = torch.empty(5, 8, 8, dtype=torch.long).random_(0, 4)
 05 loss = loss(m(input), target)
 06 print('torch_loss',loss)
 07
 08 m = mn.LogSoftmax(axis=1)
 09 loss = ops.NLLLoss()
 10 input = Tensor(np.random.randn(5, 4,8,8), mindspore.float32)
 11 labels = Tensor([1, 0,1, 1], mindspore.int32)
 12 weight = Tensor(np.random.rand(5,8,8), mindspore.float32)
 13 loss, weight = loss(m(input), labels, weight)
 14 print('mindspore_loss:',loss)

1.2.2 报错

这里报错信息如下:

Traceback (most recent call last):
 File "demo.py", line 13, in <module>
​    loss, weight = loss(m(input), labels, weight)
…
File "/lib/python3.7/site-packages/mindspore/_checkparam.py", line 238, in check_int
  return check_number(arg_value, value, rel, int, arg_name, prim_name)
 File " /lib/python3.7/site-packages/mindspore/_checkparam.py", line 168, in check_number
  raise type_except(f'{prim_info} should be {arg_type.__name__} and must {rel_str}, '
ValueError: `x rank` in `NLLLoss` should be int and must in [1, 2], but got `4` with type `int

原因分析

​ 在MindSpore 1.6版本,利用对标算子的用例编写输入为(N, C, d1, d2)的用例。先看报错信息,在ValueError中,写到x rank in NLLLos should be int and must in [1, 2], but got 4 with type int,意思是传的NLLLoss的x_rank参数应该为int,而且应该在[1, 2]之间,但是你传进去的是int类型的4,由报错行数line13,检查传入数据可知我们传入了为4维的(5, 4, 8, 8),而torch传入该类型数据能支持,这是由于目前MindSpore暂不支持(N,C,d1,d2,…,dK) with K≥1类型,这点在旧版本的PyTorch与MindSpore API映射对比中描述为功能一致,此处文档有误(见下图),目前标杆算子和MindSpore框架支持的NLLLoss算子功能存在一定差异,MindSpore目前只支持 shape为(N,C)的数据(如下图所示)。

​ 在新版本中的描述中已经进行修改。参考链接为:https://www.mindspore.cn/docs/migration_guide/zh-CN/r1.5/api_mapping/pytorch_diff/NLLLoss.html。

2 解决方法

​ 基于上面已知的原因,该算子存在部分输入不支持的情况,目前需要用户自己封装,该操作会对用户带来一定的困扰,我们将在后续统一考虑这种需求。

3 总结

定位报错问题的步骤:

1、找到报错的用户代码行:loss, weight = loss(m(input), labels, weight)

2、 根据日志报错信息中的关键字,缩小分析问题的范围:loss, weight = loss(m(input), labels, weight)

3、需要重点关注变量定义、初始化的正确性。

4 参考文档

4.1 NLLLoss介绍

举报

相关推荐

Log4j 1.x如何升级到Log4j 2.x

0 条评论