You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
4.1 KiB
Python

indent_prefix = 2 * ' '
def indent_line(line):
return indent_prefix + line
def encode_branch_with_unease(branch):
if isinstance(branch, dict):
if 'label' in branch:
return encode_tree_with_unease(branch)
else:
return [ indent_line('return "?";') ]
elif isinstance(branch, (int, float)):
return [ indent_line('return {};'.format(branch)) ]
else:
return [ indent_line('return "{}";'.format(branch)) ]
# if value at index is smaller than threshold
# enter left, else right
def encode_tree_with_unease(tree):
lines = [
"if ((considerWithUnease)(\"{unease}\",(getObservationValue)({index}), {index}, {threshold})) {{".format(index=tree['index'], threshold=tree['value'], unease=tree['label']),
*encode_branch_with_unease(tree['left']),
"}",
"else {",
*encode_branch_with_unease(tree['right']),
"}"
]
return map(indent_line, lines)
def encode_branch(branch):
if isinstance(branch, dict):
return encode_tree(branch)
elif isinstance(branch, (int, float)):
return [ indent_line('return {};'.format(branch)) ]
else:
return [ indent_line('return "{}";'.format(branch)) ]
# if value at index is smaller than threshold
# enter left, else right
def encode_tree(tree):
lines = [
"if ((consider)((getObservationValue)({index}), {index}, {threshold})) {{".format(index=tree['index'], threshold=tree['value']),
*encode_branch(tree['left']),
"}",
"else {",
*encode_branch(tree['right']),
"}"
]
return map(indent_line, lines)
def make_classifier (tree):
lines = [
"#pragma once",
"#include <cstdarg>",
"namespace PublishingHouse",
"{",
*map(indent_line, [
"namespace RandomForest",
"{",
*map(indent_line, [
"class DecisionTree",
"{",
"public:",
*map(indent_line, [
"const char* predict(float (*getObservationValue)(int), bool (*consider)(float, int, float))",
"{",
*encode_tree(tree),
"}",
]),
"private:",
"};"
]),
"}"
]),
"}",
]
return('\n'.join(lines))
def make_classifier_unease (tree):
lines = [
"#pragma once",
"#include <cstdarg>",
"namespace PublishingHouse",
"{",
*map(indent_line, [
"namespace RandomForest",
"{",
*map(indent_line, [
"class TreeOfUnease",
"{",
"public:",
*map(indent_line, [
"const char* traverse(float (*getObservationValue)(int), bool (*considerWithUnease)(const char*, float, int, float))",
"{",
*encode_tree_with_unease(tree),
"}",
]),
"private:",
"};"
]),
"}"
]),
"}",
]
return('\n'.join(lines))
if __name__ == '__main__':
import json
import os.path
import glob
basepath = os.path.dirname(os.path.realpath(__file__))
globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json'))
globpath = os.path.realpath(os.path.join(basepath, '*.json'))
models = glob.glob(globpath)
# Search for exported models
print("Found:")
for k, modelpath in enumerate(models):
print("[{}] {}".format(k, modelpath))
model_key = int(input("Which model?\n"))
modelpath = models[model_key]
# Open model
with open(modelpath, 'r') as file_in:
# Parse the forest
forest = json.load(file_in)
modelname, _ = os.path.splitext(os.path.basename(modelpath))
classifiernamepattern = 'Tree_{{}}.h'.format(modelname)
uneasenamepattern = 'Tree_of_unease_{{}}.h'.format(modelname)
# Walk through the forest and visualize the trees
for idx, tree in enumerate(forest):
print('Transforming tree {} of {}'.format(idx, len(forest)))
with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h:
h.write(make_classifier(tree))
with open(os.path.join(basepath, uneasenamepattern.format(idx)), 'w') as h:
h.write(make_classifier_unease(tree))
print('Classifiers placed in: {}'.format(basepath))