embedding table 内存优化

1,669次阅读
没有评论

前言

Embedding table 优化关键的一点是内存空间占用的优化,比如一个id类特征几个亿,维度32维,你要生成几亿*32的矩阵,这个存储空间消耗可是很大,所以今天介绍其中一个方法,就是 muti hash 的方法,这个是在最近阿里开源的deeprec框架上面,不知道这个框架能不能持续发展下去,看看xdl就很心寒,其他家推出的开源框架基本是只是上传了出版代码后面的就没了。主要大模型这块,不光是训练框架,其他的涉及推理等等都是配套的,这块也算是技术壁垒,毕竟各家都是自研的框架。

今天介绍的 multi hash 的方法就是要解决embedding 维度过大,存储空间占用过多的问题,论文可以参考《Compositional Embeddings Using Complementary Partitions for Memory-Efficent Recommendation Systems》。

先来看下传统的 embedding table

embedding

如果id 特征有 N 种取值情况,那么你就要生成 N*Embedding_size大小的矩阵用于存储向量,这个消耗的内存过于庞大,毕竟id特征还不少。当然最主要的一般都是用户的id 和物料的id ,动辄上亿维度。

multi hash 的方法就是要减少 N 的维度。既然方法里面提到hash,这也是这个方法的重点。

算法逻辑

QUOTIENT-REMAINDER TRICK

算法的详细计算如下图所示:

embedding

  1. 构建两个矩阵,两个矩阵是互补的关系,注意两个的维度大小一个是m\times D 一个是\frac{\vert S\vert}{m}\times D
  2. 计算查询的索引i对应在两个embedding table中的索引
  3. 获取到对应的embedding数据
  4. 对从W_1W_2查到的向量做elemen-wise计算

泛化的QR方法

这里需要强调的是构建多个 Partitions 也就是分成多少组,对应多少个 Embedding table。

embedding

举个例子:

S = {0, 1, 2, 3, 4}. 分成3个Partition如下所示

{{0}, {1, 3, 4}, {2}}, {{0, 1, 3}, {2, 4}}, {{0, 3}, {1, 2, 4}}.

到这一步之后下一步还是要结合多个 Partition下查询出来的embedding 结合问题,这里文章给出了三种方式:concat 、addition和element-wise 乘积

embedding

实验结论

Full table可以看作是ground truth标准,可以看到 qr trick 相比于 hash trick效果好,这个就是hash trick 冲突带来的问题,qr trick 一定程度上是做了缓解,就是hash trick 和full table之间的折中方案,算是效果和工程实现兼顾,对于工业界来说能落地才是最重要的环节。

embedding

admin
版权声明:本站原创文章,由 admin 2022-01-28发表,共计1134字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)