36#ifndef VIGRA_RANDOM_FOREST_NP_HXX
37#define VIGRA_RANDOM_FOREST_NP_HXX
42#include "vigra/mathutil.hxx"
43#include "vigra/array_vector.hxx"
44#include "vigra/sized_int.hxx"
45#include "vigra/matrix.hxx"
46#include "vigra/random.hxx"
47#include "vigra/functorexpression.hxx"
58 AllColumns = 0x00000000,
59 ToBePrunedTag = 0x80000000,
60 LeafNodeTag = 0x40000000,
64 i_HypersphereNode = 2,
65 e_ConstProbNode = 0 | LeafNodeTag,
66 e_LogRegProbNode = 1 | LeafNodeTag
93 typedef T_Container_type::iterator Topology_type;
97 mutable Topology_type topology_;
101 int parameter_size_ ;
141 INT
const &
typeID()
const
161 return topology_ + 4 ;
177 return featureCount_;
197 Topology_type topology_end()
const
201 int topology_size()
const
203 return topology_size_;
211 Parameter_type parameters_end()
const
216 int parameters_size()
const
218 return parameter_size_;
243 vigra_precondition(topology_size_==o.topology_size_,
"Cannot copy nodes of different sizes");
244 vigra_precondition(featureCount_==o.featureCount_,
"Cannot copy nodes with different feature count");
245 vigra_precondition(classCount_==o.classCount_,
"Cannot copy nodes with different class counts");
246 vigra_precondition(parameters_size() ==o.parameters_size(),
"Cannot copy nodes with different parameter sizes");
279 topology_size_(
tLen),
281 parameter_size_(
pLen),
296 topology_ (node.topology_),
297 topology_size_(
tLen),
298 parameters_ (node.parameters_),
299 parameter_size_(
pLen),
300 featureCount_(node.featureCount_),
301 classCount_(node.classCount_),
321 topology_size_(
tLen),
322 parameter_size_(
pLen),
360 topology_size_(
toCopy.topology_size()),
361 parameter_size_(
toCopy.parameters_size()),
383template<NodeTags NodeType>
387class Node<i_ThresholdNode>
397 Node( BT::T_Container_type & topology,
398 BT::P_Container_type & param)
399 : BT(5,2,topology, param)
401 BT::typeID() = i_ThresholdNode;
404 Node( BT::T_Container_type
const & topology,
405 BT::P_Container_type
const & param,
407 : BT(5,2,topology, param, n)
416 return BT::parameters_begin()[1];
419 double const & threshold()
const
421 return BT::parameters_begin()[1];
426 return BT::column_data()[0];
428 BT::INT
const & column()
const
430 return BT::column_data()[0];
433 template<
class U,
class C>
434 BT::INT next(MultiArrayView<2,U,C>
const & feature)
const
436 return (feature(0, column()) < threshold())? child(0):child(1);
442class Node<i_HyperplaneNode>
452 BT::T_Container_type & topology,
453 BT::P_Container_type & split_param)
454 : BT(nCol + 5,nCol + 2,topology, split_param)
456 BT::typeID() = i_HyperplaneNode;
459 Node( BT::T_Container_type
const & topology,
460 BT::P_Container_type
const & split_param,
462 : NodeBase(5 , 2,topology, split_param, n)
465 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
467 : BT::column_data()[0];
468 BT::parameter_size_ += BT::columns_size();
475 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
477 : BT::column_data()[0];
478 BT::parameter_size_ += BT::columns_size();
482 double const & intercept()
const
484 return BT::parameters_begin()[1];
488 return BT::parameters_begin()[1];
491 BT::Parameter_type weights()
const
493 return BT::parameters_begin()+2;
496 BT::Parameter_type weights()
498 return BT::parameters_begin()+2;
502 template<
class U,
class C>
503 BT::INT next(MultiArrayView<2,U,C>
const & feature)
const
505 double result = -1 * intercept();
506 if(*(BT::column_data()) == AllColumns)
508 for(
int ii = 0; ii < BT::columns_size(); ++ii)
510 result +=feature[ii] * weights()[ii];
515 for(
int ii = 0; ii < BT::columns_size(); ++ii)
517 result +=feature[BT::columns_begin()[ii]] * weights()[ii];
520 return result < 0 ? BT::child(0)
528class Node<i_HypersphereNode>
538 BT::T_Container_type & topology,
539 BT::P_Container_type & param)
540 : NodeBase(nCol + 5,nCol + 1,topology, param)
542 BT::typeID() = i_HypersphereNode;
545 Node( BT::T_Container_type
const & topology,
546 BT::P_Container_type
const & param,
548 : NodeBase(5, 1,topology, param, n)
550 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
552 : BT::column_data()[0];
553 BT::parameter_size_ += BT::columns_size();
559 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
561 : BT::column_data()[0];
562 BT::parameter_size_ += BT::columns_size();
566 double const & squaredRadius()
const
568 return BT::parameters_begin()[1];
571 double& squaredRadius()
573 return BT::parameters_begin()[1];
576 BT::Parameter_type center()
const
578 return BT::parameters_begin()+2;
581 BT::Parameter_type center()
583 return BT::parameters_begin()+2;
586 template<
class U,
class C>
587 BT::INT next(MultiArrayView<2,U,C>
const & feature)
const
589 double result = -1 * squaredRadius();
590 if(*(BT::column_data()) == AllColumns)
592 for(
int ii = 0; ii < BT::columns_size(); ++ii)
594 result += (feature[ii] - center()[ii])*
595 (feature[ii] - center()[ii]);
600 for(
int ii = 0; ii < BT::columns_size(); ++ii)
602 result += (feature[BT::columns_begin()[ii]] - center()[ii])*
603 (feature[BT::columns_begin()[ii]] - center()[ii]);
606 return result < 0 ? BT::child(0)
626class Node<e_ConstProbNode>
639 BT::typeID() = e_ConstProbNode;
651 :
BT(2, node_.classCount_ +1, node_)
655 return BT::parameters_begin()+1;
659 return prob_begin() + prob_size();
661 int prob_size()
const
663 return BT::classCount_;
668class Node<e_LogRegProbNode>;
Definition rf_nodeproxy.hxx:88
INT const & child(Int32 l) const
Definition rf_nodeproxy.hxx:231
INT & parameter_addr()
Definition rf_nodeproxy.hxx:148
NodeBase(int tLen, int pLen, NodeBase &node)
Definition rf_nodeproxy.hxx:292
bool data() const
Definition rf_nodeproxy.hxx:128
Topology_type columns_end() const
Definition rf_nodeproxy.hxx:184
INT & child(Int32 l)
Definition rf_nodeproxy.hxx:224
Topology_type topology_begin() const
Definition rf_nodeproxy.hxx:193
NodeBase(int tLen, int pLen, T_Container_type &topology, P_Container_type ¶meter)
Definition rf_nodeproxy.hxx:316
int columns_size() const
Definition rf_nodeproxy.hxx:174
Topology_type column_data() const
Definition rf_nodeproxy.hxx:159
NodeBase(int tLen, int pLen, T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition rf_nodeproxy.hxx:272
INT & typeID()
Definition rf_nodeproxy.hxx:136
NodeBase()
Definition rf_nodeproxy.hxx:237
NodeBase(T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition rf_nodeproxy.hxx:254
NodeBase(NodeBase const &toCopy, T_Container_type &topology, P_Container_type ¶meter)
Definition rf_nodeproxy.hxx:356
double & weights()
Definition rf_nodeproxy.hxx:115
Parameter_type parameters_begin() const
Definition rf_nodeproxy.hxx:207
Topology_type columns_begin() const
Definition rf_nodeproxy.hxx:167
Class for a single RGB value.
Definition rgbvalue.hxx:128
size_type size() const
Definition tinyvector.hxx:913
iterator begin()
Definition tinyvector.hxx:861
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175