You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
from graph_utils import make_name_generator
|
|
from graphviz import Graph
|
|
|
|
# Visualizes a decision tree, from the random forest
|
|
# as generatered by the random_forest_model_altered.py
|
|
|
|
# Node: {index: int, value: float, gini: float, left: Node|TerminalNode, right: Node|TerminalNode}
|
|
# TerminalNode: {index: int, value: float, gini: float, left: int(class), right: int (class)}
|
|
|
|
# Creates a regular node in the graph and returns its name
|
|
# so other function can draw the edges
|
|
|
|
def make_regular_node (graph:Graph, generate_node_name:callable, node:dict):
|
|
node_name = generate_node_name()
|
|
graph.node(
|
|
node_name,
|
|
label="<Index:{index}<BR align=\"left\"/>Value: {value}<BR align=\"left\"/>Gini: {gini:.3f}<BR align=\"left\"/>>".format(**node),
|
|
shape='diamond')
|
|
return node_name
|
|
|
|
# Visualizes a TerminalNode from the decision tree
|
|
def make_terminal_node (graph: Graph, generate_node_name:callable, className:str):
|
|
node_name = generate_node_name()
|
|
graph.node(
|
|
node_name,
|
|
label=className,
|
|
shape='plaintext')
|
|
return node_name
|
|
|
|
def make_invisible_node (graph: Graph, generate_node_name:callable):
|
|
node_name = generate_node_name()
|
|
graph.node(node_name, label=node_name, style='invis')
|
|
return node_name
|
|
|
|
def visualize_node (graph, generate_node_name, node):
|
|
if isinstance(node, dict):
|
|
# Draw the node itselft
|
|
node_name = make_regular_node(graph, generate_node_name, node)
|
|
|
|
# Make left child/subtree and draw edge
|
|
left_child_name = visualize_node(graph, generate_node_name, node['left'])
|
|
graph.edge(node_name, left_child_name, tailport='nw', headport='s')
|
|
|
|
# Make center child and draw edge
|
|
center_node_name = make_invisible_node(graph, generate_node_name)
|
|
graph.edge(node_name, center_node_name, tailport='n', headport='s', style='invis')
|
|
|
|
# Make right child/subtree and draw edge
|
|
right_child_name = visualize_node(graph, generate_node_name, node['right'])
|
|
graph.edge(node_name, right_child_name, tailport='ne', headport='s')
|
|
|
|
else:
|
|
node_name = make_terminal_node(graph, generate_node_name, node)
|
|
|
|
return node_name
|
|
|
|
def make_graph (graphname):
|
|
graph = Graph(name=graphname, format='svg', engine='dot')
|
|
graph.attr('graph', splines='line', rankdir='BT')
|
|
return graph
|
|
|
|
def visualize (tree, graphname, generate_node_name = make_name_generator(length=3)):
|
|
graph = make_graph(graphname)
|
|
visualize_node(graph, generate_node_name, tree)
|
|
graph.render(graphname)
|
|
|
|
if __name__ == '__main__':
|
|
import json
|
|
|
|
with open('../random_forest_model.json', 'r') as file_in:
|
|
forest = json.load(file_in)
|
|
|
|
for idx, tree in enumerate(forest):
|
|
visualize(tree, 'random-tree-{}'.format(idx)) |