当前位置: 首页 > >

DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks

发布时间:

DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks


论文链接:https://dl.acm.org/citation.cfm?id=3330858
代码链接:https://github.com/motefly/DeepGBM


背景

Guolin Ke是lightgbm的核心作者之一,我幸运的作为Hui Xue的实*生去参加了NIPS2018的表数据的AutoML比赛,这个比赛中的三个数据集也被这篇论文选为benchmark的测试数据集,他们对表数据的理解非常深刻,这篇论文是他们在KDD2019的投稿,这篇论文本身写的很好,通俗易懂,我读完后再从自己的角度用大白话讲一下,方便大家对这篇文章的理解。


表数据的输入类型从业务角度可以有多种,比如时间列,离散列,连续值列,多值列,字典列(key-value对儿),但是从模型的角度看只有两种,离散列和连续值列,目前我见过的表数据的AutoML也都是把业务类型的有实际意义的列变换成离散列和连续值列这两种情况,比如时间列我们可以做一些时间特征(提取年月日时分秒,是否是周末),多值列可以用tfidf提取特征,字典列可以展开当做稀疏的特征。


在表数据上竞赛圈目前大家最常用的模型就是gbdt(lightgbm, xgboost),FM/FFM(xlearn)


概述

经典模型各有优劣,主要是两个角度,离散值还是连续值,在线场景还是离线场景。


挑战

    gbdt在连续值,离线场景下表现非常好,但是在离散值多且在线场景下表现很差,原因是离散值很多如果做one-hot的话太稀疏树很难生长,如果不做one-hot直接扔进模型的话会很容易过拟合当前的数据集,trading-off是用一些不如one-hot稀疏的encoding方法,但是多少会丢一些信息。在线场景很难更新当前树的结构,需要重新利用新的统计信息重新构建树,所以在线场景效率很低,不改变结构更新树的方法也有,比如改变叶子节点的值但是这种方法效果都不太好。

    LR和NN,FM这类模型在离散值上处理的很好,且可以在线更新,但是连续值较多的情况下效果远差于gbdt。


DeepGBM希望可以结合模型各自的优点,所以点主要在于如何结合上:


离散值给CatNN,连续值给GBDT2NN。


CatNN是把离散值one-hot以后学一个embedding,用一些对离散值效果好的模型预测然后ensemble,比如FM和NN。
GBDT2NN是先训练一个gbdt,然后把gbdt转成NN(期望相同输入有相同输出,用不同的方法拟合同一个东西),这篇论文还有一个点在于如果更好的把gbdt转成NN。


拿CatNN和GBDT2NN的结果做一个ensemble得到最后的结果。


离线场景:


    拆分数据为离散数据,连续数据利用离散数据训练CatNN,利用连续数据训练GBDT把GBDT转成NN并训练。ensemble CatNN和GBDT2NN

在线场景下更新的时候只更新CatNN。



GBDT2NN

下面重点讲一下他是如何训练GBDT2NN的:
先讲如果是一棵树的话怎么做,再讲如何扩展到有多棵树。


Single Tree Distillation

我们的目的是用一个NN来拟合出这颗树,对相同的输入有相同的输出。


    Tree-Selected Features
    相比NN来说,用树的一个好处是树不会用所有的特征,他只用一部分,我们训练的这个NN只使用树使用过的特征当做输入,这样会大大加速NN训练的过程,相当于是用树做了特征选择,我们做时间和效果的trading-off的时候也可以通过控制列选用比例来加速,因为树的特征选择是可以给出特征的重要性排序的。

    Tree Structure
    NN的拟合能力很强,我们用NN来拟合一个树的结构,对于一个输入样本,树可以返回这个样本的叶子节点的编号,是一个one-hot的向量,NN要去拟合的就是这个过程,对于一个样本,期望返回和树一样的叶子节点的编号(一个one-hot的向量)。



    Tree Outputs
    因为我们在第二步拟合结构的时候输出就是一个one-hot的叶子节点的编号,所以这个输出我们直接用树原来的叶子节点的值就好。

Multiple Tree Distillation

当我们有多棵树的时候,最简单的办法是每一颗树像之前那样训练一个NN,但是这样复杂度太高了,我们通过Leaf embedding和Tree Grouping对这个过程进行加速。


Leaf Embedding Distillation

加一层fully-connected层做embedding,把稀疏的one-hot向量变成一个密集向量,然后再训练得到一个Leaf Output。



Tree Grouping

Leaf Embedding减少了每一颗树的维度,我们通过Tree Grouping来减少NN的数目。


Tree Grouping是从gbdt的所有树种分组,每一组训练一个NN,deepgbm用的方法是随机选树进行分组,每个组树的颗数一样,每个组内的one-hot向量concat起来,然后就和之前一样一个组训练一个NN就好了。


实验


图中对比了当前SOTA的模型在这几个数据集下面的情况,结果我还没有复现,打算先读一下作者的源码,复现了论文的结果,然后在自己的框架里面搭一套试试怎么样,会在接下来尝试一下,看论文的实验结果还是不错的,但是中间有一些细节需要再琢磨一下,比如参数的设置,特征工程等等,因为在实际使用的时候特征工程可能远不止这些,已经ensemble这个过程怎么调比较合适等等。
看完这篇paper我觉得收获主要在于在什么情况下用什么模型,这个结构是什么样的,以及看看9102年大家是如何理解这个问题的,如果只是为了简单,稳定,效果还不错,那就用一个单lightgbm效果就还可以了。



友情链接: