Embedding
场景说明
将离散ID映射为稠密向量,目前kdnn支持torch.float16数据类型,其他数据类型会走开源分支。
示例代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import torch import torch.nn as nn # 启用kdnn torch._C._set_kdnn_enabled(True) # 构造Embedding层 embed = nn.Embedding(num_embeddings=1000, embedding_dim=128, torch.float16) # 输入数据2x2的token索引 input_ids = torch.LongTensor([[1, 2], [3, 4]]) embeddings = embed(input_ids) # 输出 [2, 2, 128] print(embeddings) |
父主题: 使用示例