You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.6 KiB
Python

from graph_utils import make_name_generator
from graphviz import Graph
# Visualizes a decision tree, from the random forest
# as generatered by the random_forest_model_altered.py
# Node: {index: int, value: float, gini: float, left: Node|TerminalNode, right: Node|TerminalNode}
# TerminalNode: {index: int, value: float, gini: float, left: int(class), right: int (class)}
# Creates a regular node in the graph and returns its name
# so other function can draw the edges
def make_regular_node (graph:Graph, generate_node_name:callable, node:dict):
node_name = generate_node_name()
graph.node(
node_name,
label="<Index:{index}<BR align=\"left\"/>Value: {value}<BR align=\"left\"/>Gini: {gini:.3f}<BR align=\"left\"/>>".format(**node),
shape='diamond')
return node_name
# Visualizes a TerminalNode from the decision tree
def make_terminal_node (graph: Graph, generate_node_name:callable, className:str):
node_name = generate_node_name()
graph.node(
node_name,
label=className,
shape='plaintext')
return node_name
def make_invisible_node (graph: Graph, generate_node_name:callable):
node_name = generate_node_name()
graph.node(node_name, label=node_name, style='invis')
return node_name
def visualize_node (graph, generate_node_name, node):
if isinstance(node, dict):
# Draw the node itselft
node_name = make_regular_node(graph, generate_node_name, node)
# Make left child/subtree and draw edge
left_child_name = visualize_node(graph, generate_node_name, node['left'])
graph.edge(node_name, left_child_name, tailport='nw', headport='s')
# Make center child and draw edge
center_node_name = make_invisible_node(graph, generate_node_name)
graph.edge(node_name, center_node_name, tailport='n', headport='s', style='invis')
# Make right child/subtree and draw edge
right_child_name = visualize_node(graph, generate_node_name, node['right'])
graph.edge(node_name, right_child_name, tailport='ne', headport='s')
else:
node_name = make_terminal_node(graph, generate_node_name, node)
return node_name
def make_graph (graphname):
graph = Graph(name=graphname, format='svg', engine='dot')
graph.attr('graph', splines='line', rankdir='BT')
return graph
def visualize (tree, graphname, generate_node_name = make_name_generator(length=3)):
graph = make_graph(graphname)
visualize_node(graph, generate_node_name, tree)
graph.render(graphname)
if __name__ == '__main__':
import json
with open('../random_forest_model.json', 'r') as file_in:
forest = json.load(file_in)
for idx, tree in enumerate(forest):
visualize(tree, 'random-tree-{}'.format(idx))