from graph_utils import make_name_generator from graphviz import Graph # Visualizes a decision tree, from the random forest # as generatered by the random_forest_model_altered.py # Node: {index: int, value: float, gini: float, left: Node|TerminalNode, right: Node|TerminalNode} # TerminalNode: {index: int, value: float, gini: float, left: int(class), right: int (class)} # Creates a regular node in the graph and returns its name # so other function can draw the edges def make_regular_node (graph:Graph, generate_node_name:callable, node:dict): node_name = generate_node_name() graph.node( node_name, label="Value: {value}
Gini: {gini:.3f}
>".format(**node), shape='diamond') return node_name # Visualizes a TerminalNode from the decision tree def make_terminal_node (graph: Graph, generate_node_name:callable, className:str): node_name = generate_node_name() graph.node( node_name, label=className, shape='plaintext') return node_name def make_invisible_node (graph: Graph, generate_node_name:callable): node_name = generate_node_name() graph.node(node_name, label=node_name, style='invis') return node_name def visualize_node (graph, generate_node_name, node): if isinstance(node, dict): # Draw the node itselft node_name = make_regular_node(graph, generate_node_name, node) # Make left child/subtree and draw edge left_child_name = visualize_node(graph, generate_node_name, node['left']) graph.edge(node_name, left_child_name, tailport='nw', headport='s') # Make center child and draw edge center_node_name = make_invisible_node(graph, generate_node_name) graph.edge(node_name, center_node_name, tailport='n', headport='s', style='invis') # Make right child/subtree and draw edge right_child_name = visualize_node(graph, generate_node_name, node['right']) graph.edge(node_name, right_child_name, tailport='ne', headport='s') else: node_name = make_terminal_node(graph, generate_node_name, node) return node_name def make_graph (graphname): graph = Graph(name=graphname, format='svg', engine='dot') graph.attr('graph', splines='line', rankdir='BT') return graph def visualize (tree, graphname, generate_node_name = make_name_generator(length=3), directory=None): graph = make_graph(graphname) visualize_node(graph, generate_node_name, tree) graph.render(graphname, directory=directory) if __name__ == '__main__': import json import os.path import glob basepath = os.path.dirname(os.path.realpath(__file__)) globpath = os.path.realpath(os.path.join(basepath, '..', 'random_forest_model_*-trees.json')) models = glob.glob(globpath) # Search for exported models for modelpath in models: print("Found {}".format(modelpath)) # Open model with open(modelpath, 'r') as file_in: # Parse the forest forest = json.load(file_in) modelname, _ = os.path.splitext(os.path.basename(modelpath)) graphnamepattern = '{}_tree_{{}}'.format(modelname) # Walk through the forest and visualize the trees for idx, tree in enumerate(forest): graphname = graphnamepattern.format(idx, len(forest)) print('Visualizing tree {} of {}'.format(idx, len(forest))) visualize(tree, graphname, directory=basepath) print() print('Graphs placed in: {}'.format(basepath))