You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.1 KiB
Python
79 lines
2.1 KiB
Python
indent_prefix = 2 * ' '
|
|
|
|
|
|
def indent_line(line):
|
|
return indent_prefix + line
|
|
|
|
|
|
def encode_branch(branch):
|
|
if isinstance(branch, dict):
|
|
return encode_tree(branch)
|
|
elif isinstance(branch, (int, float)):
|
|
return [ indent_line('return {};'.format(branch)) ]
|
|
else:
|
|
return [ indent_line('return "{}";'.format(branch)) ]
|
|
|
|
# if value at index is smaller than threshold
|
|
# enter left, else right
|
|
|
|
|
|
def encode_tree(tree):
|
|
lines = [
|
|
"print(\"Considering feature {index}, is it smaller than {threshold}?\");".format(index=tree['index'], threshold=tree['value']),
|
|
"// timeout code",
|
|
"if (r[{index}] < {threshold}) {{".format(index=tree['index'], threshold=tree['value']),
|
|
indent_line("print(\"Yes\");"),
|
|
*encode_branch(tree['left']),
|
|
"}",
|
|
"else {",
|
|
indent_line("print(\"No\");"),
|
|
*encode_branch(tree['right']),
|
|
"}"
|
|
]
|
|
|
|
return map(indent_line, lines)
|
|
|
|
|
|
def make_classifier (tree):
|
|
lines = [
|
|
"char predict(float *r) {",
|
|
*encode_tree(tree),
|
|
"}"
|
|
]
|
|
|
|
return('\n'.join(lines))
|
|
|
|
if __name__ == '__main__':
|
|
import json
|
|
import os.path
|
|
import glob
|
|
basepath = os.path.dirname(os.path.realpath(__file__))
|
|
globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json'))
|
|
|
|
models = glob.glob(globpath)
|
|
|
|
# Search for exported models
|
|
print("Found:")
|
|
for k, modelpath in enumerate(models):
|
|
print("[{}] {}".format(k, modelpath))
|
|
|
|
model_key = int(input("Which model?\n"))
|
|
|
|
modelpath = models[model_key]
|
|
|
|
# Open model
|
|
with open(modelpath, 'r') as file_in:
|
|
# Parse the forest
|
|
forest = json.load(file_in)
|
|
modelname, _ = os.path.splitext(os.path.basename(modelpath))
|
|
classifiernamepattern = 'classifier_{}_tree_{{}}.h'.format(modelname)
|
|
|
|
# Walk through the forest and visualize the trees
|
|
for idx, tree in enumerate(forest):
|
|
print('Transforming tree {} of {}'.format(idx, len(forest)))
|
|
code = make_classifier(tree)
|
|
|
|
with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h:
|
|
h.write(code)
|
|
|
|
print('Classifiers placed in: {}'.format(basepath)) |