0
点赞
收藏
分享

微信扫一扫

lstm 新闻标题聚类

程序员漫画编程 2022-04-29 阅读 46
pytorchnlp

lstm 新闻标题聚类

数据集加载

数据集选用了网上找的灾害新闻标题数据集,然后选了5个种类灾害合在一起,重新制作。基本就如图所示:
在这里插入图片描述

label代表标签。
commment代表新闻标题内容。
这里放一个数据集下载链接。
数据集下载地址

然后处理文本标签内容,将其转成更加容易处理的数字标签
处理后的数据如图所示:
在这里插入图片描述
数据集制作完成后使用torchtext来加载数据。

torch.manual_seed(SEED)

seg = pkuseg.pkuseg()


def tokenizer(text):
    return seg.cut(text)


TEXT = Field(sequential=True, tokenize=tokenizer, fix_length=35)
POS = Field(sequential=False, use_vocab=False)

FIELD = [('label', None), ('content', TEXT), ('pos', POS)]

df = TabularDataset(
    path='./data/news.csv', format='csv',
    fields=FIELD, skip_header=True)

TEXT.build_vocab(df, min_freq=3, vectors='glove.6B.50d')

train, valid = df.split(split_ratio=0.7, random_state=random.seed(SEED))

train_iter, valid_iter = BucketIterator.splits(
    (train, valid),
    batch_sizes=(batch_size, batch_size),
    device=device,
    sort_key=lambda x: len(x.content),
    sort_within_batch=False,
    repeat=False
)

模型定义

模型采用LSTM,pytorch已经内置实现


class LSTM(nn.Module):
    def __init__(self, emb_len, emb_dim, out_dim):
        super(LSTM, self).__init__()
        self.embedding = nn.Embedding(emb_len, emb_dim)
        self.lstm = nn.LSTM(emb_dim, out_dim, batch_first=True, dropout=0.5, bidirectional=True, num_layers=2)
        self.linear = nn.Sequential(
            nn.Linear(2 * out_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )

    def forward(self, x):
        # 初始输入格式为(length, batch_size)
        out = self.embedding(x)
        # (length, batch_size, emb) -> (batch_size, length , emb )

        out = torch.transpose(out, 0, 1)

        out, (h, c) = self.lstm(out)
        out = torch.cat((h[-2, :, :], h[-1, :, :]), dim=1)
        out = self.linear(out)

        return out

bidirectional代表使用双向传播,所以在forword中out需要将不同的h进行拼接操作。

训练

模型采用LSTM,pytorch已经内置实现


model = LSTM(len(TEXT.vocab), 64, 128).to(device)

import torch.optim as optim
import torch.nn.functional as F

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = F.cross_entropy
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# -----------------------------------模型训练--------------------------------------
epochs = 100
stop = 20
cnt = 0
best_valid_acc = float('-inf')
model_save_path = './model/torchtext.pkl'

for epoch in range(epochs):
    loss_one_epoch = 0.0
    correct_num = 0.0
    total_num = 0.0

    for i, batch in enumerate(train_iter):
        model.train()
        pos, content = batch.pos, batch.content
        # 进行forward()、backward()、更新权重
        optimizer.zero_grad()
        pred = model(content)
        loss = criterion(pred, pos)
        loss.backward()
        optimizer.step()

        # 统计预测信息
        total_num += pos.size(0)
        # 预测有多少个标签是预测中的,并加总
        correct_num += (torch.argmax(pred, dim=1) == pos).sum().float().item()
        loss_one_epoch += loss.item()

    loss_avg = loss_one_epoch / len(train_iter)

    print("Train: Epoch[{:0>3}/{:0>3}]  Loss: {:.4f} Acc:{:.2%}".
          format(epoch + 1, epochs, loss_avg, correct_num / total_num))

    # ---------------------------------------验证------------------------------
    loss_one_epoch = 0.0
    total_num = 0.0
    correct_num = 0.0

    model.eval()
    for i, batch in enumerate(valid_iter):
        pos, content = batch.pos, batch.content
        pred = model(content)
        pred.detach()
        # 计算loss

        # 统计预测信息
        total_num += pos.size(0)
        # 预测有多少个标签是预测中的,并加总
        correct_num += (torch.argmax(pred, dim=1) == pos).sum().float().item()

    # 学习率调整
    scheduler.step()

    print('valid Acc:{:.2%}'.format(correct_num / total_num))

    # 每个epoch计算一下验证集准确率如果模型效果变好,保存模型
    if (correct_num / total_num) > best_valid_acc:
        print("超过最好模型,保存")
        best_valid_acc = (correct_num / total_num)
        torch.save(model.state_dict(), model_save_path)
        cnt = 0
    else:
        cnt = cnt + 1
        if cnt > stop:
            # 早停机制
            print("模型基本无变化,停止训练")
            print("训练集最高准确率为%.2f" % best_valid_acc)
            break

在训练中设置了早停机制与学习率调整策略。

最后的训练结果如下。
在这里插入图片描述
测试集的准确率是很高,但验证集这边低了,应该是有点过拟合了。
完整代码放github上了
代码地址

举报

相关推荐

0 条评论