0
点赞
收藏
分享

微信扫一扫

什么是 Embedding()

Hyggelook 2022-01-20 阅读 46

本文只是初步解释, 或者解释

Embedding(input_dim, 
               output_dim,  
               embeddings_initializer='uniform',  embeddings_regularizer=None, activity_regularizer=None, 		embeddings_constraint=None, mask_zero=False,   
               input_length=None)

中的 input_dim, output_dim, input_length.

令 S 为一良序集合 S, Embedding 为函数 f : S → R n f: \text{S} \to \R^n f:SRn, n 为手选参数. 从 S 中放回抽取 m 次, 按抽取顺序排成一序列.可类比从词典中选词组成句子.

input_dim : 集合 S 的大小,
output_dim : n
input_length : m 的最大值.

假定 S : = { v 0 , v 1 , ⋯   , v 9 } \text{S}:=\{v_0,v_1,\cdots, v_9\} S:={v0,v1,,v9}. 构作编号 f 1 ( v i ) = i f_1(v_i)=i f1(vi)=i.
选 n=4, m=2.

基本想法是 使用 one-hot 编码
f 2 ( i ) = [ 0 , 0 , ⋯   , 0 , 1 , 0 , ⋯   , 0 ] f_2(i)=[0,0, \cdots, 0, 1, 0,\cdots,0] f2(i)=[0,0,,0,1,0,,0]
这两个函数是确定好的,不产生参数.
每次输入为两个 元素, 例如 ( v 1 v 2 ) → f 2 f 1 ( 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) \left( \begin{array}{c} v_1 \\ v_2 \end{array} \right) \xrightarrow{f_2f_1} \left( \begin{array}{c} 1,0,0,0,0,0,0,0,0,0\\ 0,1,0,0,0,0,0,0,0,0\\ \end{array} \right) (v1v2)f2f1 (1,0,0,0,0,0,0,0,0,00,1,0,0,0,0,0,0,0,0)

即得到是 2 × 9 2 \times 9 2×9 的矩阵, 将之转换为 2 × 4 2 \times 4 2×4 的矩阵
转换需要的矩阵的规模是 9 × 4 9 \times 4 9×4. 也即是
参数为 36.
输出的矩阵大小为 2 × 4 2 \times 4 2×4.

此即表明. Embedding() 的
输入是 input_length 大小的向量,
输出是 input_length × \times × out_put 的张量,
需要学习的参数是 input_dim × \times × out_put.

这个只是讨论了一个样本的情况, 对多个样本, 实际是 l × \times × input_length 的形式. 通常使用 keras 中的 Input 函数 , Input(shape(input_length,))
Embedding 改写为 Embedding(input_dim, output_dim).
l 为样本数, 可用 l 个句子理解.
(Input(shape(input_length,)))

其他参数的解释超出了能力范围, 不清楚. 下面是代码, 可自行运行.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np
model = Sequential()
embedding_layer = Embedding(input_dim=10,output_dim=4,input_length=2)
model.add(embedding_layer)
model.compile('adam','mse')
input_data = np.array([[1,2]])
pred = model.predict(input_data)
print(input_data.shape)
print(pred)

此处么有训练, 能出数据是因为 Embedding 自带了初值.

举报

相关推荐

0 条评论