Decision Trees

Exploring the material in chapter 17 of our "from scratch" textbook.

Jim M | April 2020

In [25]:
from math import log2
from typing import List, Any
from collections import Counter
import numpy as np
from matplotlib import pyplot as plt

def entropy(probabilities : List[float]) -> float: 
    """ Given a list of probabilities, return their info entropy """ 
    # essentially same as scratch.decision_trees
    return sum ( p * log2(1/p)
                 for p in probabilities 
                 if p > 0 )    # ignore zero probabilities 

assert entropy ([ 1.0 ]) == 0 
assert entropy ([ 0.5 , 0.5 ]) == 1
assert 0.81 < entropy ([ 0.25 , 0.75 ]) < 0.82 

def class_probabilities(labels: List[Any]) -> List[float]:
    """ Given a list of labels, return a list of probabilities of each label """
    total_count = len(labels)
    return [count / total_count
            for count in Counter(labels).values()]

assert class_probabilities(['red', 'red', 'blue', 'green']) == [0.5, 0.25, 0.25]

def data_entropy (labels : List [Any]) -> float: 
    """ Given a list of labels, return entropy of their probabilities """
    return entropy(class_probabilities(labels))

assert data_entropy ([ 'a' ]) == 0 
assert data_entropy ([ True , False ]) == 1 
assert data_entropy ([ 3 , 4 , 4 , 4 ]) == entropy ([ 0.25 , 0.75 ]) 

def partition_entropy(subsets: List[List[Any]]) -> float:
    """Returns the entropy from this partition of data into subsets"""
    total_count = sum(len(subset) for subset in subsets)
    return sum(data_entropy(subset) * len(subset) / total_count
               for subset in subsets)
In [46]:
# You ask someone to guess a number from 0 to 9.
# You then play twenty questions with questions like
#
#  "Is it bigger then 8?"
#
# Using one_hot_encoding for the guess, if the number is 3
# then the data would be [False, False, False, True, False, ...].
#
# Here I'll use the partition_entropy method to see which
# choice of partitioning the data gives the most specific result.
#
# As you can see from the results, the answer is that you should
# split near the middle. Doing so over and over gives you a binary
# search, with a total of log2(n) guesses.
#
# Splitting on a bad choices, like "bigger than 2?", "bigger than 4?"
# "bigger than 6" gives the result in about n/2 guesses ... *much* worse.
#
# Choosing the partition with the lowest average entropy means that 
# you have narrowed down the choice as much as possible, on average.
#

n = 10
guess = [False] * n
guess[3] = True
for i in range(n):
    partition = [guess[:i], guess[i:]]
    pe = partition_entropy(partition)
    print(f'split={i}, entropy={pe:<05.3} : ', partition)
split=0, entropy=0.469 :  [[], [False, False, False, True, False, False, False, False, False, False]]
split=1, entropy=0.453 :  [[False], [False, False, True, False, False, False, False, False, False]]
split=2, entropy=0.435 :  [[False, False], [False, True, False, False, False, False, False, False]]
split=3, entropy=0.414 :  [[False, False, False], [True, False, False, False, False, False, False]]
split=4, entropy=0.325 :  [[False, False, False, True], [False, False, False, False, False, False]]
split=5, entropy=0.361 :  [[False, False, False, True, False], [False, False, False, False, False]]
split=6, entropy=0.390 :  [[False, False, False, True, False, False], [False, False, False, False]]
split=7, entropy=0.414 :  [[False, False, False, True, False, False, False], [False, False, False]]
split=8, entropy=0.435 :  [[False, False, False, True, False, False, False, False], [False, False]]
split=9, entropy=0.453 :  [[False, False, False, True, False, False, False, False, False], [False]]
In [28]:
partition_entropy([ [True], [False, True, True]])
Out[28]:
0.6887218755408671

mandatory entropy plot for two probabilities $ p $ and $ 1-p $

... since no discussion of information is complete without it. ;)

In [22]:
# Entropy of two values (i.e. 0, 1) with probability(0)=p, probability(1)=1-p
# Max entropy of 1 is at p=(1-p)=0.5.
# Min entropy of 0 is at p=0 , p=1
probs = np.linspace(0,1)
entropies = [entropy([p, 1-p]) for p in np.linspace(0,1)]
plt.plot(probs, entropies)
plt.title('entropy of (p, 1-p)')
plt.xlabel('p')
plt.ylabel('entropy')
None
In [24]:
n = 100
guess = [False] * n
guess[17] = True

data_entropy(guess)
Out[24]:
0.08079313589591126
In [ ]: