diff --git a/commenting_code_model/encode_model.py b/commenting_code_model/encode_model.py new file mode 100644 index 0000000..90f7e17 --- /dev/null +++ b/commenting_code_model/encode_model.py @@ -0,0 +1,79 @@ +indent_prefix = 2 * ' ' + + +def indent_line(line): + return indent_prefix + line + + +def encode_branch(branch): + if isinstance(branch, dict): + return encode_tree(branch) + elif isinstance(branch, (int, float)): + return [ indent_line('return {};'.format(branch)) ] + else: + return [ indent_line('return "{}";'.format(branch)) ] + +# if value at index is smaller than threshold +# enter left, else right + + +def encode_tree(tree): + lines = [ + "print(\"Considering feature {index}, is it smaller than {threshold}?\");".format(index=tree['index'], threshold=tree['value']), + "// timeout code", + "if (r[{index}] < {threshold}) {{".format(index=tree['index'], threshold=tree['value']), + indent_line("print(\"Yes\");"), + *encode_branch(tree['left']), + "}", + "else {", + indent_line("print(\"No\");"), + *encode_branch(tree['right']), + "}" + ] + + return map(indent_line, lines) + + +def make_classifier (tree): + lines = [ + "char predict(float *r) {", + *encode_tree(tree), + "}" + ] + + return('\n'.join(lines)) + +if __name__ == '__main__': + import json + import os.path + import glob + basepath = os.path.dirname(os.path.realpath(__file__)) + globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json')) + + models = glob.glob(globpath) + + # Search for exported models + print("Found:") + for k, modelpath in enumerate(models): + print("[{}] {}".format(k, modelpath)) + + model_key = int(input("Which model?\n")) + + modelpath = models[model_key] + + # Open model + with open(modelpath, 'r') as file_in: + # Parse the forest + forest = json.load(file_in) + modelname, _ = os.path.splitext(os.path.basename(modelpath)) + classifiernamepattern = 'classifier_{}_tree_{{}}.h'.format(modelname) + + # Walk through the forest and visualize the trees + for idx, tree in enumerate(forest): + print('Transforming tree {} of {}'.format(idx, len(forest))) + code = make_classifier(tree) + + with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h: + h.write(code) + + print('Classifiers placed in: {}'.format(basepath)) \ No newline at end of file