决策树:最清晰明了的分类模型
一共有14个样本,其中9个早上都出去打球,5个早上没出去打球。在原始数据中,统计了每个早上的天气,湿度,是否有风这3个条件
输入数据的每一个特征作为决策树中的一个节点,根据其取值的不同,划分不同的分支,根据各个特征的取值,按照这个树状结构就可以解释一个样本的分类情况。
对于决策树模型,其解释性非常强,可以看做是一连串的if-else条件,根据该条件就可以轻松的预测一个新的样本点。决策树的输入和输出都比较直观,核心就在于构建合适的分类树。
在构建决策树的过程中, 对于一个原始的特征,根据其取值分割成不同的分支,分割的过程其实是一个取子集的过程。以outlook为例,分割前是14个样本,9个play,5个no play; 根据3种取值分割后,sunny有2个play, 3个no play, overcast有4个play, rain有3个play, 2个no paly。
为了量化特征以及分割前后的变化,引入了以下概念
1. 熵
熵是从信息论中引入的概念,用来衡量一个事物的混乱状态,熵越大,越无序,具体的计算公式如下
p代表的是概率,上述示例数据为例共14个样本,其中9个play, 4个no play, 对应的熵如下
>>> -(np.log2(9/14) * (9/14) + np.log2(5/14) * (5/14))
0.9402859586706309
和条件概率这个概念类似,也要条件熵的概念,即再特征X下数据集的熵,公式如下
>>> - (5/14) * (np.log2(2/5) * (2/5) + np.log2(3/5) * (3/5)) - (5/14) * (np.log2(3/5) * (3/5) + np.log2(2/5) * (2/5))
0.6935361388961918
在取值为overcast时,出现了no play为0的情况,无法计算log值,此时直接将其熵定义为0, 所以上述公式只考虑了取值为sunny和rain的情况。
2. 信息增益
在决策树中,我们根据特征的取值将原始的特征拆分成了不同的分支,信息增益用来衡量拆分前后复杂度的变化,具体的计算公式如下
>>> -(np.log2(9/14) * (9/14) + np.log2(5/14) * (5/14)) + (5/14) * (np.log2(2/5) * (2/5) + np.log2(3/5) * (3/5)) + (5/14) * (np.log2(3/5) * (3/5) + np.log2(2/5) * (2/5))
0.2467498197744391
信息增益衡量的是用某个特征拆分前后的复杂度变化,理想的拆分情况是复杂度降低的越多,对应的信息增益值越大,所以对于输入的多个特征,一般选择信息增益大的特征进行拆分。
3. 信息增益率
也叫做信息增益比, 具体的计算公式如下
可以看到,相比信息增益,信息增益比用总体的经验熵进行了矫正,将数据转换到0到1的范围,从而可以直接在不同特征之间进行比较。
决策树的构建,是一个递归的过程,从根节点开始,不断选择信息增益大的特征作为节点,依次进行拆分,直到信息增益很小或者没有特征可以选择为止。基于熵模型的信息增益先后出现了两种算法。
首先是ID3算法,只针对离散型的特征,通过信息增益来选择特征;接下来是C4.5算法,对ID3进行了改进,选择了信息增益比来选择特征, 同时支持处理连续型的特征,对特征值排序之后,迭代选取阈值, 大于阈值为一类,小于阈值为另一类,从而将连续性的特征转换为离散型。
相比熵而言,基尼系数没有对数运算,计算更快捷。
对于决策树而言,常见的问题是过拟合。此时,就需要通过剪枝来优化决策树,所谓剪枝,就是去除决策树中的某些分支。根据生成决策树和剪枝的顺序,可以分为以下两种策略
1. 预剪枝,pre-pruning
2. 后剪枝,post-pruning
在监督学习中,数据集会分为训练集和测试集两部分,在预剪枝中,训练集用来决定划分所用的特征,测试集用来决定该特征是否可以用来划分数据;在后剪枝中,先基于训练集生成一颗决策树,然后用测试集的数据判断某个决策树中的节点是否需要去除。
在scikit-learn中,可以方便的构建决策树,用法如下
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> X, y = load_iris(return_X_y=True)
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, y)
>>> import matplotlib.pyplot as plt
>>> plt.figure()
>>> tree.plot_tree(clf, filled=True)
>>> plt.show()
输出结果如下
默认使用的就是CART算法,同时通过可视化决策树,可以让我们清晰的了解决策的过程。