You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
2.7 KiB
Python

indent_prefix = 2 * ' '
def indent_line(line):
return indent_prefix + line
def encode_branch(branch):
if isinstance(branch, dict):
return encode_tree(branch)
elif isinstance(branch, (int, float)):
return [ indent_line('return {};'.format(branch)) ]
else:
return [ indent_line('return "{}";'.format(branch)) ]
# if value at index is smaller than threshold
# enter left, else right
def encode_tree(tree):
lines = [
"(clearScreen)();",
"(printLine)(\"Considering feature {index}, is it smaller than {threshold}?\");".format(index=tree['index'], threshold=tree['value']),
"(delay)(1500);",
"if (r[{index}] < {threshold}) {{".format(index=tree['index'], threshold=tree['value']),
indent_line("(printLine)(\"Yes\");"),
indent_line("(delay)(1500);"),
*encode_branch(tree['left']),
"}",
"else {",
indent_line("(printLine)(\"No\");"),
indent_line("(delay)(1500);"),
*encode_branch(tree['right']),
"}"
]
return map(indent_line, lines)
def make_classifier (tree):
lines = [
"#pragma once",
"#include <cstdarg>",
"namespace PublishingHouse",
"{",
*map(indent_line, [
"namespace RandomForest",
"{",
*map(indent_line, [
"class DecisionTree",
"{",
"public:",
*map(indent_line, [
"char* predict(float *r, void (*printLine)(char *), void (*clearScreen)(), void (*delay)(int))",
"{",
*encode_tree(tree),
"}",
]),
"private:",
"};"
]),
"}"
]),
"}",
]
return('\n'.join(lines))
if __name__ == '__main__':
import json
import os.path
import glob
basepath = os.path.dirname(os.path.realpath(__file__))
globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json'))
models = glob.glob(globpath)
# Search for exported models
print("Found:")
for k, modelpath in enumerate(models):
print("[{}] {}".format(k, modelpath))
model_key = int(input("Which model?\n"))
modelpath = models[model_key]
# Open model
with open(modelpath, 'r') as file_in:
# Parse the forest
forest = json.load(file_in)
modelname, _ = os.path.splitext(os.path.basename(modelpath))
classifiernamepattern = 'classifier_{}_tree_{{}}.h'.format(modelname)
# Walk through the forest and visualize the trees
for idx, tree in enumerate(forest):
print('Transforming tree {} of {}'.format(idx, len(forest)))
code = make_classifier(tree)
with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h:
h.write(code)
print('Classifiers placed in: {}'.format(basepath))