35#ifndef VIGRA_RF3_RANDOM_FOREST_HXX
36#define VIGRA_RF3_RANDOM_FOREST_HXX
41#include "../multi_shape.hxx"
42#include "../binary_forest.hxx"
43#include "../threadpool.hxx"
44#include "random_forest_common.hxx"
64template <
typename FEATURES,
66 typename SPLITTESTS = LessEqualSplitTest<typename FEATURES::value_type>,
67 typename ACCTYPE = ArgMaxVectorAcc<double>>
78 typedef typename ACC::input_type AccInputType;
82 static ContainerTag
const container_tag = VectorTag;
118 const std::vector<size_t> &
tree_indices = std::vector<size_t>()
123 template <
typename PROBS>
128 const std::vector<size_t> &
tree_indices = std::vector<size_t>()
133 template <
typename IDS>
138 const std::vector<size_t>
tree_indices = std::vector<size_t>()
183 template <
typename IDS,
typename INDICES>
184 double leaf_ids_impl(
192 template<
typename PROBS>
193 void predict_probabilities_impl(
201template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
210template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
213 typename NodeMap<SplitTests>::type
const &
split_tests,
223template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
228 "RandomForest::merge(): You cannot merge with different problem specs.");
232 size_t const offset = num_nodes();
233 graph_.merge(other.
graph_);
236 split_tests_.insert(Node(p.first.id()+
offset), p.second);
240 node_responses_.insert(Node(p.first.id()+
offset), p.second);
246template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
253 vigra_precondition(features.shape()[0] == labels.shape()[0],
254 "RandomForest::predict(): Shape mismatch between features and labels.");
255 vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
256 "RandomForest::predict(): Number of features in prediction differs from training.");
260 for (
size_t i = 0;
i < (
size_t)features.shape()[0]; ++
i)
265 labels(
i) = problem_spec_.distinct_classes_[label];
272template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
273template <
typename PROBS>
280 vigra_precondition(features.shape()[0] ==
probs.shape()[0],
281 "RandomForest::predict_probabilities(): Shape mismatch between features and probabilities.");
282 vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
283 "RandomForest::predict_probabilities(): Number of features in prediction differs from training.");
284 vigra_precondition((
size_t)
probs.shape()[1] == problem_spec_.num_classes_,
285 "RandomForest::predict_probabilities(): Number of labels in probabilities differs from training.");
300 vigra_precondition(
i < graph_.numRoots(),
"RandomForest::leaf_ids(): Tree index out of range.");
303 size_t const num_instances = features.shape()[0];
306 n_threads = std::thread::hardware_concurrency();
319template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
320template <
typename PROBS>
337 Node node = graph_.getRoot(
k);
338 while (graph_.outDegree(node) > 0)
347 auto sub_probs = probs.template bind<0>(i);
348 acc(tree_results.begin(), tree_results.end(), sub_probs.begin());
351template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
352template <
typename IDS>
359 vigra_precondition(features.shape()[0] ==
ids.shape()[0],
360 "RandomForest::leaf_ids(): Shape mismatch between features and probabilities.");
361 vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
362 "RandomForest::leaf_ids(): Number of features in prediction differs from training.");
363 vigra_precondition(
ids.shape()[1] == graph_.numRoots(),
364 "RandomForest::leaf_ids(): Leaf array has wrong shape.");
370 vigra_precondition(
i < graph_.numRoots(),
"RandomForest::leaf_ids(): Tree index out of range.");
379 size_t const num_instances = features.shape()[0];
381 n_threads = std::thread::hardware_concurrency();
385 std::vector<size_t> indices(num_instances);
386 std::iota(indices.begin(), indices.end(), 0);
393 split_comparisons[thread_id] += this->leaf_ids_impl(features, ids, i, i+1, tree_indices);
401template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
402template <
typename IDS,
typename INDICES>
410 vigra_precondition(features.shape()[0] ==
ids.shape()[0],
411 "RandomForest::leaf_ids_impl(): Shape mismatch between features and labels.");
412 vigra_precondition(features.shape()[1] == problem_spec_.num_features_,
413 "RandomForest::leaf_ids_impl(): Number of Features in prediction differs from training.");
414 vigra_precondition(
from >= 0 &&
from <=
to &&
to <= (
size_t)features.shape()[0],
415 "RandomForest::leaf_ids_impl(): Indices out of range.");
416 vigra_precondition(
ids.shape()[1] == graph_.numRoots(),
417 "RandomForest::leaf_ids_impl(): Leaf array has wrong shape.");
425 Node node = graph_.getRoot(
k);
426 while (graph_.outDegree(node) > 0)
432 ids(
i,
k) = node.id();
435 return split_comparisons;
BinaryForest stores a collection of rooted binary trees.
Definition binary_forest.hxx:65
size_t numRoots() const
Return the number of trees in the forest.
Definition binary_forest.hxx:332
detail::NodeDescriptor< index_type > Node
Node descriptor type of the present graph.
Definition binary_forest.hxx:70
size_t numNodes() const
Return the number of nodes (equivalent to maxNodeId()+1).
Definition binary_forest.hxx:289
Class for a single RGB value.
Definition rgbvalue.hxx:128
Base::value_type value_type
Definition rgbvalue.hxx:141
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition random_forest.hxx:148
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Options class for vigra::rf3::RandomForest version 3.
Definition random_forest_common.hxx:583
Random forest version 3.
Definition random_forest.hxx:69
void predict_probabilities(FEATURES const &features, PROBS &probs, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the probabilities of the given data and return the average number of split comparisons.
Definition random_forest.hxx:274
RandomForestOptions options_
The options that were used for training.
Definition random_forest.hxx:178
size_t num_features() const
Return the number of classes.
Definition random_forest.hxx:160
void merge(RandomForest const &other)
Grow this forest by incorporating the other.
Definition random_forest.hxx:224
void predict(FEATURES const &features, LABELS &labels, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the given data and return the average number of split comparisons.
Definition random_forest.hxx:247
NodeMap< SplitTests >::type split_tests_
Contains a test for each internal node, that is used to determine whether given data goes to the left...
Definition random_forest.hxx:169
Graph graph_
The graph structure.
Definition random_forest.hxx:166
double leaf_ids(FEATURES const &features, IDS &ids, int n_threads=-1, const std::vector< size_t > tree_indices=std::vector< size_t >()) const
For each data point in features, compute the corresponding leaf ids and return the average number of ...
Definition random_forest.hxx:353
size_t num_trees() const
Return the number of trees.
Definition random_forest.hxx:148
size_t num_nodes() const
Return the number of nodes.
Definition random_forest.hxx:142
ProblemSpec< LabelType > problem_spec_
The specifications.
Definition random_forest.hxx:175
size_t num_classes() const
Return the number of classes.
Definition random_forest.hxx:154
NodeMap< AccInputType >::type node_responses_
Contains the responses of each node (for example the most frequent label).
Definition random_forest.hxx:172
void parallel_foreach(...)
Apply a functor to all items in a range in parallel.