Script to convert a tree into c if / else tree.
parent
a8dac2ad23
commit
8e132e035a
@ -0,0 +1,79 @@
|
||||
indent_prefix = 2 * ' '
|
||||
|
||||
|
||||
def indent_line(line):
|
||||
return indent_prefix + line
|
||||
|
||||
|
||||
def encode_branch(branch):
|
||||
if isinstance(branch, dict):
|
||||
return encode_tree(branch)
|
||||
elif isinstance(branch, (int, float)):
|
||||
return [ indent_line('return {};'.format(branch)) ]
|
||||
else:
|
||||
return [ indent_line('return "{}";'.format(branch)) ]
|
||||
|
||||
# if value at index is smaller than threshold
|
||||
# enter left, else right
|
||||
|
||||
|
||||
def encode_tree(tree):
|
||||
lines = [
|
||||
"print(\"Considering feature {index}, is it smaller than {threshold}?\");".format(index=tree['index'], threshold=tree['value']),
|
||||
"// timeout code",
|
||||
"if (r[{index}] < {threshold}) {{".format(index=tree['index'], threshold=tree['value']),
|
||||
indent_line("print(\"Yes\");"),
|
||||
*encode_branch(tree['left']),
|
||||
"}",
|
||||
"else {",
|
||||
indent_line("print(\"No\");"),
|
||||
*encode_branch(tree['right']),
|
||||
"}"
|
||||
]
|
||||
|
||||
return map(indent_line, lines)
|
||||
|
||||
|
||||
def make_classifier (tree):
|
||||
lines = [
|
||||
"char predict(float *r) {",
|
||||
*encode_tree(tree),
|
||||
"}"
|
||||
]
|
||||
|
||||
return('\n'.join(lines))
|
||||
|
||||
if __name__ == '__main__':
|
||||
import json
|
||||
import os.path
|
||||
import glob
|
||||
basepath = os.path.dirname(os.path.realpath(__file__))
|
||||
globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json'))
|
||||
|
||||
models = glob.glob(globpath)
|
||||
|
||||
# Search for exported models
|
||||
print("Found:")
|
||||
for k, modelpath in enumerate(models):
|
||||
print("[{}] {}".format(k, modelpath))
|
||||
|
||||
model_key = int(input("Which model?\n"))
|
||||
|
||||
modelpath = models[model_key]
|
||||
|
||||
# Open model
|
||||
with open(modelpath, 'r') as file_in:
|
||||
# Parse the forest
|
||||
forest = json.load(file_in)
|
||||
modelname, _ = os.path.splitext(os.path.basename(modelpath))
|
||||
classifiernamepattern = 'classifier_{}_tree_{{}}.h'.format(modelname)
|
||||
|
||||
# Walk through the forest and visualize the trees
|
||||
for idx, tree in enumerate(forest):
|
||||
print('Transforming tree {} of {}'.format(idx, len(forest)))
|
||||
code = make_classifier(tree)
|
||||
|
||||
with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h:
|
||||
h.write(code)
|
||||
|
||||
print('Classifiers placed in: {}'.format(basepath))
|
Loading…
Reference in New Issue