0
点赞
收藏
分享

微信扫一扫

BartForConditionalGeneration的使用细节


1. BartForConditionalGeneration 类的各个参数

先聊聊输入到 BartForConditionalGeneration 类的各个参数是什么意思?这个部分是比较重要的。

decoder_input_ids

是必须要以 ​​<s>​​ 开头的。这个参数可以自己生成然后传入到模型中,也可以交由代码自己生成(一般会根据label右移一位再补0)

  • case 1: 直接传入

    此时的 ​​​decoder_input_ids​​ 如下:
  • case 2: 由labels 右移一位生成


    ​​​decoder_start_token_id​​​ 的值为2(一般需要指定),对应的token是​​</s>​​​。最后返回 ​​shifted_input_ids​​​ 作为 ​​decoder_input_ids​
  • 需要注意 ​​labels​​​ 的起始是没有 ​​<s>​​ token的。

细心的读者会发现这两种方法得到的 ​​decoder_input_ids​​​ 是不同的(就是因为这个 ​​decoder_start_token_id​​ 值的不同)。

2.为啥下面两种方法计算的loss值不相同?

就是因为上述说的 这个 ​​decoder_input_ids​​​ 值的原因,以及​​add_special_tokens​​参数的原因。

BartForConditionalGeneration的使用细节_默认值


BartForConditionalGeneration的使用细节_默认值_02

​generate​​​函数中的 ​​max_length​​ 有什么作用?

max_length : The maximum length of the sequence to be generated. 【将会被生成的句子的最大长度。】

有如下两段代码,很好奇,为啥 ​​max_length​​ 会对模型生成结果产生影响?理应来说不是只会限制生成长度,怎么在值不同的时候变了内容呢?

BartForConditionalGeneration的使用细节_人工智能_03


BartForConditionalGeneration的使用细节_人工智能_04

出现这个问题的原因是:max_length 参数是有默认值的,而且默认值较小,所以就会导致生成的结果很短。

BartForConditionalGeneration的使用细节_人工智能_05


举报

相关推荐

0 条评论