[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_split.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
36#define VIGRA_RANDOM_FOREST_SPLIT_HXX
37#include <algorithm>
38#include <cstddef>
39#include <map>
40#include <numeric>
41#include <math.h>
42#include "../mathutil.hxx"
43#include "../array_vector.hxx"
44#include "../sized_int.hxx"
45#include "../matrix.hxx"
46#include "../random.hxx"
47#include "../functorexpression.hxx"
48#include "rf_nodeproxy.hxx"
49//#include "rf_sampling.hxx"
50#include "rf_region.hxx"
51//#include "../hokashyap.hxx"
52//#include "vigra/rf_helpers.hxx"
53
54namespace vigra
55{
56
57// Incomplete Class to ensure that findBestSplit is always implemented in
58// the derived classes of SplitBase
59class CompileTimeError;
60
61
62namespace detail
63{
64 template<class Tag>
65 class Normalise
66 {
67 public:
68 template<class Iter>
69 static void exec(Iter /*begin*/, Iter /*end*/)
70 {}
71 };
72
73 template<>
74 class Normalise<ClassificationTag>
75 {
76 public:
77 template<class Iter>
78 static void exec (Iter begin, Iter end)
79 {
80 double bla = std::accumulate(begin, end, 0.0);
81 for(int ii = 0; ii < end - begin; ++ii)
82 begin[ii] = begin[ii]/bla ;
83 }
84 };
85}
86
87
88/** Base Class for all SplitFunctors used with the \ref RandomForest class
89 defines the interface used while learning a tree.
90**/
91template<class Tag>
93{
94 public:
95
96 typedef Tag RF_Tag;
99
100 ProblemSpec<> ext_param_;
101
104
105 NodeBase node_;
106
107 /** returns the DecisionTree Node created by
108 \ref SplitBase::findBestSplit() or \ref SplitBase::makeTerminalNode().
109 **/
110
111 template<class T>
113 {
114 ext_param_ = in;
115 t_data.push_back(in.column_count_);
116 t_data.push_back(in.class_count_);
117 }
118
119 NodeBase & createNode()
120 {
121 return node_;
122 }
123
124 int classCount() const
125 {
126 return int(t_data[1]);
127 }
128
129 int featureCount() const
130 {
131 return int(t_data[0]);
132 }
133
134 /** resets internal data. Should always be called before
135 calling findBestSplit or makeTerminalNode
136 **/
137 void reset()
138 {
139 t_data.resize(2);
140 p_data.resize(0);
141 }
142
143
144 /** findBestSplit has to be re-implemented in derived split functor.
145 The defaut implementation only insures that a CompileTime error is issued
146 if no such method was defined.
147 **/
148
149 template<class T, class C, class T2, class C2, class Region, class Random>
151 MultiArrayView<2, T2, C2> /*labels*/,
152 Region /*region*/,
153 ArrayVector<Region> /*childs*/,
154 Random /*randint*/)
155 {
156#ifndef __clang__
157 // FIXME: This compile-time checking trick does not work for clang.
159#endif
160 return 0;
161 }
162
163 /** Default action for creating a terminal Node.
164 sets the Class probability of the remaining region according to
165 the class histogram
166 **/
167 template<class T, class C, class T2,class C2, class Region, class Random>
169 MultiArrayView<2, T2, C2> /* labels */,
170 Region & region,
171 Random /* randint */)
172 {
173 Node<e_ConstProbNode> ret(t_data, p_data);
174 node_ = ret;
175 if(ext_param_.class_weights_.size() != region.classCounts().size())
176 {
177 std::copy(region.classCounts().begin(),
178 region.classCounts().end(),
179 ret.prob_begin());
180 }
181 else
182 {
183 std::transform(region.classCounts().begin(),
184 region.classCounts().end(),
185 ext_param_.class_weights_.begin(),
186 ret.prob_begin(), std::multiplies<double>());
187 }
188 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
189// std::copy(ret.prob_begin(), ret.prob_end(), std::ostream_iterator<double>(std::cerr, ", " ));
190// std::cerr << std::endl;
191 ret.weights() = region.size();
192 return e_ConstProbNode;
193 }
194
195
196};
197
198/** Functor to sort the indices of a feature Matrix by a certain dimension
199**/
200template<class DataMatrix>
202{
203 DataMatrix const & data_;
204 MultiArrayIndex sortColumn_;
205 double thresVal_;
206 public:
207
210 double thresVal = 0.0)
211 : data_(data),
212 sortColumn_(sortColumn),
213 thresVal_(thresVal)
214 {}
215
216 void setColumn(MultiArrayIndex sortColumn)
217 {
218 sortColumn_ = sortColumn;
219 }
220 void setThreshold(double value)
221 {
222 thresVal_ = value;
223 }
224
225 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
226 {
227 return data_(l, sortColumn_) < data_(r, sortColumn_);
228 }
229 bool operator()(MultiArrayIndex l) const
230 {
231 return data_(l, sortColumn_) < thresVal_;
232 }
233};
234
235template<class DataMatrix>
236class DimensionNotEqual
237{
238 DataMatrix const & data_;
239 MultiArrayIndex sortColumn_;
240
241 public:
242
243 DimensionNotEqual(DataMatrix const & data,
245 : data_(data),
246 sortColumn_(sortColumn)
247 {}
248
249 void setColumn(MultiArrayIndex sortColumn)
250 {
251 sortColumn_ = sortColumn;
252 }
253
254 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
255 {
256 return data_(l, sortColumn_) != data_(r, sortColumn_);
257 }
258};
259
260template<class DataMatrix>
261class SortSamplesByHyperplane
262{
263 DataMatrix const & data_;
264 Node<i_HyperplaneNode> const & node_;
265
266 public:
267
268 SortSamplesByHyperplane(DataMatrix const & data,
269 Node<i_HyperplaneNode> const & node)
270 :
271 data_(data),
272 node_(node)
273 {}
274
275 /** calculate the distance of a sample point to a hyperplane
276 */
277 double operator[](MultiArrayIndex l) const
278 {
279 double result_l = -1 * node_.intercept();
280 for(int ii = 0; ii < node_.columns_size(); ++ii)
281 {
282 result_l += rowVector(data_, l)[node_.columns_begin()[ii]]
283 * node_.weights()[ii];
284 }
285 return result_l;
286 }
287
288 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
289 {
290 return (*this)[l] < (*this)[r];
291 }
292
293};
294
295/** makes a Class Histogram given indices in a labels_ array
296 * usage:
297 * MultiArrayView<2, T2, C2> labels = makeSomeLabels()
298 * ArrayVector<int> hist(numberOfLabels(labels), 0);
299 * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
300 *
301 * Container<int> indices = getSomeIndices()
302 * std::for_each(indices, counter);
303 */
304template <class DataSource, class CountArray>
306{
307 DataSource const & labels_;
308 CountArray & counts_;
309
310 public:
311
312 RandomForestClassCounter(DataSource const & labels,
314 : labels_(labels),
315 counts_(counts)
316 {
317 reset();
318 }
319
320 void reset()
321 {
322 counts_.init(0);
323 }
324
325 void operator()(MultiArrayIndex l) const
326 {
327 counts_[labels_[l]] +=1;
328 }
329};
330
331
332/** Functor To Calculate the Best possible Split Based on the Gini Index
333 given Labels and Features along a given Axis
334*/
335
336namespace detail
337{
338 template<int N>
339 class ConstArr
340 {
341 public:
342 double operator[](size_t) const
343 {
344 return (double)N;
345 }
346 };
347
348
349}
350
351
352
353
354/** Functor to calculate the entropy based impurity
355 */
357{
358public:
359 /**calculate the weighted gini impurity based on class histogram
360 * and class weights
361 */
362 template<class Array, class Array2>
363 double operator() (Array const & hist,
364 Array2 const & weights,
365 double total = 1.0) const
366 {
367 return impurity(hist, weights, total);
368 }
369
370 /** calculate the gini based impurity based on class histogram
371 */
372 template<class Array>
373 double operator()(Array const & hist, double total = 1.0) const
374 {
375 return impurity(hist, total);
376 }
377
378 /** static version of operator(hist total)
379 */
380 template<class Array>
381 static double impurity(Array const & hist, double total)
382 {
383 return impurity(hist, detail::ConstArr<1>(), total);
384 }
385
386 /** static version of operator(hist, weights, total)
387 */
388 template<class Array, class Array2>
389 static double impurity (Array const & hist,
390 Array2 const & weights,
391 double total)
392 {
393
394 int class_count = hist.size();
395 double entropy = 0.0;
396 if(class_count == 2)
397 {
398 double p0 = (hist[0]/total);
399 double p1 = (hist[1]/total);
400 entropy = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1);
401 }
402 else
403 {
404 for(int ii = 0; ii < class_count; ++ii)
405 {
406 double w = weights[ii];
407 double pii = hist[ii]/total;
408 entropy -= w*( pii*std::log(pii));
409 }
410 }
412 return entropy;
413 }
414};
415
416/** Functor to calculate the gini impurity
417 */
419{
420public:
421 /**calculate the weighted gini impurity based on class histogram
422 * and class weights
423 */
424 template<class Array, class Array2>
425 double operator() (Array const & hist,
426 Array2 const & weights,
427 double total = 1.0) const
428 {
429 return impurity(hist, weights, total);
430 }
431
432 /** calculate the gini based impurity based on class histogram
433 */
434 template<class Array>
435 double operator()(Array const & hist, double total = 1.0) const
436 {
437 return impurity(hist, total);
438 }
439
440 /** static version of operator(hist total)
441 */
442 template<class Array>
443 static double impurity(Array const & hist, double total)
444 {
445 return impurity(hist, detail::ConstArr<1>(), total);
446 }
447
448 /** static version of operator(hist, weights, total)
449 */
450 template<class Array, class Array2>
451 static double impurity (Array const & hist,
452 Array2 const & weights,
453 double total)
454 {
455
456 int class_count = hist.size();
457 double gini = 0.0;
458 if(class_count == 2)
459 {
460 double w = weights[0] * weights[1];
461 gini = w * (hist[0] * hist[1] / total);
462 }
463 else
464 {
465 for(int ii = 0; ii < class_count; ++ii)
466 {
467 double w = weights[ii];
468 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
469 }
470 }
471 return gini;
472 }
473};
474
475
476template <class DataSource, class Impurity= GiniCriterion>
477class ImpurityLoss
478{
479
480 DataSource const & labels_;
481 ArrayVector<double> counts_;
482 ArrayVector<double> const class_weights_;
483 double total_counts_;
484 Impurity impurity_;
485
486 public:
487
488 template<class T>
489 ImpurityLoss(DataSource const & labels,
490 ProblemSpec<T> const & ext_)
491 : labels_(labels),
492 counts_(ext_.class_count_, 0.0),
493 class_weights_(ext_.class_weights_),
494 total_counts_(0.0)
495 {}
496
497 void reset()
498 {
499 counts_.init(0);
500 total_counts_ = 0.0;
501 }
502
503 template<class Counts>
504 double increment_histogram(Counts const & counts)
505 {
506 std::transform(counts.begin(), counts.end(),
507 counts_.begin(), counts_.begin(),
508 std::plus<double>());
509 total_counts_ = std::accumulate( counts_.begin(),
510 counts_.end(),
511 0.0);
512 return impurity_(counts_, class_weights_, total_counts_);
513 }
514
515 template<class Counts>
516 double decrement_histogram(Counts const & counts)
517 {
518 std::transform(counts.begin(), counts.end(),
519 counts_.begin(), counts_.begin(),
520 std::minus<double>());
521 total_counts_ = std::accumulate( counts_.begin(),
522 counts_.end(),
523 0.0);
524 return impurity_(counts_, class_weights_, total_counts_);
525 }
526
527 template<class Iter>
528 double increment(Iter begin, Iter end)
529 {
530 for(Iter iter = begin; iter != end; ++iter)
531 {
532 counts_[labels_(*iter, 0)] +=1.0;
533 total_counts_ +=1.0;
534 }
535 return impurity_(counts_, class_weights_, total_counts_);
536 }
537
538 template<class Iter>
539 double decrement(Iter const & begin, Iter const & end)
540 {
541 for(Iter iter = begin; iter != end; ++iter)
542 {
543 counts_[labels_(*iter,0)] -=1.0;
544 total_counts_ -=1.0;
545 }
546 return impurity_(counts_, class_weights_, total_counts_);
547 }
548
549 template<class Iter, class Resp_t>
550 double init (Iter /*begin*/, Iter /*end*/, Resp_t resp)
551 {
552 reset();
553 std::copy(resp.begin(), resp.end(), counts_.begin());
554 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0);
555 return impurity_(counts_,class_weights_, total_counts_);
556 }
557
558 ArrayVector<double> const & response()
559 {
560 return counts_;
561 }
562};
563
564
565
566 template <class DataSource>
567 class RegressionForestCounter
568 {
569 public:
570 typedef MultiArrayShape<2>::type Shp;
571 DataSource const & labels_;
572 ArrayVector <double> mean_;
573 ArrayVector <double> variance_;
574 ArrayVector <double> tmp_;
575 size_t count_;
576 int* end_;
577
578 template<class T>
579 RegressionForestCounter(DataSource const & labels,
580 ProblemSpec<T> const & ext_)
581 :
582 labels_(labels),
583 mean_(ext_.response_size_, 0.0),
584 variance_(ext_.response_size_, 0.0),
585 tmp_(ext_.response_size_),
586 count_(0)
587 {}
588
589 template<class Iter>
590 double increment (Iter begin, Iter end)
591 {
592 for(Iter iter = begin; iter != end; ++iter)
593 {
594 ++count_;
595 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
596 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
597 double f = 1.0 / count_,
598 f1 = 1.0 - f;
599 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
600 mean_[ii] += f*tmp_[ii];
601 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
602 variance_[ii] += f1*sq(tmp_[ii]);
603 }
604 double res = std::accumulate(variance_.begin(),
605 variance_.end(),
606 0.0,
607 std::plus<double>());
608 //std::cerr << res << " ) = ";
609 return res;
610 }
611
612 template<class Iter> //This is BROKEN
613 double decrement (Iter begin, Iter end)
614 {
615 for(Iter iter = begin; iter != end; ++iter)
616 {
617 --count_;
618 }
619
620 begin = end;
621 end = end + count_;
622
623
624 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
625 {
626 mean_[ii] = 0;
627 for(Iter iter = begin; iter != end; ++iter)
628 {
629 mean_[ii] += labels_(*iter, ii);
630 }
631 mean_[ii] /= count_;
632 variance_[ii] = 0;
633 for(Iter iter = begin; iter != end; ++iter)
634 {
635 variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]);
636 }
637 }
638 double res = std::accumulate(variance_.begin(),
639 variance_.end(),
640 0.0,
641 std::plus<double>());
642 //std::cerr << res << " ) = ";
643 return res;
644 }
645
646
647 template<class Iter, class Resp_t>
648 double init (Iter begin, Iter end, Resp_t /*resp*/)
649 {
650 reset();
651 return this->increment(begin, end);
652
653 }
654
655
656 ArrayVector<double> const & response()
657 {
658 return mean_;
659 }
660
661 void reset()
662 {
663 mean_.init(0.0);
664 variance_.init(0.0);
665 count_ = 0;
666 }
667 };
668
669
670template <class DataSource>
671class RegressionForestCounter2
672{
673public:
674 typedef MultiArrayShape<2>::type Shp;
675 DataSource const & labels_;
676 ArrayVector <double> mean_;
677 ArrayVector <double> variance_;
678 ArrayVector <double> tmp_;
679 size_t count_;
680
681 template<class T>
682 RegressionForestCounter2(DataSource const & labels,
683 ProblemSpec<T> const & ext_)
684 :
685 labels_(labels),
686 mean_(ext_.response_size_, 0.0),
687 variance_(ext_.response_size_, 0.0),
688 tmp_(ext_.response_size_),
689 count_(0)
690 {}
691
692 template<class Iter>
693 double increment (Iter begin, Iter end)
694 {
695 for(Iter iter = begin; iter != end; ++iter)
696 {
697 ++count_;
698 for(int ii = 0; ii < mean_.size(); ++ii)
699 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
700 double f = 1.0 / count_,
701 f1 = 1.0 - f;
702 for(int ii = 0; ii < mean_.size(); ++ii)
703 mean_[ii] += f*tmp_[ii];
704 for(int ii = 0; ii < mean_.size(); ++ii)
705 variance_[ii] += f1*sq(tmp_[ii]);
706 }
707 double res = std::accumulate(variance_.begin(),
708 variance_.end(),
709 0.0,
710 std::plus<double>())
711 /((count_ == 1)? 1:(count_ -1));
712 //std::cerr << res << " ) = ";
713 return res;
714 }
715
716 template<class Iter> //This is BROKEN
717 double decrement (Iter begin, Iter end)
718 {
719 for(Iter iter = begin; iter != end; ++iter)
720 {
721 double f = 1.0 / count_,
722 f1 = 1.0 - f;
723 for(int ii = 0; ii < mean_.size(); ++ii)
724 mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f);
725 for(int ii = 0; ii < mean_.size(); ++ii)
726 variance_[ii] -= f1*sq(labels_(*iter,ii) - mean_[ii]);
727 --count_;
728 }
729 double res = std::accumulate(variance_.begin(),
730 variance_.end(),
731 0.0,
732 std::plus<double>())
733 /((count_ == 1)? 1:(count_ -1));
734 //std::cerr << "( " << res << " + ";
735 return res;
736 }
737 /* west's algorithm for incremental variance
738 // calculation
739 template<class Iter>
740 double increment (Iter begin, Iter end)
741 {
742 for(Iter iter = begin; iter != end; ++iter)
743 {
744 ++count_;
745 for(int ii = 0; ii < mean_.size(); ++ii)
746 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
747 double f = 1.0 / count_,
748 f1 = 1.0 - f;
749 for(int ii = 0; ii < mean_.size(); ++ii)
750 mean_[ii] += f*tmp_[ii];
751 for(int ii = 0; ii < mean_.size(); ++ii)
752 variance_[ii] += f1*sq(tmp_[ii]);
753 }
754 return std::accumulate(variance_.begin(),
755 variance_.end(),
756 0.0,
757 std::plus<double>())
758 /(count_ -1);
759 }
760
761 template<class Iter>
762 double decrement (Iter begin, Iter end)
763 {
764 for(Iter iter = begin; iter != end; ++iter)
765 {
766 --count_;
767 for(int ii = 0; ii < mean_.size(); ++ii)
768 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
769 double f = 1.0 / count_,
770 f1 = 1.0 + f;
771 for(int ii = 0; ii < mean_.size(); ++ii)
772 mean_[ii] -= f*tmp_[ii];
773 for(int ii = 0; ii < mean_.size(); ++ii)
774 variance_[ii] -= f1*sq(tmp_[ii]);
775 }
776 return std::accumulate(variance_.begin(),
777 variance_.end(),
778 0.0,
779 std::plus<double>())
780 /(count_ -1);
781 }*/
782
783 template<class Iter, class Resp_t>
784 double init (Iter begin, Iter end, Resp_t resp)
785 {
786 reset();
787 return this->increment(begin, end, resp);
788 }
789
790
791 ArrayVector<double> const & response()
792 {
793 return mean_;
794 }
795
796 void reset()
797 {
798 mean_.init(0.0);
799 variance_.init(0.0);
800 count_ = 0;
801 }
802};
803
804template<class Tag, class Datatyp>
805struct LossTraits;
806
807struct LSQLoss
808{};
809
810template<class Datatype>
811struct LossTraits<GiniCriterion, Datatype>
812{
813 typedef ImpurityLoss<Datatype, GiniCriterion> type;
814};
815
816template<class Datatype>
817struct LossTraits<EntropyCriterion, Datatype>
818{
819 typedef ImpurityLoss<Datatype, EntropyCriterion> type;
820};
821
822template<class Datatype>
823struct LossTraits<LSQLoss, Datatype>
824{
825 typedef RegressionForestCounter<Datatype> type;
826};
827
828/** Given a column, choose a split that minimizes some loss
829 */
830template<class LineSearchLossTag>
832{
833public:
834 ArrayVector<double> class_weights_;
835 ArrayVector<double> bestCurrentCounts[2];
836 double min_gini_;
837 std::ptrdiff_t min_index_;
838 double min_threshold_;
839 ProblemSpec<> ext_param_;
840
842 {}
843
844 template<class T>
846 :
847 class_weights_(ext.class_weights_),
848 ext_param_(ext)
849 {
850 bestCurrentCounts[0].resize(ext.class_count_);
851 bestCurrentCounts[1].resize(ext.class_count_);
852 }
853 template<class T>
854 void set_external_parameters(ProblemSpec<T> const & ext)
855 {
856 class_weights_ = ext.class_weights_;
857 ext_param_ = ext;
858 bestCurrentCounts[0].resize(ext.class_count_);
859 bestCurrentCounts[1].resize(ext.class_count_);
860 }
861 /** calculate the best gini split along a Feature Column
862 * \param column the feature vector - has to support the [] operator
863 * \param labels the label vector
864 * \param begin
865 * \param end (in and out)
866 * begin and end iterators to the indices of the
867 * samples in the current region.
868 * the range begin - end is sorted by the column supplied
869 * during function execution.
870 * \param region_response
871 * ???
872 * class histogram of the range.
873 *
874 * precondition: begin, end valid range,
875 * class_counts positive integer valued array with the
876 * class counts in the current range.
877 * labels.size() >= max(begin, end);
878 * postcondition:
879 * begin, end sorted by column given.
880 * min_gini_ contains the minimum gini found or
881 * NumericTraits<double>::max if no split was found.
882 * min_index_ contains the splitting index in the range
883 * or invalid data if no split was found.
884 * BestCirremtcounts[0] and [1] contain the
885 * class histogram of the left and right region of
886 * the left and right regions.
887 */
888 template< class DataSourceF_t,
889 class DataSource_t,
890 class I_Iter,
891 class Array>
892 void operator()(DataSourceF_t const & column,
893 DataSource_t const & labels,
894 I_Iter & begin,
895 I_Iter & end,
896 Array const & region_response)
897 {
898 std::sort(begin, end,
900 typedef typename
901 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
902 LineSearchLoss left(labels, ext_param_); //initialize left and right region
903 LineSearchLoss right(labels, ext_param_);
904
905
906
907 min_gini_ = right.init(begin, end, region_response);
908 min_threshold_ = *begin;
909 min_index_ = 0; //the starting point where to split
911
912 I_Iter iter = begin;
913 I_Iter next = std::adjacent_find(iter, end, comp);
914 //std::cerr << std::distance(begin, end) << std::endl;
915 while( next != end)
916 {
917 double lr = right.decrement(iter, next + 1);
918 double ll = left.increment(iter , next + 1);
919 double loss = lr +ll;
920 //std::cerr <<lr << " + "<< ll << " " << loss << " ";
921#ifdef CLASSIFIER_TEST
922 if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
923#else
924 if(loss < min_gini_ )
925#endif
926 {
927 bestCurrentCounts[0] = left.response();
928 bestCurrentCounts[1] = right.response();
929#ifdef CLASSIFIER_TEST
930 min_gini_ = loss < min_gini_? loss : min_gini_;
931#else
932 min_gini_ = loss;
933#endif
934 min_index_ = next - begin +1 ;
935 min_threshold_ = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
936 }
937 iter = next +1 ;
938 next = std::adjacent_find(iter, end, comp);
939 }
940 //std::cerr << std::endl << " 000 " << std::endl;
941 //int in;
942 //std::cin >> in;
943 }
944
945 template<class DataSource_t, class Iter, class Array>
946 double loss_of_region(DataSource_t const & labels,
947 Iter & begin,
948 Iter & end,
949 Array const & region_response) const
950 {
951 typedef typename
952 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
953 LineSearchLoss region_loss(labels, ext_param_);
954 return
955 region_loss.init(begin, end, region_response);
956 }
957
958};
959
960namespace detail
961{
962 template<class T>
963 struct Correction
964 {
965 template<class Region, class LabelT>
966 static void exec(Region & /*in*/, LabelT & /*labels*/)
967 {}
968 };
969
970 template<>
971 struct Correction<ClassificationTag>
972 {
973 template<class Region, class LabelT>
974 static void exec(Region & region, LabelT & labels)
975 {
976 if(std::accumulate(region.classCounts().begin(),
977 region.classCounts().end(), 0.0) != region.size())
978 {
979 RandomForestClassCounter< LabelT,
980 ArrayVector<double> >
981 counter(labels, region.classCounts());
982 std::for_each( region.begin(), region.end(), counter);
983 region.classCountsIsValid = true;
984 }
985 }
986 };
987}
988
989/** Chooses mtry columns and applies ColumnDecisionFunctor to each of the
990 * columns. Then Chooses the column that is best
991 */
992template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
993class ThresholdSplit: public SplitBase<Tag>
994{
995 public:
996
997
998 typedef SplitBase<Tag> SB;
999
1000 ArrayVector<Int32> splitColumns;
1001 ColumnDecisionFunctor bgfunc;
1002
1003 double region_gini_;
1004 ArrayVector<double> min_gini_;
1005 ArrayVector<std::ptrdiff_t> min_indices_;
1006 ArrayVector<double> min_thresholds_;
1007
1008 int bestSplitIndex;
1009
1010 double minGini() const
1011 {
1012 return min_gini_[bestSplitIndex];
1013 }
1014 int bestSplitColumn() const
1015 {
1016 return splitColumns[bestSplitIndex];
1017 }
1018 double bestSplitThreshold() const
1019 {
1020 return min_thresholds_[bestSplitIndex];
1021 }
1022
1023 template<class T>
1024 void set_external_parameters(ProblemSpec<T> const & in)
1025 {
1026 SB::set_external_parameters(in);
1027 bgfunc.set_external_parameters( SB::ext_param_);
1028 int featureCount_ = SB::ext_param_.column_count_;
1029 splitColumns.resize(featureCount_);
1030 for(int k=0; k<featureCount_; ++k)
1031 splitColumns[k] = k;
1032 min_gini_.resize(featureCount_);
1033 min_indices_.resize(featureCount_);
1034 min_thresholds_.resize(featureCount_);
1035 }
1036
1037
1038 template<class T, class C, class T2, class C2, class Region, class Random>
1039 int findBestSplit(MultiArrayView<2, T, C> features,
1041 Region & region,
1043 Random & randint)
1044 {
1045
1046 typedef typename Region::IndexIterator IndexIterator;
1047 if(region.size() == 0)
1048 {
1049 std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
1050 "continuing learning process....";
1051 }
1052 // calculate things that haven't been calculated yet.
1053 detail::Correction<Tag>::exec(region, labels);
1054
1055
1056 // Is the region pure already?
1057 region_gini_ = bgfunc.loss_of_region(labels,
1058 region.begin(),
1059 region.end(),
1060 region.classCounts());
1061 if(region_gini_ <= SB::ext_param_.precision_)
1062 return this->makeTerminalNode(features, labels, region, randint);
1063
1064 // select columns to be tried.
1065 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
1066 std::swap(splitColumns[ii],
1067 splitColumns[ii+ randint(features.shape(1) - ii)]);
1068
1069 // find the best gini index
1070 bestSplitIndex = 0;
1071 double current_min_gini = region_gini_;
1072 int num2try = features.shape(1);
1073 for(int k=0; k<num2try; ++k)
1074 {
1075 //this functor does all the work
1076 bgfunc(columnVector(features, splitColumns[k]),
1077 labels,
1078 region.begin(), region.end(),
1079 region.classCounts());
1080 min_gini_[k] = bgfunc.min_gini_;
1081 min_indices_[k] = bgfunc.min_index_;
1082 min_thresholds_[k] = bgfunc.min_threshold_;
1083#ifdef CLASSIFIER_TEST
1084 if( bgfunc.min_gini_ < current_min_gini
1085 && !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
1086#else
1087 if(bgfunc.min_gini_ < current_min_gini)
1088#endif
1089 {
1090 current_min_gini = bgfunc.min_gini_;
1091 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
1092 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
1093 childRegions[0].classCountsIsValid = true;
1094 childRegions[1].classCountsIsValid = true;
1095
1096 bestSplitIndex = k;
1097 num2try = SB::ext_param_.actual_mtry_;
1098 }
1099 }
1100 //std::cerr << current_min_gini << "curr " << region_gini_ << std::endl;
1101 // did not find any suitable split
1102 // FIXME: this is wrong: sometimes we must execute bad splits to make progress,
1103 // especially near the root.
1104 if(closeAtTolerance(current_min_gini, region_gini_))
1105 return this->makeTerminalNode(features, labels, region, randint);
1106
1107 //create a Node for output
1108 Node<i_ThresholdNode> node(SB::t_data, SB::p_data);
1109 SB::node_ = node;
1110 node.threshold() = min_thresholds_[bestSplitIndex];
1111 node.column() = splitColumns[bestSplitIndex];
1112
1113 // partition the range according to the best dimension
1115 sorter(features, node.column(), node.threshold());
1116 IndexIterator bestSplit =
1117 std::partition(region.begin(), region.end(), sorter);
1118 // Save the ranges of the child stack entries.
1119 childRegions[0].setRange( region.begin() , bestSplit );
1120 childRegions[0].rule = region.rule;
1121 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
1122 childRegions[1].setRange( bestSplit , region.end() );
1123 childRegions[1].rule = region.rule;
1124 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
1125
1126 return i_ThresholdNode;
1127 }
1128};
1129
1133
1134namespace rf
1135{
1136
1137/** This namespace contains additional Splitfunctors.
1138 *
1139 * The Split functor classes are designed in a modular fashion because new split functors may
1140 * share a lot of code with existing ones.
1141 *
1142 * ThresholdSplit implements the functionality needed for any split functor, that makes its
1143 * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split
1144 * along one dimension is chosen.
1145 *
1146 * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied
1147 * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a
1148 * kD tree fashion.
1149 *
1150 *
1151 * Currently defined typedefs:
1152 * \code
1153 * typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit;
1154 * typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit;
1155 * typedef ThresholdSplit<Median> MedianSplit;
1156 * \endcode
1157 */
1158namespace split
1159{
1160
1161/** This Functor chooses the median value of a column
1162 */
1164{
1165public:
1166
1168 ArrayVector<double> class_weights_;
1169 ArrayVector<double> bestCurrentCounts[2];
1170 double min_gini_;
1171 std::ptrdiff_t min_index_;
1172 double min_threshold_;
1173 ProblemSpec<> ext_param_;
1174
1175 Median()
1176 {}
1177
1178 template<class T>
1179 Median(ProblemSpec<T> const & ext)
1180 :
1181 class_weights_(ext.class_weights_),
1182 ext_param_(ext)
1183 {
1184 bestCurrentCounts[0].resize(ext.class_count_);
1185 bestCurrentCounts[1].resize(ext.class_count_);
1186 }
1187
1188 template<class T>
1189 void set_external_parameters(ProblemSpec<T> const & ext)
1190 {
1191 class_weights_ = ext.class_weights_;
1192 ext_param_ = ext;
1193 bestCurrentCounts[0].resize(ext.class_count_);
1194 bestCurrentCounts[1].resize(ext.class_count_);
1195 }
1196
1197 template< class DataSourceF_t,
1198 class DataSource_t,
1199 class I_Iter,
1200 class Array>
1201 void operator()(DataSourceF_t const & column,
1202 DataSource_t const & labels,
1203 I_Iter & begin,
1204 I_Iter & end,
1205 Array const & region_response)
1206 {
1207 std::sort(begin, end,
1209 typedef typename
1210 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1211 LineSearchLoss left(labels, ext_param_);
1212 LineSearchLoss right(labels, ext_param_);
1213 right.init(begin, end, region_response);
1214
1215 min_gini_ = NumericTraits<double>::max();
1216 min_index_ = floor(double(end - begin)/2.0);
1217 min_threshold_ = column[*(begin + min_index_)];
1219 sorter(column, 0, min_threshold_);
1220 I_Iter part = std::partition(begin, end, sorter);
1222 if(part == begin)
1223 {
1224 part= std::adjacent_find(part, end, comp)+1;
1225
1226 }
1227 if(part >= end)
1228 {
1229 return;
1230 }
1231 else
1232 {
1233 min_threshold_ = column[*part];
1234 }
1235 min_gini_ = right.decrement(begin, part)
1236 + left.increment(begin , part);
1237
1238 bestCurrentCounts[0] = left.response();
1239 bestCurrentCounts[1] = right.response();
1240
1241 min_index_ = part - begin;
1242 }
1243
1244 template<class DataSource_t, class Iter, class Array>
1245 double loss_of_region(DataSource_t const & labels,
1246 Iter & begin,
1247 Iter & end,
1248 Array const & region_response) const
1249 {
1250 typedef typename
1251 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1252 LineSearchLoss region_loss(labels, ext_param_);
1253 return
1254 region_loss.init(begin, end, region_response);
1255 }
1256
1257};
1258
1260
1261
1262/** This Functor chooses a random value of a column
1263 */
1265{
1266public:
1267
1269 ArrayVector<double> class_weights_;
1270 ArrayVector<double> bestCurrentCounts[2];
1271 double min_gini_;
1272 std::ptrdiff_t min_index_;
1273 double min_threshold_;
1274 ProblemSpec<> ext_param_;
1275 typedef RandomMT19937 Random_t;
1276 Random_t random;
1277
1279 {}
1280
1281 template<class T>
1283 :
1284 class_weights_(ext.class_weights_),
1285 ext_param_(ext),
1286 random(RandomSeed)
1287 {
1288 bestCurrentCounts[0].resize(ext.class_count_);
1289 bestCurrentCounts[1].resize(ext.class_count_);
1290 }
1291
1292 template<class T>
1294 :
1295 class_weights_(ext.class_weights_),
1296 ext_param_(ext),
1297 random(random_)
1298 {
1299 bestCurrentCounts[0].resize(ext.class_count_);
1300 bestCurrentCounts[1].resize(ext.class_count_);
1301 }
1302
1303 template<class T>
1304 void set_external_parameters(ProblemSpec<T> const & ext)
1305 {
1306 class_weights_ = ext.class_weights_;
1307 ext_param_ = ext;
1308 bestCurrentCounts[0].resize(ext.class_count_);
1309 bestCurrentCounts[1].resize(ext.class_count_);
1310 }
1311
1312 template< class DataSourceF_t,
1313 class DataSource_t,
1314 class I_Iter,
1315 class Array>
1316 void operator()(DataSourceF_t const & column,
1317 DataSource_t const & labels,
1318 I_Iter & begin,
1319 I_Iter & end,
1320 Array const & region_response)
1321 {
1322 std::sort(begin, end,
1324 typedef typename
1325 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1326 LineSearchLoss left(labels, ext_param_);
1327 LineSearchLoss right(labels, ext_param_);
1328 right.init(begin, end, region_response);
1329
1330
1331 min_gini_ = NumericTraits<double>::max();
1332 int tmp_pt = random.uniformInt(std::distance(begin, end));
1333 min_index_ = tmp_pt;
1334 min_threshold_ = column[*(begin + min_index_)];
1336 sorter(column, 0, min_threshold_);
1337 I_Iter part = std::partition(begin, end, sorter);
1339 if(part == begin)
1340 {
1341 part= std::adjacent_find(part, end, comp)+1;
1342
1343 }
1344 if(part >= end)
1345 {
1346 return;
1347 }
1348 else
1349 {
1350 min_threshold_ = column[*part];
1351 }
1352 min_gini_ = right.decrement(begin, part)
1353 + left.increment(begin , part);
1354
1355 bestCurrentCounts[0] = left.response();
1356 bestCurrentCounts[1] = right.response();
1357
1358 min_index_ = part - begin;
1359 }
1360
1361 template<class DataSource_t, class Iter, class Array>
1362 double loss_of_region(DataSource_t const & labels,
1363 Iter & begin,
1364 Iter & end,
1365 Array const & region_response) const
1366 {
1367 typedef typename
1368 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1369 LineSearchLoss region_loss(labels, ext_param_);
1370 return
1371 region_loss.init(begin, end, region_response);
1372 }
1373
1374};
1375
1377}
1378}
1379
1380
1381} //namespace vigra
1382#endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
Definition rf_split.hxx:832
void operator()(DataSourceF_t const &column, DataSource_t const &labels, I_Iter &begin, I_Iter &end, Array const &region_response)
Definition rf_split.hxx:892
Definition rf_split.hxx:357
static double impurity(Array const &hist, double total)
Definition rf_split.hxx:381
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition rf_split.hxx:363
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition rf_split.hxx:389
double operator()(Array const &hist, double total=1.0) const
Definition rf_split.hxx:373
Definition rf_split.hxx:419
static double impurity(Array const &hist, double total)
Definition rf_split.hxx:443
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition rf_split.hxx:425
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition rf_split.hxx:451
double operator()(Array const &hist, double total=1.0) const
Definition rf_split.hxx:435
Definition rf_nodeproxy.hxx:88
Class for a single RGB value.
Definition rgbvalue.hxx:128
Definition rf_split.hxx:306
Definition rf_split.hxx:202
Definition rf_split.hxx:93
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition rf_split.hxx:150
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region &region, Random)
Definition rf_split.hxx:168
void set_external_parameters(ProblemSpec< T > const &in)
Definition rf_split.hxx:112
void reset()
Definition rf_split.hxx:137
Definition rf_split.hxx:994
void init(Iterator i, Iterator end)
Definition tinyvector.hxx:708
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Definition rf_split.hxx:1164
Definition rf_split.hxx:1265
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition fixedpoint.hxx:667
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition mathutil.hxx:1638
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1