GEL  2
GEL is a library for Geometry and Linear Algebra
/Users/jab/Documents/Teaching/02585/GEL2_and_demos/GEL/src/Geometry/KDTree.h
00001 #ifndef __GEOMETRY_KDTREE_H
00002 #define __GEOMETRY_KDTREE_H
00003 
00004 #include <cmath>
00005 #include <iostream>
00006 #include <vector>
00007 #include <algorithm>
00008 #include "CGLA/CGLA.h"
00009 #include "CGLA/ArithVec.h"
00010 
00011 #if (_MSC_VER >= 1200)
00012 #pragma warning (push)
00013 #pragma warning (disable: 4018)
00014 #endif
00015 
00016 namespace Geometry
00017 {
00023         template<class KeyT, class ValT>
00024         class KDTree
00025         {
00026                 typedef typename KeyT::ScalarType ScalarType;
00027                 typedef KeyT KeyType;
00028                 typedef std::vector<KeyT> KeyVectorType;
00029                 typedef std::vector<ValT> ValVectorType;
00030         
00032                 struct KDNode
00033                 {
00034                         KeyT key;
00035                         ValT val;
00036                         short dsc;
00037 
00038                         KDNode(): dsc(0) {}
00039 
00040                         KDNode(const KeyT& _key, const ValT& _val):
00041                                 key(_key), val(_val), dsc(-1) {}
00042 
00043                         ScalarType dist(const KeyType& p) const 
00044                         {
00045                                 KeyType dist_vec = p;
00046                                 dist_vec  -= key;
00047                                 return dot(dist_vec, dist_vec);
00048                         }
00049                 };
00050 
00051                 typedef std::vector<KDNode> NodeVecType;
00052                 bool is_built;
00053                 NodeVecType init_nodes;
00054                 NodeVecType nodes;
00055 
00059                 class Comp
00060                 {
00061                         const int dsc;
00062                 public:
00063                         Comp(int _dsc): dsc(_dsc) {}
00064                         bool operator()(const KeyType& k0, const KeyType& k1) const
00065                         {
00066                                 int dim=KeyType::get_dim();
00067                                 for(int i=0;i<dim;i++)
00068                                         {
00069                                                 int j=(dsc+i)%dim;
00070                                                 if(k0[j]<k1[j])
00071                                                         return true;
00072                                                 if(k0[j]>k1[j])
00073                                                         return false;
00074                                         }
00075                                 return false;
00076                         }
00077 
00078                         bool operator()(const KDNode& k0, const KDNode& k1) const
00079                         {
00080                                 return (*this)(k0.key,k1.key);
00081                         }
00082                 };
00083 
00084 
00087                 void optimize(int, int, int);
00088 
00090                 int closest_point_priv(int, const KeyType&, ScalarType&) const;
00091 
00092                                                         
00093                 void in_sphere_priv(int n, 
00094                                                                                                 const KeyType& p, 
00095                                                                                                 const ScalarType& dist,
00096                                                                                                 std::vector<KeyT>& keys,
00097                                                                                                 std::vector<ValT>& vals) const;
00098         
00103                 int opt_disc(int,int) const;
00104         
00105         public:
00106 
00108                 KDTree(): is_built(false), init_nodes(1) {}
00109 
00112                 void insert(const KeyT& key, const ValT& val)
00113                 {
00114                                 if(is_built)
00115                                 {
00116                                                 assert(init_nodes.size()==1);
00117                                                 init_nodes.swap(nodes);
00118                                                 is_built=false;
00119                                 }
00120                                 init_nodes.push_back(KDNode(key,val));
00121                 }
00122 
00125                 void build()
00126                 {
00127                         assert(!is_built);
00128                         nodes.resize(init_nodes.size());
00129                         if(init_nodes.size() > 1)       
00130                                 optimize(1,1,init_nodes.size());
00131                         NodeVecType v(1);
00132                         init_nodes.swap(v);
00133                         is_built = true;
00134                 }
00135 
00141                 bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const
00142                 {
00143                         assert(is_built);
00144                         if(nodes.size()>1)
00145                         {
00146                                         ScalarType max_sq_dist = CGLA::sqr(dist);
00147                                         if(int n = closest_point_priv(1, p, max_sq_dist))
00148                                         {
00149                                                         k = nodes[n].key;
00150                                                         v = nodes[n].val;
00151                                                         dist = std::sqrt(max_sq_dist);
00152                                                         return true;
00153                                         }
00154                         }
00155                         return false;
00156                 }
00157 
00165                 int in_sphere(const KeyType& p, 
00166                                                                         ScalarType dist,
00167                                                                         std::vector<KeyT>& keys,
00168                                                                         std::vector<ValT>& vals) const
00169                 {
00170                         assert(is_built);
00171                         if(nodes.size()>1)
00172                         {
00173                                         ScalarType max_sq_dist = CGLA::sqr(dist);
00174                                         in_sphere_priv(1,p,max_sq_dist,keys,vals);
00175                                         return keys.size();
00176                         }
00177                         return 0;
00178                 }
00179                 
00180 
00181         };
00182 
00183         template<class KeyT, class ValT>
00184         int KDTree<KeyT,ValT>::opt_disc(int kvec_beg,  
00185                                                                                                                                         int kvec_end) const 
00186         {
00187                 KeyType vmin = init_nodes[kvec_beg].key;
00188                 KeyType vmax = init_nodes[kvec_beg].key;
00189                 for(int i=kvec_beg;i<kvec_end;i++)
00190                         {
00191                                 vmin = CGLA::v_min(vmin,init_nodes[i].key);
00192                                 vmax = CGLA::v_max(vmax,init_nodes[i].key);
00193                         }
00194                 int od=0;
00195                 KeyType ave_v = vmax-vmin;
00196                 for(int i=1;i<KeyType::get_dim();i++)
00197                         if(ave_v[i]>ave_v[od]) od = i;
00198                 return od;
00199         } 
00200 
00201         template<class KeyT, class ValT>
00202         void KDTree<KeyT,ValT>::optimize(int cur,
00203                                                                                                                                          int kvec_beg,  
00204                                                                                                                                          int kvec_end)
00205         {
00206                 // Assert that we are not inserting beyond capacity.
00207                 assert(cur < nodes.size());
00208 
00209                 // If there is just a single element, we simply insert.
00210                 if(kvec_beg+1==kvec_end) 
00211                         {
00212                                 nodes[cur] = init_nodes[kvec_beg];
00213                                 nodes[cur].dsc = -1;
00214                                 return;
00215                         }
00216         
00217                 // Find the axis that best separates the data.
00218                 int disc = opt_disc(kvec_beg, kvec_end);
00219 
00220                 // Compute the median element. See my document on how to do this
00221                 // www.imm.dtu.dk/~jab/publications.html
00222                 int N = kvec_end-kvec_beg;
00223                 int M = 1<< (CGLA::two_to_what_power(N));
00224                 int R = N-(M-1);
00225                 int left_size  = (M-2)/2;
00226                 int right_size = (M-2)/2;
00227                 if(R < M/2)
00228                         {
00229                                 left_size += R;
00230                         }
00231                 else
00232                         {
00233                                 left_size += M/2;
00234                                 right_size += R-M/2;
00235                         }
00236 
00237                 int median = kvec_beg + left_size;
00238 
00239                 // Sort elements but use nth_element (which is cheaper) than
00240                 // a sorting algorithm. All elements to the left of the median
00241                 // will be smaller than or equal the median. All elements to the right
00242                 // will be greater than or equal to the median.
00243                 const Comp comp(disc);
00244                 std::nth_element(&init_nodes[kvec_beg], 
00245                                                                                  &init_nodes[median], 
00246                                                                                  &init_nodes[kvec_end], comp);
00247 
00248                 // Insert the node in the final data structure.
00249                 nodes[cur] = init_nodes[median];
00250                 nodes[cur].dsc = disc;
00251 
00252                 // Recursively build left and right tree.
00253                 if(left_size>0) 
00254                         optimize(2*cur, kvec_beg, median);
00255                 
00256                 if(right_size>0) 
00257                         optimize(2*cur+1, median+1, kvec_end);
00258         }
00259 
00260         template<class KeyT, class ValT>
00261         int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 
00262                                                                                                                                                                                 ScalarType& dist) const
00263         {
00264                 int ret_node = 0;
00265                 ScalarType this_dist = nodes[n].dist(p);
00266 
00267                 if(this_dist<dist)
00268                         {
00269                                 dist = this_dist;
00270                                 ret_node = n;
00271                         }
00272                 if(nodes[n].dsc != -1)
00273                         {
00274                                 int dsc         = nodes[n].dsc;
00275                                 ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
00276                                 bool left_son   = Comp(dsc)(p,nodes[n].key);
00277 
00278                                 if(left_son||dsc_dist<dist)
00279                                         {
00280                                                 int left_child = 2*n;
00281                                                 if(left_child < nodes.size())
00282                                                         if(int nl=closest_point_priv(left_child, p, dist))
00283                                                                 ret_node = nl;
00284                                         }
00285                                 if(!left_son||dsc_dist<dist)
00286                                         {
00287                                                 int right_child = 2*n+1;
00288                                                 if(right_child < nodes.size())
00289                                                         if(int nr=closest_point_priv(right_child, p, dist))
00290                                                                 ret_node = nr;
00291                                         }
00292                         }
00293                 return ret_node;
00294         }
00295 
00296         template<class KeyT, class ValT>
00297         void KDTree<KeyT,ValT>::in_sphere_priv(int n, 
00298                                                                                                                                                                  const KeyType& p, 
00299                                                                                                                                                                  const ScalarType& dist,
00300                                                                                                                                                                  std::vector<KeyT>& keys,
00301                                                                                                                                                                  std::vector<ValT>& vals) const
00302         {
00303                 ScalarType this_dist = nodes[n].dist(p);
00304                 assert(n<nodes.size());
00305                 if(this_dist<dist)
00306                         {
00307                                 keys.push_back(nodes[n].key);
00308                                 vals.push_back(nodes[n].val);
00309                         }
00310                 if(nodes[n].dsc != -1)
00311                         {
00312                                 const int dsc         = nodes[n].dsc;
00313                                 const ScalarType dsc_dist  = CGLA::sqr(nodes[n].key[dsc]-p[dsc]);
00314 
00315                                 bool left_son = Comp(dsc)(p,nodes[n].key);
00316 
00317                                 if(left_son||dsc_dist<dist)
00318                                         {
00319                                                 int left_child = 2*n;
00320                                                 if(left_child < nodes.size())
00321                                                         in_sphere_priv(left_child, p, dist, keys, vals);
00322                                         }
00323                                 if(!left_son||dsc_dist<dist)
00324                                         {
00325                                                 int right_child = 2*n+1;
00326                                                 if(right_child < nodes.size())
00327                                                         in_sphere_priv(right_child, p, dist, keys, vals);
00328                                         }
00329                         }
00330         }
00331 }
00332 namespace GEO = Geometry;
00333 
00334 #if (_MSC_VER >= 1200)
00335 #pragma warning (pop)
00336 #endif
00337 
00338 
00339 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations