from graph_utils import make_name_generator from graphviz import Graph # Visualizes a decision tree, from the random forest # as generatered by the random_forest_model_altered.py # Node: {index: int, value: float, gini: float, left: Node|TerminalNode, right: Node|TerminalNode} # TerminalNode: {index: int, value: float, gini: float, left: int(class), right: int (class)} # Creates a regular node in the graph and returns its name # so other function can draw the edges def make_regular_node (graph:Graph, generate_node_name:callable, node:dict): node_name = generate_node_name() graph.node( node_name, label="Value: {value}
Gini: {gini:.3f}
>".format(**node), shape='diamond') return node_name # Visualizes a TerminalNode from the decision tree def make_terminal_node (graph: Graph, generate_node_name:callable, className:str): node_name = generate_node_name() graph.node( node_name, label=className, shape='plaintext') return node_name def make_invisible_node (graph: Graph, generate_node_name:callable): node_name = generate_node_name() graph.node(node_name, label=node_name, style='invis') return node_name def visualize_node (graph, generate_node_name, node): if isinstance(node, dict): # Draw the node itselft node_name = make_regular_node(graph, generate_node_name, node) # Make left child/subtree and draw edge left_child_name = visualize_node(graph, generate_node_name, node['left']) graph.edge(node_name, left_child_name, tailport='nw', headport='s') # Make center child and draw edge center_node_name = make_invisible_node(graph, generate_node_name) graph.edge(node_name, center_node_name, tailport='n', headport='s', style='invis') # Make right child/subtree and draw edge right_child_name = visualize_node(graph, generate_node_name, node['right']) graph.edge(node_name, right_child_name, tailport='ne', headport='s') else: node_name = make_terminal_node(graph, generate_node_name, node) return node_name def make_graph (graphname): graph = Graph(name=graphname, format='svg', engine='dot') graph.attr('graph', splines='line', rankdir='BT') return graph def visualize (tree, graphname, generate_node_name = make_name_generator(length=3)): graph = make_graph(graphname) visualize_node(graph, generate_node_name, tree) graph.render(graphname) if __name__ == '__main__': import json with open('../random_forest_model.json', 'r') as file_in: forest = json.load(file_in) for idx, tree in enumerate(forest): visualize(tree, 'random-tree-{}'.format(idx))