You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.4 KiB
Python

from graph_utils import make_name_generator
from graphviz import Graph
# Visualizes a decision tree, from the random forest
# as generatered by the random_forest_model_altered.py
# Node: {index: int, value: float, gini: float, left: Node|TerminalNode, right: Node|TerminalNode}
# TerminalNode: {index: int, value: float, gini: float, left: int(class), right: int (class)}
# Creates a regular node in the graph and returns its name
# so other function can draw the edges
def make_regular_node (graph:Graph, generate_node_name:callable, node:dict):
node_name = generate_node_name()
graph.node(
node_name,
label="<Index:{index}<BR align=\"left\"/>Value: {value}<BR align=\"left\"/>Gini: {gini:.3f}<BR align=\"left\"/>>".format(**node),
shape='diamond')
return node_name
# Visualizes a TerminalNode from the decision tree
def make_terminal_node (graph: Graph, generate_node_name:callable, className:str):
node_name = generate_node_name()
graph.node(
node_name,
label=className,
shape='plaintext')
return node_name
def make_invisible_node (graph: Graph, generate_node_name:callable):
node_name = generate_node_name()
graph.node(node_name, label=node_name, style='invis')
return node_name
def visualize_node (graph, generate_node_name, node):
if isinstance(node, dict):
# Draw the node itselft
node_name = make_regular_node(graph, generate_node_name, node)
# Make left child/subtree and draw edge
left_child_name = visualize_node(graph, generate_node_name, node['left'])
graph.edge(node_name, left_child_name, tailport='nw', headport='s')
# Make center child and draw edge
center_node_name = make_invisible_node(graph, generate_node_name)
graph.edge(node_name, center_node_name, tailport='n', headport='s', style='invis')
# Make right child/subtree and draw edge
right_child_name = visualize_node(graph, generate_node_name, node['right'])
graph.edge(node_name, right_child_name, tailport='ne', headport='s')
else:
node_name = make_terminal_node(graph, generate_node_name, node)
return node_name
def make_graph (graphname):
graph = Graph(name=graphname, format='svg', engine='dot')
graph.attr('graph', splines='line', rankdir='BT')
return graph
def visualize (tree, graphname, generate_node_name = make_name_generator(length=3), directory=None):
graph = make_graph(graphname)
visualize_node(graph, generate_node_name, tree)
graph.render(graphname, directory=directory)
if __name__ == '__main__':
import json
import os.path
import glob
basepath = os.path.dirname(os.path.realpath(__file__))
globpath = os.path.realpath(os.path.join(basepath, '..', 'random_forest_model_*-trees.json'))
models = glob.glob(globpath)
# Search for exported models
for modelpath in models:
print("Found {}".format(modelpath))
# Open model
with open(modelpath, 'r') as file_in:
# Parse the forest
forest = json.load(file_in)
modelname, _ = os.path.splitext(os.path.basename(modelpath))
graphnamepattern = '{}_tree_{{}}'.format(modelname)
# Walk through the forest and visualize the trees
for idx, tree in enumerate(forest):
graphname = graphnamepattern.format(idx, len(forest))
print('Visualizing tree {} of {}'.format(idx, len(forest)))
visualize(tree, graphname, directory=basepath)
print()
print('Graphs placed in: {}'.format(basepath))