indent_prefix = 2 * ' ' def indent_line(line): return indent_prefix + line def encode_branch(branch): if isinstance(branch, dict): return encode_tree(branch) elif isinstance(branch, (int, float)): return [ indent_line('return {};'.format(branch)) ] else: return [ indent_line('return "{}";'.format(branch)) ] # if value at index is smaller than threshold # enter left, else right def encode_tree(tree): lines = [ "print(\"Considering feature {index}, is it smaller than {threshold}?\");".format(index=tree['index'], threshold=tree['value']), "// timeout code", "if (r[{index}] < {threshold}) {{".format(index=tree['index'], threshold=tree['value']), indent_line("print(\"Yes\");"), *encode_branch(tree['left']), "}", "else {", indent_line("print(\"No\");"), *encode_branch(tree['right']), "}" ] return map(indent_line, lines) def make_classifier (tree): lines = [ "char predict(float *r) {", *encode_tree(tree), "}" ] return('\n'.join(lines)) if __name__ == '__main__': import json import os.path import glob basepath = os.path.dirname(os.path.realpath(__file__)) globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json')) models = glob.glob(globpath) # Search for exported models print("Found:") for k, modelpath in enumerate(models): print("[{}] {}".format(k, modelpath)) model_key = int(input("Which model?\n")) modelpath = models[model_key] # Open model with open(modelpath, 'r') as file_in: # Parse the forest forest = json.load(file_in) modelname, _ = os.path.splitext(os.path.basename(modelpath)) classifiernamepattern = 'classifier_{}_tree_{{}}.h'.format(modelname) # Walk through the forest and visualize the trees for idx, tree in enumerate(forest): print('Transforming tree {} of {}'.format(idx, len(forest))) code = make_classifier(tree) with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h: h.write(code) print('Classifiers placed in: {}'.format(basepath))