You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
4.1 KiB
Python
159 lines
4.1 KiB
Python
indent_prefix = 2 * ' '
|
|
|
|
|
|
def indent_line(line):
|
|
return indent_prefix + line
|
|
|
|
|
|
def encode_branch_with_unease(branch):
|
|
if isinstance(branch, dict):
|
|
if 'label' in branch:
|
|
return encode_tree_with_unease(branch)
|
|
else:
|
|
return [ indent_line('return "?";') ]
|
|
elif isinstance(branch, (int, float)):
|
|
return [ indent_line('return {};'.format(branch)) ]
|
|
else:
|
|
return [ indent_line('return "{}";'.format(branch)) ]
|
|
|
|
# if value at index is smaller than threshold
|
|
# enter left, else right
|
|
|
|
|
|
def encode_tree_with_unease(tree):
|
|
lines = [
|
|
"if ((considerWithUnease)(\"{unease}\",(getObservationValue)({index}), {index}, {threshold})) {{".format(index=tree['index'], threshold=tree['value'], unease=tree['label']),
|
|
*encode_branch_with_unease(tree['left']),
|
|
"}",
|
|
"else {",
|
|
*encode_branch_with_unease(tree['right']),
|
|
"}"
|
|
]
|
|
|
|
return map(indent_line, lines)
|
|
|
|
|
|
def encode_branch(branch):
|
|
if isinstance(branch, dict):
|
|
return encode_tree(branch)
|
|
elif isinstance(branch, (int, float)):
|
|
return [ indent_line('return {};'.format(branch)) ]
|
|
else:
|
|
return [ indent_line('return "{}";'.format(branch)) ]
|
|
|
|
# if value at index is smaller than threshold
|
|
# enter left, else right
|
|
|
|
|
|
def encode_tree(tree):
|
|
lines = [
|
|
"if ((consider)((getObservationValue)({index}), {index}, {threshold})) {{".format(index=tree['index'], threshold=tree['value']),
|
|
*encode_branch(tree['left']),
|
|
"}",
|
|
"else {",
|
|
*encode_branch(tree['right']),
|
|
"}"
|
|
]
|
|
|
|
return map(indent_line, lines)
|
|
|
|
|
|
def make_classifier (tree):
|
|
lines = [
|
|
"#pragma once",
|
|
"#include <cstdarg>",
|
|
"namespace PublishingHouse",
|
|
"{",
|
|
*map(indent_line, [
|
|
"namespace RandomForest",
|
|
"{",
|
|
*map(indent_line, [
|
|
"class DecisionTree",
|
|
"{",
|
|
"public:",
|
|
*map(indent_line, [
|
|
"const char* predict(float (*getObservationValue)(int), bool (*consider)(float, int, float))",
|
|
"{",
|
|
*encode_tree(tree),
|
|
"}",
|
|
]),
|
|
"private:",
|
|
"};"
|
|
]),
|
|
"}"
|
|
]),
|
|
"}",
|
|
]
|
|
|
|
return('\n'.join(lines))
|
|
|
|
|
|
def make_classifier_unease (tree):
|
|
lines = [
|
|
"#pragma once",
|
|
"#include <cstdarg>",
|
|
"namespace PublishingHouse",
|
|
"{",
|
|
*map(indent_line, [
|
|
"namespace RandomForest",
|
|
"{",
|
|
*map(indent_line, [
|
|
"class TreeOfUnease",
|
|
"{",
|
|
"public:",
|
|
*map(indent_line, [
|
|
"const char* traverse(float (*getObservationValue)(int), bool (*considerWithUnease)(const char*, float, int, float))",
|
|
"{",
|
|
*encode_tree_with_unease(tree),
|
|
"}",
|
|
]),
|
|
"private:",
|
|
"};"
|
|
]),
|
|
"}"
|
|
]),
|
|
"}",
|
|
]
|
|
|
|
return('\n'.join(lines))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import json
|
|
import os.path
|
|
import glob
|
|
basepath = os.path.dirname(os.path.realpath(__file__))
|
|
globpath = os.path.realpath(os.path.join(basepath, 'random_forest_model_*-trees.json'))
|
|
globpath = os.path.realpath(os.path.join(basepath, '*.json'))
|
|
|
|
models = glob.glob(globpath)
|
|
|
|
# Search for exported models
|
|
print("Found:")
|
|
for k, modelpath in enumerate(models):
|
|
print("[{}] {}".format(k, modelpath))
|
|
|
|
model_key = int(input("Which model?\n"))
|
|
|
|
modelpath = models[model_key]
|
|
|
|
# Open model
|
|
with open(modelpath, 'r') as file_in:
|
|
# Parse the forest
|
|
forest = json.load(file_in)
|
|
modelname, _ = os.path.splitext(os.path.basename(modelpath))
|
|
classifiernamepattern = 'Tree_{{}}.h'.format(modelname)
|
|
uneasenamepattern = 'Tree_of_unease_{{}}.h'.format(modelname)
|
|
|
|
# Walk through the forest and visualize the trees
|
|
for idx, tree in enumerate(forest):
|
|
print('Transforming tree {} of {}'.format(idx, len(forest)))
|
|
|
|
with open(os.path.join(basepath, classifiernamepattern.format(idx)), 'w') as h:
|
|
h.write(make_classifier(tree))
|
|
|
|
with open(os.path.join(basepath, uneasenamepattern.format(idx)), 'w') as h:
|
|
h.write(make_classifier_unease(tree))
|
|
|
|
|
|
print('Classifiers placed in: {}'.format(basepath)) |