You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
from graph_utils import make_name_generator
|
|
from graphviz import Graph
|
|
|
|
# Visualizes a decision tree, from the random forest
|
|
# as generatered by the random_forest_model_altered.py
|
|
|
|
# Node: {index: int, value: float, gini: float, left: Node|TerminalNode, right: Node|TerminalNode}
|
|
# TerminalNode: {index: int, value: float, gini: float, left: int(class), right: int (class)}
|
|
|
|
# Creates a regular node in the graph and returns its name
|
|
# so other function can draw the edges
|
|
|
|
def make_regular_node (graph:Graph, generate_node_name:callable, node:dict):
|
|
node_name = generate_node_name()
|
|
graph.node(
|
|
node_name,
|
|
label="<Index:{index}<BR align=\"left\"/>Value: {value}<BR align=\"left\"/>Gini: {gini:.3f}<BR align=\"left\"/>>".format(**node),
|
|
shape='diamond')
|
|
return node_name
|
|
|
|
# Visualizes a TerminalNode from the decision tree
|
|
def make_terminal_node (graph: Graph, generate_node_name:callable, className:str):
|
|
node_name = generate_node_name()
|
|
graph.node(
|
|
node_name,
|
|
label=className,
|
|
shape='plaintext')
|
|
return node_name
|
|
|
|
def make_invisible_node (graph: Graph, generate_node_name:callable):
|
|
node_name = generate_node_name()
|
|
graph.node(node_name, label=node_name, style='invis')
|
|
return node_name
|
|
|
|
def visualize_node (graph, generate_node_name, node):
|
|
if isinstance(node, dict):
|
|
# Draw the node itselft
|
|
node_name = make_regular_node(graph, generate_node_name, node)
|
|
|
|
# Make left child/subtree and draw edge
|
|
left_child_name = visualize_node(graph, generate_node_name, node['left'])
|
|
graph.edge(node_name, left_child_name, tailport='nw', headport='s')
|
|
|
|
# Make center child and draw edge
|
|
center_node_name = make_invisible_node(graph, generate_node_name)
|
|
graph.edge(node_name, center_node_name, tailport='n', headport='s', style='invis')
|
|
|
|
# Make right child/subtree and draw edge
|
|
right_child_name = visualize_node(graph, generate_node_name, node['right'])
|
|
graph.edge(node_name, right_child_name, tailport='ne', headport='s')
|
|
|
|
else:
|
|
node_name = make_terminal_node(graph, generate_node_name, node)
|
|
|
|
return node_name
|
|
|
|
def make_graph (graphname):
|
|
graph = Graph(name=graphname, format='svg', engine='dot')
|
|
graph.attr('graph', splines='line', rankdir='BT')
|
|
return graph
|
|
|
|
def visualize (tree, graphname, generate_node_name = make_name_generator(length=3), directory=None):
|
|
graph = make_graph(graphname)
|
|
visualize_node(graph, generate_node_name, tree)
|
|
graph.render(graphname, directory=directory)
|
|
|
|
if __name__ == '__main__':
|
|
import json
|
|
import os.path
|
|
import glob
|
|
basepath = os.path.dirname(os.path.realpath(__file__))
|
|
globpath = os.path.realpath(os.path.join(basepath, '..', 'random_forest_model_*-trees.json'))
|
|
|
|
models = glob.glob(globpath)
|
|
|
|
# Search for exported models
|
|
for modelpath in models:
|
|
print("Found {}".format(modelpath))
|
|
|
|
# Open model
|
|
with open(modelpath, 'r') as file_in:
|
|
# Parse the forest
|
|
forest = json.load(file_in)
|
|
modelname, _ = os.path.splitext(os.path.basename(modelpath))
|
|
graphnamepattern = '{}_tree_{{}}'.format(modelname)
|
|
|
|
# Walk through the forest and visualize the trees
|
|
for idx, tree in enumerate(forest):
|
|
graphname = graphnamepattern.format(idx, len(forest))
|
|
print('Visualizing tree {} of {}'.format(idx, len(forest)))
|
|
visualize(tree, graphname, directory=basepath)
|
|
print()
|
|
|
|
print('Graphs placed in: {}'.format(basepath)) |