0
点赞
收藏
分享

微信扫一扫

SapBERT: Self-alignment pretraining for BERT的代码使用示例


最近在研究SapBERT来计算实体的相似度,发现官方的repo没有给使用示例,我仿照写了一下使用示例,方便直接把SapBERT用起来,我的环境是:

torch                   1.7.1+cu101
torchvision 0.11.3
transformers 4.16.2

下面是使用代码,知道SapBERT是抽取向量的就行了,然后就可以用一些类似faiss的近似向量检索工具进行检索了:

from transformers import AutoTokenizer, AutoModel
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

query = "cardiopathy"
query_toks = tokenizer.batch_encode_plus([query],
padding="max_length",
max_length=25,
truncation=True,
return_tensors="pt")
print(query_toks)
query_output = model(**query_toks)
query_cls_rep = query_output[0][:,0,:]
print(query_cls_rep)

all_names = ['Neoplasm of anterior aspect of epiglottis']

toks = tokenizer.batch_encode_plus(all_names,
padding="max_length",
max_length=25,
truncation=True,
return_tensors="pt")

output = model(**toks)
cls_rep = output[0][:,0,:]
print(cls_rep)

# for large-scale search, should switch to faiss
from scipy.spatial.distance import cdist

dist = cdist(query_cls_rep.cpu().detach().numpy(), cls_rep.cpu().detach().numpy())
nn_index = np.argmin(dist)
# print ("predicted label:", snomed_sf_id_pairs_100k[nn_index])


举报

相关推荐

0 条评论