I am creating a decision tree from scratch and implementing pruning. Currently I believe the problem in my code is that when I prune a tree, the new leaf node I create does not get placed into the original tree, so when I calculate the new accuracy, it is valid and all my branches get pruned. Attached is my code for the tree class, leaf class and pruning function.
class TreeNode:
def __init__(self, feature, split, depth, left = None, right = None):
"""
self.feature = the feature the node splits upon
self.split = the value the of the feature the node splits upon
self.left = the left child
self.right = the right child
self.depth = depth of the tree at this point
"""
self.feature = feature
self.split = split
self.left = left
self.right = right
self.depth = depth
def getLeft(self):
return self.left
def getRight(self):
return self.right
def getFeature(self):
return self.feature
def getSplit(self):
return self.split
def getDepth(self):
return self.depth
def eval(self, sample):
if sample[self.feature] < self.split:
return self.left.eval(sample)
else:
return self.right.eval(sample)
class LeafNode:
def __init__(self, roomNumber, users, depth):
"""
self.roomNumer = the room number that the leaf is assigned
self.depth = depth of the leaf in the tree
"""
self.roomNumber = roomNumber
self.depth = depth
self.users = users
def getRoomNumber(self):
return self.roomNumber
def getDepth(self):
return self.depth
def getUsers(self):
return self.users
def eval(self, sample):
self.users += 1
return self.getRoomNumber()
def pruneTree(original_tree, validation, node):
if node is None:
return None
if isinstance(node, LeafNode):
return node
node.left = pruneTree(original_tree, validation, node.left)
node.right = pruneTree(original_tree, validation, node.right)
if isinstance(node.left, LeafNode) and isinstance(node.right, LeafNode):
current_accuracy = evaluate(validation, original_tree)
leftRoom, leftPopulation = node.left.getRoomNumber(), node.left.getUsers()
rightRoom, rightPopulation = node.right.getRoomNumber(), node.right.getUsers()
previous_feature, previous_split, previous_depth, previous_left, previous_right = node.getFeature(), node.getSplit(), node.getDepth(), node.getLeft(), node.getRight()
newRoom = -1
newPopulation = leftPopulation + rightPopulation
if rightPopulation >= leftPopulation:
newRoom = rightRoom
else:
newRoom = leftRoom
node = LeafNode(roomNumber = newRoom, users=newPopulation, depth = previous_depth)
new_accuracy = evaluate(validation, original_tree)
if new_accuracy < current_accuracy:
node = TreeNode(split = previous_split, feature=previous_feature, depth=previous_depth)
node.left = previous_left
node.right = previous_right
return node
def evaluate(test_db, trained_tree):
num_correct = 0
for data in test_db:
sample = data[:-1]
prediction = trained_tree.eval(sample)
if prediction == data[-1]:
num_correct += 1
return num_correct/len(test_db)
You need to sign in to view this answers