October 22, 2024
Chicago 12, Melborne City, USA
python

Prune function for decision tree


I am creating a decision tree from scratch and implementing pruning. Currently I believe the problem in my code is that when I prune a tree, the new leaf node I create does not get placed into the original tree, so when I calculate the new accuracy, it is valid and all my branches get pruned. Attached is my code for the tree class, leaf class and pruning function.

class TreeNode:
    def __init__(self, feature, split, depth, left = None, right = None):
        """
            self.feature = the feature the node splits upon
            self.split = the value the of the feature the node splits upon
            self.left = the left child
            self.right = the right child
            self.depth = depth of the tree at this point
        """
        self.feature = feature
        self.split = split
        self.left = left
        self.right = right
        self.depth = depth

    def getLeft(self):
        return self.left
    
    def getRight(self):
        return self.right
    
    def getFeature(self):
        return self.feature
    
    def getSplit(self):
        return self.split
    
    def getDepth(self):
        return self.depth

    def eval(self, sample):
        if sample[self.feature] < self.split:
            return self.left.eval(sample)
        else:
            return self.right.eval(sample)
        
class LeafNode:
    def __init__(self, roomNumber, users, depth):
        """
            self.roomNumer = the room number that the leaf is assigned
            self.depth = depth of the leaf in the tree
        """
        self.roomNumber = roomNumber
        self.depth = depth
        self.users = users

    def getRoomNumber(self):
        return self.roomNumber
                
    def getDepth(self):
        return self.depth
    
    def getUsers(self):
        return self.users
    
    def eval(self, sample):
        self.users += 1
        return self.getRoomNumber()

def pruneTree(original_tree, validation, node):
    if node is None:
        return None
    if isinstance(node, LeafNode):
        return node
    node.left = pruneTree(original_tree, validation, node.left)
    node.right = pruneTree(original_tree, validation, node.right)
    if isinstance(node.left, LeafNode) and isinstance(node.right, LeafNode):

        current_accuracy = evaluate(validation, original_tree)

        leftRoom, leftPopulation = node.left.getRoomNumber(), node.left.getUsers()

        rightRoom, rightPopulation = node.right.getRoomNumber(), node.right.getUsers()

        previous_feature, previous_split, previous_depth, previous_left, previous_right = node.getFeature(), node.getSplit(), node.getDepth(), node.getLeft(), node.getRight()

        newRoom = -1

        newPopulation = leftPopulation + rightPopulation

        if rightPopulation >= leftPopulation:
            newRoom = rightRoom
        else:
            newRoom = leftRoom

        node = LeafNode(roomNumber = newRoom, users=newPopulation, depth = previous_depth)

        new_accuracy = evaluate(validation, original_tree)
        
        if new_accuracy < current_accuracy:
            node = TreeNode(split = previous_split, feature=previous_feature, depth=previous_depth)
            node.left = previous_left
            node.right = previous_right
    return node

def evaluate(test_db, trained_tree):
    num_correct = 0
    for data in test_db:
        sample = data[:-1]
        prediction = trained_tree.eval(sample)

        if prediction == data[-1]:
            num_correct += 1
    return num_correct/len(test_db)



You need to sign in to view this answers

Leave feedback about this

  • Quality
  • Price
  • Service

PROS

+
Add Field

CONS

+
Add Field
Choose Image
Choose Video