Script to convert a tree into c if / else tree.

main
Gijs 2 years ago
parent a8dac2ad23
commit 8e132e035a

@ -0,0 +1,79 @@
indent_prefix = 2 * ' '
def indent_line(line):
return indent_prefix + line
def encode_branch(branch):
if isinstance(branch, dict):
return encode_tree(branch)
elif isinstance(branch, (int, float)):
return [ indent_line('return {};'.format(branch)) ]
else:
return [ indent_line('return "{}";'.format(branch)) ]
# if value at index is smaller than threshold
# enter left, else right
def encode_tree(tree):
lines = [
"print(\"Considering feature {index}, is it smaller than {threshold}?\");".format(index=tree['index'], threshold=tree['value']),
"// timeout code",
"if (r[{index}] < {threshold}) {{".format(index=tree['index'], threshold=tree['value']),
indent_line("print(\"Yes\");"),
*encode_branch(tree['left']),
"}",
"else {",
indent_line("print(\"No\");"),
*encode_branch(tree['right']),
"}"
]
return map(indent_line, lines)
def make_classifier (tree):
lines = [
"char predict(float *r) {",
*encode_tree(tree),
"}"
]
return('\n'.join(lines))
if __name__ == '__main__':
import json
import os.path
import glob
basepath = os.path.dirname(os.path.realpath(__file__))
globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json'))
models = glob.glob(globpath)
# Search for exported models
print("Found:")
for k, modelpath in enumerate(models):
print("[{}] {}".format(k, modelpath))
model_key = int(input("Which model?\n"))
modelpath = models[model_key]
# Open model
with open(modelpath, 'r') as file_in:
# Parse the forest
forest = json.load(file_in)
modelname, _ = os.path.splitext(os.path.basename(modelpath))
classifiernamepattern = 'classifier_{}_tree_{{}}.h'.format(modelname)
# Walk through the forest and visualize the trees
for idx, tree in enumerate(forest):
print('Transforming tree {} of {}'.format(idx, len(forest)))
code = make_classifier(tree)
with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h:
h.write(code)
print('Classifiers placed in: {}'.format(basepath))
Loading…
Cancel
Save