learn_fast_tree
executable.
learn_fast_tree [--weight.
x weight ] ... <
infile >
outfile
learn_fast_tree
used ID3 to learn a ternary decision tree for corner detection. The data is read from the standard input, and the tree is written to the standard output. This is designed to learn FAST feature detectors, and does not allow for the possibility ambbiguity in the input data.
5 [-1 -1] [1 1] [3 4] [5 6] [-3 4] bbbbb 1 0 bsdsb 1000 1 . . .
The first row is the number of features. The second row is the the list of offsets assosciated with each feature. This list has no effect on the learning of the tree, but it is passed through to the outpur for convinience.
The remaining rows contain the data. The first field is the ternary feature vector. The three characters "b", "d" and "s" are the correspond to brighter, darker and similar respectively, with the first feature being stored in the first character and so on.
The next field is the number of instances of the particular feature. The third field is the class, with 1 for corner, and 0 for background.
Additionally, a the program fast_N_features can be used to generate all possible feature combinations for FAST-N features. When run without arguments, it generates data for FAST-9 features, otherwise the argument can be used to specify N.
The structure of the tree is described in detail in print_tree.
Definition in file learn_fast_tree.cc.
Go to the source code of this file.
Classes | |
struct | datapoint< FEATURE_SIZE > |
This structure represents a datapoint. More... | |
struct | tree |
This class represents a decision tree. More... | |
Defines | |
#define | fatal(E, S,...) vfatal((E), (S), (tag::Fmt,## __VA_ARGS__)) |
Enumerations | |
enum | Ternary { Brighter = 'b', Darker = 'd', Similar = 's' } |
Functions | |
template<class C> | |
void | vfatal (int err, const string &s, const C &list) |
template<int S> | |
V_tuple< shared_ptr < vector< datapoint < S > > >, uint64_t > ::type | load_features (unsigned int nfeats) |
double | entropy (uint64_t n, uint64_t c1) |
template<int S> | |
int | find_best_split (const vector< datapoint< S > > &fs, const vector< double > &weights, unsigned int nfeats) |
template<int S> | |
shared_ptr< tree > | build_tree (vector< datapoint< S > > &corners, const vector< double > &weights, int nfeats) |
void | print_tree (const tree *node, ostream &o, const string &i="") |
template<int S> | |
V_tuple< shared_ptr < tree >, uint64_t > ::type | load_and_build_tree (unsigned int num_features, const vector< double > &weights) |
int | main (int argc, char **argv) |
enum Ternary |
V_tuple<shared_ptr<vector<datapoint<S> > >, uint64_t >::type load_features | ( | unsigned int | nfeats | ) | [inline] |
This function loads as many datapoints from the standard input as possible.
Datapoints consist of a feature vector (a string containing the characters "b", "d" and "s"), a number of instances and a class.
See datapoint::pack_trits for a more complete description of the feature vector.
The tokens are whitespace separated.
nfeats | Number of features in a feature vector. Used to spot errors. |
Definition at line 246 of file learn_fast_tree.cc.
References fatal.
00247 { 00248 shared_ptr<vector<datapoint<S> > > ret(new vector<datapoint<S> >); 00249 00250 00251 string unpacked_feature; 00252 00253 uint64_t total_num = 0; 00254 00255 uint64_t line_num=2; 00256 00257 for(;;) 00258 { 00259 uint64_t count; 00260 bool is; 00261 00262 cin >> unpacked_feature >> count >> is; 00263 00264 if(!cin) 00265 break; 00266 00267 line_num++; 00268 00269 if(unpacked_feature.size() != nfeats) 00270 fatal(1, "Feature string length is %i, not %i on line %i", unpacked_feature.size(), nfeats, line_num); 00271 00272 if(count == 0) 00273 fatal(4, "Zero count is invalid"); 00274 00275 ret->push_back(datapoint<S>(unpacked_feature, count, is)); 00276 00277 total_num += count; 00278 } 00279 00280 cerr << "Num features: " << total_num << endl 00281 << "Num distinct: " << ret->size() << endl; 00282 00283 return make_vtuple(ret, total_num); 00284 }
double entropy | ( | uint64_t | n, | |
uint64_t | c1 | |||
) |
Compute the entropy of a set with binary annotations.
n | Number of elements in the set | |
c1 | Number of elements in class 1 |
Definition at line 291 of file learn_fast_tree.cc.
Referenced by find_best_split().
00292 { 00293 assert(c1 <= n); 00294 //n is total number, c1 in num in class 1 00295 if(n == 0) 00296 return 0; 00297 else if(c1 == 0 || c1 == n) 00298 return 0; 00299 else 00300 { 00301 double p1 = (double)c1 / n; 00302 double p2 = 1-p1; 00303 00304 return -(double)n*(p1*log(p1) + p2*log(p2)) / log(2.f); 00305 } 00306 }
int find_best_split | ( | const vector< datapoint< S > > & | fs, | |
const vector< double > & | weights, | |||
unsigned int | nfeats | |||
) | [inline] |
Find the feature that has the highest weighted entropy change.
fs | datapoints to split in to three subsets. | |
weights | weights on features | |
nfeats | Number of features in use. |
Definition at line 313 of file learn_fast_tree.cc.
References Brighter, Darker, entropy(), fatal, and Similar.
00314 { 00315 assert(nfeats == weights.size()); 00316 uint64_t num_total = 0, num_corners=0; 00317 00318 for(typename vector<datapoint<S> >::const_iterator i=fs.begin(); i != fs.end(); i++) 00319 { 00320 num_total += i->count; 00321 if(i->is_a_corner) 00322 num_corners += i->count; 00323 } 00324 00325 double total_entropy = entropy(num_total, num_corners); 00326 00327 double biggest_delta = 0; 00328 int feature_num = -1; 00329 00330 for(unsigned int i=0; i < nfeats; i++) 00331 { 00332 uint64_t num_bri = 0, num_dar = 0, num_sim = 0; 00333 uint64_t cor_bri = 0, cor_dar = 0, cor_sim = 0; 00334 00335 for(typename vector<datapoint<S> >::const_iterator f=fs.begin(); f != fs.end(); f++) 00336 { 00337 switch(f->get_trit(i)) 00338 { 00339 case Brighter: 00340 num_bri += f->count; 00341 if(f->is_a_corner) 00342 cor_bri += f->count; 00343 break; 00344 00345 case Darker: 00346 num_dar += f->count; 00347 if(f->is_a_corner) 00348 cor_dar += f->count; 00349 break; 00350 00351 case Similar: 00352 num_sim += f->count; 00353 if(f->is_a_corner) 00354 cor_sim += f->count; 00355 break; 00356 } 00357 } 00358 00359 double delta_e = total_entropy - (entropy(num_bri, cor_bri) + entropy(num_dar, cor_dar) + entropy(num_sim, cor_sim)); 00360 00361 delta_e *= weights[i]; 00362 00363 if(delta_e > biggest_delta) 00364 { 00365 biggest_delta = delta_e; 00366 feature_num = i; 00367 } 00368 } 00369 00370 if(feature_num == -1) 00371 fatal(3, "Couldn't find a split."); 00372 00373 return feature_num; 00374 }
shared_ptr<tree> build_tree | ( | vector< datapoint< S > > & | corners, | |
const vector< double > & | weights, | |||
int | nfeats | |||
) | [inline] |
This function uses ID3 to construct a decision tree.
The entropy changes are weighted by the list of weights, to allow bias towards certain features. This function assumes that the class is an exact function of the data. If there datapoints with different classes share the same feature vector, the program will crash with error code 3.
corners | Datapoints in this part of the subtree to classify | |
weights | Weights on the features | |
nfeats | Number of features actually used |
Definition at line 468 of file learn_fast_tree.cc.
References Brighter, tree::CornerLeaf(), Darker, tree::NonCornerLeaf(), and Similar.
00469 { 00470 //Find the split 00471 int f = find_best_split<S>(corners, weights, nfeats); 00472 00473 //Split corners in to the three chunks, based on the result of find_best_split. 00474 //Also, count how many of each class ends up in each of the three bins. 00475 //It may apper to be inefficient to use a vector here instead of a list, in terms 00476 //of memory, but the per-element storage overhead of the list is such that it uses 00477 //considerably more memory and is much slower. 00478 vector<datapoint<S> > brighter, darker, similar; 00479 uint64_t num_bri=0, cor_bri=0, num_dar=0, cor_dar=0, num_sim=0, cor_sim=0; 00480 00481 for(size_t i=0; i < corners.size(); i++) 00482 { 00483 switch(corners[i].get_trit(f)) 00484 { 00485 case Brighter: 00486 brighter.push_back(corners[i]); 00487 num_bri += corners[i].count; 00488 if(corners[i].is_a_corner) 00489 cor_bri += corners[i].count; 00490 break; 00491 00492 case Darker: 00493 darker.push_back(corners[i]); 00494 num_dar += corners[i].count; 00495 if(corners[i].is_a_corner) 00496 cor_dar += corners[i].count; 00497 break; 00498 00499 case Similar: 00500 similar.push_back(corners[i]); 00501 num_sim += corners[i].count; 00502 if(corners[i].is_a_corner) 00503 cor_sim += corners[i].count; 00504 break; 00505 } 00506 } 00507 00508 //Deallocate the memory now it's no longer needed. 00509 corners.clear(); 00510 00511 //This is not the same as corners.size(), since the corners (datapoints) 00512 //have a count assosciated with them. 00513 uint64_t num_tests = num_bri + num_dar + num_sim; 00514 00515 00516 //Build the subtrees 00517 shared_ptr<tree> b_tree, d_tree, s_tree; 00518 00519 00520 //If the sublist contains a single class, then instantiate a leaf, 00521 //otherwise recursively build the tree. 00522 if(cor_bri == 0) 00523 b_tree = tree::NonCornerLeaf(num_bri); 00524 else if(cor_bri == num_bri) 00525 b_tree = tree::CornerLeaf(num_bri); 00526 else 00527 b_tree = build_tree<S>(brighter, weights, nfeats); 00528 00529 00530 if(cor_dar == 0) 00531 d_tree = tree::NonCornerLeaf(num_dar); 00532 else if(cor_dar == num_dar) 00533 d_tree = tree::CornerLeaf(num_dar); 00534 else 00535 d_tree = build_tree<S>(darker, weights, nfeats); 00536 00537 00538 if(cor_sim == 0) 00539 s_tree = tree::NonCornerLeaf(num_sim); 00540 else if(cor_sim == num_sim) 00541 s_tree = tree::CornerLeaf(num_sim); 00542 else 00543 s_tree = build_tree<S>(similar, weights, nfeats); 00544 00545 return shared_ptr<tree>(new tree(b_tree, d_tree, s_tree, f, num_tests)); 00546 }
void print_tree | ( | const tree * | node, | |
ostream & | o, | |||
const string & | i = "" | |||
) |
This function traverses the tree and produces a textual representation of it.
Additionally, if any of the subtrees are the same, then a single subtree is produced and the test is removed.
A subtree has the following format:
subtree= lead | node; leaf = "corner" | "background" ; node = node2 | node3; node3 = "if_brighter" feature_number n1 n2 n3 subtree "elsf_darker" feature_number subtree "else" subtree "end"; node2= if_statement feature_number n1 n2 subtree "else" subtree "end"; if_statement = "if_brighter" | "if_darker" | "if_either"; feature_number ==integer; n1 = integer; n2 = integer; n3 = integer;
feature_number refers to the index of the feature that the test is performed on.
In node3, a 3 way test is performed. n1, n2 and n3 refer to the number of training examples landing in the if block, the elfs block and the else block respectivly.
In a node2 node, one of the tests has been removed. n1 and n2refer to the number of training examples landing in the if block and the else block respectivly.
Although not mentioned in the grammar, the indenting is kept very strict.
This representation has been designed to be parsed very easily with simple regular expressions, hence the use if "elsf" as opposed to "elif" or "elseif".
node | (sub)tree to serialize | |
o | Stream to serialize to. | |
i | Indent to print before each line of the serialized tree. |
Definition at line 601 of file learn_fast_tree.cc.
References tree::brighter, tree::Corner, tree::darker, tree::feature_to_test, tree::is_a_corner, tree::NonCorner, tree::num_datapoints, and tree::similar.
Referenced by main().
00602 { 00603 if(node->is_a_corner == tree::Corner) 00604 o << i << "corner" << endl; 00605 else if(node->is_a_corner == tree::NonCorner) 00606 o << i << "background" << endl; 00607 else 00608 { 00609 string b = node->brighter->stringify(); 00610 string d = node->darker->stringify(); 00611 string s = node->similar->stringify(); 00612 00613 const tree * bt = node->brighter.get(); 00614 const tree * dt = node->darker.get(); 00615 const tree * st = node->similar.get(); 00616 string ii = i + " "; 00617 00618 int f = node->feature_to_test; 00619 00620 if(b == d && d == s) //All the same 00621 { 00622 //o << i << "if " << f << " is whatever\n"; 00623 print_tree(st, o, i); 00624 } 00625 else if(d == s) //Bright is different 00626 { 00627 o << i << "if_brighter " << f << " " << bt->num_datapoints << " " << dt->num_datapoints+st->num_datapoints << endl; 00628 print_tree(bt, o, ii); 00629 o << i << "else" << endl; 00630 print_tree(st, o, ii); 00631 o << i << "end" << endl; 00632 00633 } 00634 else if(b == s) //Dark is different 00635 { 00636 o << i << "if_darker " << f << " " << dt->num_datapoints << " " << bt->num_datapoints + st->num_datapoints << endl; 00637 print_tree(dt, o, ii); 00638 o << i << "else" << endl; 00639 print_tree(st, o, ii); 00640 o << i << "end" << endl; 00641 } 00642 else if(b == d) //Similar is different 00643 { 00644 o << i << "if_either " << f << " " << bt->num_datapoints + dt->num_datapoints << " " << st->num_datapoints << endl; 00645 print_tree(bt, o, ii); 00646 o << i << "else" << endl; 00647 print_tree(st, o, ii); 00648 o << i << "end" << endl; 00649 } 00650 else //All different 00651 { 00652 o << i << "if_brighter " << f << " " << bt->num_datapoints << " " << dt->num_datapoints << " " << st->num_datapoints << endl; 00653 print_tree(bt, o, ii); 00654 o << i << "elsf_darker " << f << endl; 00655 print_tree(dt, o, ii); 00656 o << i << "else" << endl; 00657 print_tree(st, o, ii); 00658 o << i << "end" << endl; 00659 } 00660 } 00661 }
V_tuple<shared_ptr<tree>, uint64_t>::type load_and_build_tree | ( | unsigned int | num_features, | |
const vector< double > & | weights | |||
) | [inline] |
This function loads data and builds a tree.
It is templated because datapoint is templated, for reasons of memory efficiency.
num_features | Number of features used | |
weights | Weights on each feature. |
Definition at line 668 of file learn_fast_tree.cc.
00669 { 00670 assert(weights.size() == num_features); 00671 00672 shared_ptr<vector<datapoint<S> > > l; 00673 uint64_t num_datapoints; 00674 00675 //Load the data 00676 make_rtuple(l, num_datapoints) = load_features<S>(num_features); 00677 00678 cerr << "Loaded.\n"; 00679 00680 //Build the tree 00681 shared_ptr<tree> tree; 00682 tree = build_tree<S>(*l, weights, num_features); 00683 00684 return make_vtuple(tree, num_datapoints); 00685 }
int main | ( | int | argc, | |
char ** | argv | |||
) |
The main program.
argc | Number of commandline arguments | |
argv | Commandline arguments |
Each feature takes up 2 bits. Since GCC doesn't pack any finer then 32 bits for hetrogenous structs, there is no point in having granularity finer than 16 features.
Definition at line 692 of file learn_fast_tree.cc.
References fatal, offsets, and print_tree().
00693 { 00694 //Set up default arguments 00695 GUI.parseArguments(argc, argv); 00696 00697 cin.sync_with_stdio(false); 00698 cout.sync_with_stdio(false); 00699 00700 00701 /////////////////// 00702 //read file 00703 00704 //Read number of features 00705 unsigned int num_features; 00706 cin >> num_features; 00707 if(!cin.good() || cin.eof()) 00708 fatal(6, "Error reading number of features."); 00709 00710 //Read offset list 00711 vector<ImageRef> offsets(num_features); 00712 for(unsigned int i=0; i < num_features; i++) 00713 cin >> offsets[i]; 00714 if(!cin.good() || cin.eof()) 00715 fatal(7, "Error reading offset list."); 00716 00717 //Read weights for the various offsets 00718 vector<double> weights(offsets.size()); 00719 for(unsigned int i=0; i < weights.size(); i++) 00720 weights[i] = GV3::get<double>(sPrintf("weights.%i", i), 1, 1); 00721 00722 00723 shared_ptr<tree> tree; 00724 uint64_t num_datapoints; 00725 00726 ///Each feature takes up 2 bits. Since GCC doesn't pack any finer 00727 ///then 32 bits for hetrogenous structs, there is no point in having 00728 ///granularity finer than 16 features. 00729 if(num_features <= 16) 00730 make_rtuple(tree, num_datapoints) = load_and_build_tree<16>(num_features, weights); 00731 else if(num_features <= 32) 00732 make_rtuple(tree, num_datapoints) = load_and_build_tree<32>(num_features, weights); 00733 else if(num_features <= 48) 00734 make_rtuple(tree, num_datapoints) = load_and_build_tree<48>(num_features, weights); 00735 else if(num_features <= 64) 00736 make_rtuple(tree, num_datapoints) = load_and_build_tree<64>(num_features, weights); 00737 else 00738 fatal(8, "Too many feratures (%i). To learn from this, see %s, line %i.", num_features, __FILE__, __LINE__); 00739 00740 00741 cout << num_features << endl; 00742 copy(offsets.begin(), offsets.end(), ostream_iterator<ImageRef>(cout, " ")); 00743 cout << endl; 00744 print_tree(tree.get(), cout); 00745 }