1 Longformer
之前做了BART中文摘要生成,但是因为项目需求是中文长文本摘要生成,因此在此采用Longformer完成中文摘要生成(实际用的是LED,Longformer基础上添加了解码器),11G显存长度可以到8K,非常友好。短文本上虽然比不上BART,不过这并不重要。
1.1 Longformer结构
LED结构与BART类似,只不过多了global attention,因为LED没有中文预训练模型,但是我们有BART呀,这也给出了BART权重转到LED的脚本,因此这次我们就采用BART的权重来给LED作为初始化。
2 BART权重到LED权重初始化
2.1 设置输入输出长度
import copy
 import logging
  
 from transformers import LEDConfig, LEDForConditionalGeneration, BertTokenizer
 from transformers import BartForConditionalGeneration
  
 logger = logging.getLogger("longformer-chinese")
 logging.basicConfig(level=logging.INFO)
  
 max_encoder_position_embeddings = 4096 #设置最大输入文本长度
 max_decoder_position_embeddings = 1024 #设置解码器输入最大长度,也就是生成最大长度
 2.2 配置LED
此处把LED的配置设置好,与BART-base-chinese设置成相同的,同时将该配置作为LED模型的配置初始化。
led_config = LEDConfig(vocab_size=21128,
         max_encoder_position_embeddings=max_encoder_position_embeddings,
         max_decoder_position_embeddings=max_decoder_position_embeddings,
         encoder_layers=6,
         encoder_ffn_dim=3072,
         encoder_attention_heads=12,
         decoder_layers=6,
         decoder_ffn_dim=3072,
         decoder_attention_heads=12,
         encoder_layerdrop=0.0,
         decoder_layerdrop=0.0,
         use_cache=True,
         is_encoder_decoder=True,
         activation_function="gelu",
         d_model=768,
         dropout=0.1,
         attention_dropout=0.0,
         activation_dropout=0.0,
         init_std=0.02,
         decoder_start_token_id=102,
         classifier_dropout=0.0,
         pad_token_id=0,
         bos_token_id=101,
         eos_token_id=102,
         attention_window= 512,)
 led_model = LEDForConditionalGeneration(led_config)
 2.3 加载BART模型
bart_model = BartForConditionalGeneration.from_pretrained(r'E:\Project\NLP\long-document\bart-base')
 tokenizer = BertTokenizer.from_pretrained(r'E:\Project\NLP\long-document\bart-base')
 2.4 BART权重复制到LED
current_max_pos,embed_size = bart_model.model.encoder.embed_positions.weight.shape
 new_encoder_pos_embed = bart_model.model.encoder.embed_positions.weight.new_empty(max_encoder_position_embeddings,embed_size)
  
 k=0
 step = current_max_pos-2
  
 encoder_position_embeddings= bart_model.model.encoder.embed_positions.weight[2:]
 while k<max_encoder_position_embeddings:
  
     new_encoder_pos_embed[k:(k+step)] =encoder_position_embeddings
     k+=step
 led_model.base_model.encoder.embed_positions.weight.data = new_encoder_pos_embed
  
 current_max_pos,embed_size = bart_model.model.decoder.embed_positions.weight.shape
 new_decoder_pos_embed = bart_model.model.decoder.embed_positions.weight.new_empty(max_decoder_position_embeddings,embed_size)
  
 k=0
 step = current_max_pos-2
  
 decoder_position_embeddings= bart_model.model.decoder.embed_positions.weight[2:]
 while k<max_decoder_position_embeddings:
  
     new_decoder_pos_embed[k:(k+step)] =decoder_position_embeddings
     k+=step
 led_model.base_model.decoder.embed_positions.weight.data = new_decoder_pos_embed
  
 for i,(bart_encoder_layer, led_encoder_layer) in enumerate(zip(bart_model.model.encoder.layers, led_model.base_model.encoder.layers)):
     led_encoder_layer.self_attn.longformer_self_attn.key = bart_encoder_layer.self_attn.k_proj
     led_encoder_layer.self_attn.longformer_self_attn.query = bart_encoder_layer.self_attn.q_proj
     led_encoder_layer.self_attn.longformer_self_attn.value = bart_encoder_layer.self_attn.v_proj
     led_encoder_layer.self_attn.longformer_self_attn.key_global = copy.deepcopy(bart_encoder_layer.self_attn.k_proj)
     led_encoder_layer.self_attn.longformer_self_attn.query_global = copy.deepcopy(bart_encoder_layer.self_attn.q_proj)
     led_encoder_layer.self_attn.longformer_self_attn.value_global = copy.deepcopy(bart_encoder_layer.self_attn.v_proj)
     led_encoder_layer.self_attn.output = bart_encoder_layer.self_attn.out_proj
     led_encoder_layer.self_attn_layer_norm = bart_encoder_layer.self_attn_layer_norm
     led_encoder_layer.fc1 = bart_encoder_layer.fc1
     led_encoder_layer.fc2 = bart_encoder_layer.fc2
     led_encoder_layer.final_layer_norm = bart_encoder_layer.final_layer_norm
  
 for i,(bart_decoder_layer, led_decoder_layer) in enumerate(zip(bart_model.model.decoder.layers, led_model.base_model.decoder.layers)):
     led_decoder_layer.self_attn.k_proj = bart_decoder_layer.self_attn.k_proj
     led_decoder_layer.self_attn.q_proj = bart_decoder_layer.self_attn.q_proj
     led_decoder_layer.self_attn.v_proj = bart_decoder_layer.self_attn.v_proj
     led_decoder_layer.self_attn.out_proj = bart_decoder_layer.self_attn.out_proj
     led_decoder_layer.self_attn_layer_norm = bart_decoder_layer.self_attn_layer_norm
     led_decoder_layer.encoder_attn.k_proj = bart_decoder_layer.encoder_attn.k_proj
     led_decoder_layer.encoder_attn.q_proj = bart_decoder_layer.encoder_attn.q_proj
     led_decoder_layer.encoder_attn.v_proj = bart_decoder_layer.encoder_attn.v_proj
     led_decoder_layer.encoder_attn_layer_norm = bart_decoder_layer.encoder_attn_layer_norm
  
     led_decoder_layer.fc1 = bart_decoder_layer.fc1
     led_decoder_layer.fc2 = bart_decoder_layer.fc2
     led_decoder_layer.final_layer_norm = bart_decoder_layer.final_layer_norm
  
 led_model.lm_head = bart_model.lm_head
 2.5 保存LED权重
logger.info("convert bart-chinese to led success")
 led_model.save_pretrained(r'E:\Project\NLP\long-document\converted_model')
 tokenizer.save_pretrained(r'E:\Project\NLP\long-document\converted_model')
 3 完整代码
import copy
 import logging
  
 from transformers import LEDConfig, LEDForConditionalGeneration, BertTokenizer
 from transformers import BartForConditionalGeneration
  
 logger = logging.getLogger("longformer-chinese")
 logging.basicConfig(level=logging.INFO)
  
 max_encoder_position_embeddings = 4096
 max_decoder_position_embeddings = 1024
  
 led_config = LEDConfig(vocab_size=21128,
         max_encoder_position_embeddings=max_encoder_position_embeddings,
         max_decoder_position_embeddings=max_decoder_position_embeddings,
         encoder_layers=6,
         encoder_ffn_dim=3072,
         encoder_attention_heads=12,
         decoder_layers=6,
         decoder_ffn_dim=3072,
         decoder_attention_heads=12,
         encoder_layerdrop=0.0,
         decoder_layerdrop=0.0,
         use_cache=True,
         is_encoder_decoder=True,
         activation_function="gelu",
         d_model=768,
         dropout=0.1,
         attention_dropout=0.0,
         activation_dropout=0.0,
         init_std=0.02,
         decoder_start_token_id=102,
         classifier_dropout=0.0,
         pad_token_id=0,
         bos_token_id=101,
         eos_token_id=102,
         attention_window= 512,)
 led_model = LEDForConditionalGeneration(led_config)
 bart_model = BartForConditionalGeneration.from_pretrained(r'E:\Project\NLP\long-document\bart-base')
 tokenizer = BertTokenizer.from_pretrained(r'E:\Project\NLP\long-document\bart-base')
  
 current_max_pos,embed_size = bart_model.model.encoder.embed_positions.weight.shape
 new_encoder_pos_embed = bart_model.model.encoder.embed_positions.weight.new_empty(max_encoder_position_embeddings,embed_size)
  
 k=0
 step = current_max_pos-2
 # new_encoder_pos_embed[0]=bart_model.model.encoder.embed_positions.weight[0]
 encoder_position_embeddings= bart_model.model.encoder.embed_positions.weight[2:]
 while k<max_encoder_position_embeddings:
  
     new_encoder_pos_embed[k:(k+step)] =encoder_position_embeddings
     k+=step
 led_model.base_model.encoder.embed_positions.weight.data = new_encoder_pos_embed
  
 current_max_pos,embed_size = bart_model.model.decoder.embed_positions.weight.shape
 new_decoder_pos_embed = bart_model.model.decoder.embed_positions.weight.new_empty(max_decoder_position_embeddings,embed_size)
  
 k=0
 step = current_max_pos-2
 # new_encoder_pos_embed[0]=bart_model.model.encoder.embed_positions.weight[0]
 decoder_position_embeddings= bart_model.model.decoder.embed_positions.weight[2:]
 while k<max_decoder_position_embeddings:
  
     new_decoder_pos_embed[k:(k+step)] =decoder_position_embeddings
     k+=step
 led_model.base_model.decoder.embed_positions.weight.data = new_decoder_pos_embed
  
 for i,(bart_encoder_layer, led_encoder_layer) in enumerate(zip(bart_model.model.encoder.layers, led_model.base_model.encoder.layers)):
     led_encoder_layer.self_attn.longformer_self_attn.key = bart_encoder_layer.self_attn.k_proj
     led_encoder_layer.self_attn.longformer_self_attn.query = bart_encoder_layer.self_attn.q_proj
     led_encoder_layer.self_attn.longformer_self_attn.value = bart_encoder_layer.self_attn.v_proj
     led_encoder_layer.self_attn.longformer_self_attn.key_global = copy.deepcopy(bart_encoder_layer.self_attn.k_proj)
     led_encoder_layer.self_attn.longformer_self_attn.query_global = copy.deepcopy(bart_encoder_layer.self_attn.q_proj)
     led_encoder_layer.self_attn.longformer_self_attn.value_global = copy.deepcopy(bart_encoder_layer.self_attn.v_proj)
     led_encoder_layer.self_attn.output = bart_encoder_layer.self_attn.out_proj
     led_encoder_layer.self_attn_layer_norm = bart_encoder_layer.self_attn_layer_norm
     led_encoder_layer.fc1 = bart_encoder_layer.fc1
     led_encoder_layer.fc2 = bart_encoder_layer.fc2
     led_encoder_layer.final_layer_norm = bart_encoder_layer.final_layer_norm
  
 for i,(bart_decoder_layer, led_decoder_layer) in enumerate(zip(bart_model.model.decoder.layers, led_model.base_model.decoder.layers)):
     led_decoder_layer.self_attn.k_proj = bart_decoder_layer.self_attn.k_proj
     led_decoder_layer.self_attn.q_proj = bart_decoder_layer.self_attn.q_proj
     led_decoder_layer.self_attn.v_proj = bart_decoder_layer.self_attn.v_proj
     led_decoder_layer.self_attn.out_proj = bart_decoder_layer.self_attn.out_proj
     led_decoder_layer.self_attn_layer_norm = bart_decoder_layer.self_attn_layer_norm
     led_decoder_layer.encoder_attn.k_proj = bart_decoder_layer.encoder_attn.k_proj
     led_decoder_layer.encoder_attn.q_proj = bart_decoder_layer.encoder_attn.q_proj
     led_decoder_layer.encoder_attn.v_proj = bart_decoder_layer.encoder_attn.v_proj
     led_decoder_layer.encoder_attn_layer_norm = bart_decoder_layer.encoder_attn_layer_norm
  
     led_decoder_layer.fc1 = bart_decoder_layer.fc1
     led_decoder_layer.fc2 = bart_decoder_layer.fc2
     led_decoder_layer.final_layer_norm = bart_decoder_layer.final_layer_norm
  
 led_model.lm_head = bart_model.lm_head
  
 logger.info("convert bart-chinese to led success")
 led_model.save_pretrained(r'E:\Project\NLP\long-document\converted_model')
 tokenizer.save_pretrained(r'E:\Project\NLP\long-document\converted_model')
  
 到这里就搞定了,接下来在自己的短文本或者长文本摘要数据集上训练就可以了。亲测可以用。甚至效果还不错。下一篇将介绍把BART的权重加载到Bigbird模型来作为初始化,这样我们也可以拿到一个Bigbird中文模型。需要强调的是,这样直接finetune效果有限,想要更高的结果需要继续pretrain。
3 训练
# coding=utf-8
 import logging
 import datasets
 import numpy as np
 import lawrouge
 from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
 from transformers import LEDForConditionalGeneration, BertTokenizer
  
 from datasets import load_dataset
  
 logger = logging.getLogger("longformer-chinese")
 logging.basicConfig(level=logging.INFO)
  
 dataset = load_dataset('json', data_files=r'D:\nlp\project\long-document\datasets\xxxx.json', field='data') # 加载自己的长文本摘要数据集
 dataset = dataset.shuffle(seeds=42) # shuffle
  
 tokenizer = BertTokenizer.from_pretrained(r'D:\nlp\project\long-document\bert-base-chinese') # 加载bert tokenizer
 model = LEDForConditionalGeneration.from_pretrained(r'D:\nlp\project\long-document\converted_model') # 加载Longformer
 # model.resize_token_embeddings(tokenizer.vocab_size) # 补充词表 21128--->50000
  
 def flatten(example):
     return {
         "document": example["content"],
         "summary": example["title"],
     }
  
 dataset = dataset["train"].map(flatten, remove_columns=["title", "content"])  # , remove_columns=["title", "content"]
  
 max_input_length = 8192 # 4096 or others ,不能超过我们转换的最大长度8192
 max_target_length = 1024  # summary, target text
  
 def preprocess_function(examples):
     inputs = [doc for doc in examples["document"]]
     model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
  
     # Setup the tokenizer for targets
     with tokenizer.as_target_tokenizer():
         labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
  
     model_inputs["labels"] = labels["input_ids"]
     return model_inputs
 dataset = dataset.shuffle()
  
 train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.01).values()
 tokenized_datasets = datasets.DatasetDict({"train": train_data_txt, "validation": validation_data_txt}).map(preprocess_function, batched=True)
  
 batch_size = 1 # ==>穷人
 args = Seq2SeqTrainingArguments(
     fp16 = True,
     output_dir="results_long",
     num_train_epochs=10,  # demo
     do_train=True,
     do_eval=True,
     per_device_train_batch_size=batch_size,  # demo
     per_device_eval_batch_size=batch_size,
     learning_rate=2e-05,
     warmup_steps=1000,
     weight_decay=0.1,
     label_smoothing_factor=0.15,
     predict_with_generate=True,
     logging_dir="logs",
     logging_strategy="steps",
     logging_steps=1,
     save_total_limit=2,
     evaluation_strategy="steps",
     eval_steps=500,
     gradient_accumulation_steps=64,
 )
  
 data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
  
 def compute_metrics(eval_pred):
     predictions, labels = eval_pred
     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  
     decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
     decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]
     # Rouge with jieba cut
     # decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
     # decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]
  
     labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]
  
     for i,(pred,label) in enumerate(zip(decoded_preds,decoded_labels)):
         if pred=="":
             decoded_preds[i]="decoding error,skipping..."
  
     # print(decoded_preds)
     # print(decoded_labels)
     rouge = lawrouge.Rouge()
     result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
     # print(result)
     print(result)
     result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}
  
     result = {key: value * 100 for key, value in result.items()}
     result["gen_len"] = np.mean(labels_lens)
     return result
  
  
 trainer = Seq2SeqTrainer(
     model,
     args,
     train_dataset=tokenized_datasets["train"],
     eval_dataset=tokenized_datasets["validation"],
     data_collator=data_collator,
     tokenizer=tokenizer,
     compute_metrics=compute_metrics,
 )
  
 # 保存模型即训练数据
 train_result = trainer.train()
 print(train_result)
 trainer.save_model()
 metrics = train_result.metrics
 trainer.log_metrics("train", metrics)
 trainer.save_metrics("train", metrics)
 trainer.save_state()
  
  
  
————————————————










