Embedding
接口功能
将离散索引映射为连续向量(查表操作)。
函数原型
torch.nn.Embedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device=None,
dtype=None
)
参数说明
参数 |
类型 |
必填 |
说明 |
|---|---|---|---|
num_embeddings |
int |
是 |
嵌入字典大小(索引范围[0, num_embeddings-1])。 |
embedding_dim |
int |
是 |
嵌入向量维度。 |
padding_idx |
int |
否 |
填充索引(该索引对应的向量梯度固定为0)。 |
max_norm |
float |
否 |
最大范数约束(超过时缩放)。 |
norm_type |
float |
否 |
范数计算类型(默认2.0表示L2范数)。 |
scale_grad_by_freq |
bool |
否 |
是否按频率缩放梯度(默认False)。 |
sparse |
bool |
否 |
是否使用稀疏梯度(默认False)。 |
父主题: 算子接口