鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

Embedding

接口功能

将离散索引映射为连续向量(查表操作)。

函数原型

torch.nn.Embedding(
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    sparse: bool = False,
    _weight: Optional[Tensor] = None,
    device=None,
    dtype=None
)

参数说明

表1 参数说明

参数

类型

必填

说明

num_embeddings

int

嵌入字典大小(索引范围[0, num_embeddings-1])。

embedding_dim

int

嵌入向量维度。

padding_idx

int

填充索引(该索引对应的向量梯度固定为0)。

max_norm

float

最大范数约束(超过时缩放)。

norm_type

float

范数计算类型(默认2.0表示L2范数)。

scale_grad_by_freq

bool

是否按频率缩放梯度(默认False)。

sparse

bool

是否使用稀疏梯度(默认False)。