GEL
2
GEL is a library for Geometry and Linear Algebra
|
00001 #ifndef __GEOMETRY_KDTREE_H 00002 #define __GEOMETRY_KDTREE_H 00003 00004 #include <cmath> 00005 #include <iostream> 00006 #include <vector> 00007 #include <algorithm> 00008 #include "CGLA/CGLA.h" 00009 #include "CGLA/ArithVec.h" 00010 00011 #if (_MSC_VER >= 1200) 00012 #pragma warning (push) 00013 #pragma warning (disable: 4018) 00014 #endif 00015 00016 namespace Geometry 00017 { 00023 template<class KeyT, class ValT> 00024 class KDTree 00025 { 00026 typedef typename KeyT::ScalarType ScalarType; 00027 typedef KeyT KeyType; 00028 typedef std::vector<KeyT> KeyVectorType; 00029 typedef std::vector<ValT> ValVectorType; 00030 00032 struct KDNode 00033 { 00034 KeyT key; 00035 ValT val; 00036 short dsc; 00037 00038 KDNode(): dsc(0) {} 00039 00040 KDNode(const KeyT& _key, const ValT& _val): 00041 key(_key), val(_val), dsc(-1) {} 00042 00043 ScalarType dist(const KeyType& p) const 00044 { 00045 KeyType dist_vec = p; 00046 dist_vec -= key; 00047 return dot(dist_vec, dist_vec); 00048 } 00049 }; 00050 00051 typedef std::vector<KDNode> NodeVecType; 00052 bool is_built; 00053 NodeVecType init_nodes; 00054 NodeVecType nodes; 00055 00059 class Comp 00060 { 00061 const int dsc; 00062 public: 00063 Comp(int _dsc): dsc(_dsc) {} 00064 bool operator()(const KeyType& k0, const KeyType& k1) const 00065 { 00066 int dim=KeyType::get_dim(); 00067 for(int i=0;i<dim;i++) 00068 { 00069 int j=(dsc+i)%dim; 00070 if(k0[j]<k1[j]) 00071 return true; 00072 if(k0[j]>k1[j]) 00073 return false; 00074 } 00075 return false; 00076 } 00077 00078 bool operator()(const KDNode& k0, const KDNode& k1) const 00079 { 00080 return (*this)(k0.key,k1.key); 00081 } 00082 }; 00083 00084 00087 void optimize(int, int, int); 00088 00090 int closest_point_priv(int, const KeyType&, ScalarType&) const; 00091 00092 00093 void in_sphere_priv(int n, 00094 const KeyType& p, 00095 const ScalarType& dist, 00096 std::vector<KeyT>& keys, 00097 std::vector<ValT>& vals) const; 00098 00103 int opt_disc(int,int) const; 00104 00105 public: 00106 00108 KDTree(): is_built(false), init_nodes(1) {} 00109 00112 void insert(const KeyT& key, const ValT& val) 00113 { 00114 if(is_built) 00115 { 00116 assert(init_nodes.size()==1); 00117 init_nodes.swap(nodes); 00118 is_built=false; 00119 } 00120 init_nodes.push_back(KDNode(key,val)); 00121 } 00122 00125 void build() 00126 { 00127 assert(!is_built); 00128 nodes.resize(init_nodes.size()); 00129 if(init_nodes.size() > 1) 00130 optimize(1,1,init_nodes.size()); 00131 NodeVecType v(1); 00132 init_nodes.swap(v); 00133 is_built = true; 00134 } 00135 00141 bool closest_point(const KeyT& p, ScalarType& dist, KeyT&k, ValT&v) const 00142 { 00143 assert(is_built); 00144 if(nodes.size()>1) 00145 { 00146 ScalarType max_sq_dist = CGLA::sqr(dist); 00147 if(int n = closest_point_priv(1, p, max_sq_dist)) 00148 { 00149 k = nodes[n].key; 00150 v = nodes[n].val; 00151 dist = std::sqrt(max_sq_dist); 00152 return true; 00153 } 00154 } 00155 return false; 00156 } 00157 00165 int in_sphere(const KeyType& p, 00166 ScalarType dist, 00167 std::vector<KeyT>& keys, 00168 std::vector<ValT>& vals) const 00169 { 00170 assert(is_built); 00171 if(nodes.size()>1) 00172 { 00173 ScalarType max_sq_dist = CGLA::sqr(dist); 00174 in_sphere_priv(1,p,max_sq_dist,keys,vals); 00175 return keys.size(); 00176 } 00177 return 0; 00178 } 00179 00180 00181 }; 00182 00183 template<class KeyT, class ValT> 00184 int KDTree<KeyT,ValT>::opt_disc(int kvec_beg, 00185 int kvec_end) const 00186 { 00187 KeyType vmin = init_nodes[kvec_beg].key; 00188 KeyType vmax = init_nodes[kvec_beg].key; 00189 for(int i=kvec_beg;i<kvec_end;i++) 00190 { 00191 vmin = CGLA::v_min(vmin,init_nodes[i].key); 00192 vmax = CGLA::v_max(vmax,init_nodes[i].key); 00193 } 00194 int od=0; 00195 KeyType ave_v = vmax-vmin; 00196 for(int i=1;i<KeyType::get_dim();i++) 00197 if(ave_v[i]>ave_v[od]) od = i; 00198 return od; 00199 } 00200 00201 template<class KeyT, class ValT> 00202 void KDTree<KeyT,ValT>::optimize(int cur, 00203 int kvec_beg, 00204 int kvec_end) 00205 { 00206 // Assert that we are not inserting beyond capacity. 00207 assert(cur < nodes.size()); 00208 00209 // If there is just a single element, we simply insert. 00210 if(kvec_beg+1==kvec_end) 00211 { 00212 nodes[cur] = init_nodes[kvec_beg]; 00213 nodes[cur].dsc = -1; 00214 return; 00215 } 00216 00217 // Find the axis that best separates the data. 00218 int disc = opt_disc(kvec_beg, kvec_end); 00219 00220 // Compute the median element. See my document on how to do this 00221 // www.imm.dtu.dk/~jab/publications.html 00222 int N = kvec_end-kvec_beg; 00223 int M = 1<< (CGLA::two_to_what_power(N)); 00224 int R = N-(M-1); 00225 int left_size = (M-2)/2; 00226 int right_size = (M-2)/2; 00227 if(R < M/2) 00228 { 00229 left_size += R; 00230 } 00231 else 00232 { 00233 left_size += M/2; 00234 right_size += R-M/2; 00235 } 00236 00237 int median = kvec_beg + left_size; 00238 00239 // Sort elements but use nth_element (which is cheaper) than 00240 // a sorting algorithm. All elements to the left of the median 00241 // will be smaller than or equal the median. All elements to the right 00242 // will be greater than or equal to the median. 00243 const Comp comp(disc); 00244 std::nth_element(&init_nodes[kvec_beg], 00245 &init_nodes[median], 00246 &init_nodes[kvec_end], comp); 00247 00248 // Insert the node in the final data structure. 00249 nodes[cur] = init_nodes[median]; 00250 nodes[cur].dsc = disc; 00251 00252 // Recursively build left and right tree. 00253 if(left_size>0) 00254 optimize(2*cur, kvec_beg, median); 00255 00256 if(right_size>0) 00257 optimize(2*cur+1, median+1, kvec_end); 00258 } 00259 00260 template<class KeyT, class ValT> 00261 int KDTree<KeyT,ValT>::closest_point_priv(int n, const KeyType& p, 00262 ScalarType& dist) const 00263 { 00264 int ret_node = 0; 00265 ScalarType this_dist = nodes[n].dist(p); 00266 00267 if(this_dist<dist) 00268 { 00269 dist = this_dist; 00270 ret_node = n; 00271 } 00272 if(nodes[n].dsc != -1) 00273 { 00274 int dsc = nodes[n].dsc; 00275 ScalarType dsc_dist = CGLA::sqr(nodes[n].key[dsc]-p[dsc]); 00276 bool left_son = Comp(dsc)(p,nodes[n].key); 00277 00278 if(left_son||dsc_dist<dist) 00279 { 00280 int left_child = 2*n; 00281 if(left_child < nodes.size()) 00282 if(int nl=closest_point_priv(left_child, p, dist)) 00283 ret_node = nl; 00284 } 00285 if(!left_son||dsc_dist<dist) 00286 { 00287 int right_child = 2*n+1; 00288 if(right_child < nodes.size()) 00289 if(int nr=closest_point_priv(right_child, p, dist)) 00290 ret_node = nr; 00291 } 00292 } 00293 return ret_node; 00294 } 00295 00296 template<class KeyT, class ValT> 00297 void KDTree<KeyT,ValT>::in_sphere_priv(int n, 00298 const KeyType& p, 00299 const ScalarType& dist, 00300 std::vector<KeyT>& keys, 00301 std::vector<ValT>& vals) const 00302 { 00303 ScalarType this_dist = nodes[n].dist(p); 00304 assert(n<nodes.size()); 00305 if(this_dist<dist) 00306 { 00307 keys.push_back(nodes[n].key); 00308 vals.push_back(nodes[n].val); 00309 } 00310 if(nodes[n].dsc != -1) 00311 { 00312 const int dsc = nodes[n].dsc; 00313 const ScalarType dsc_dist = CGLA::sqr(nodes[n].key[dsc]-p[dsc]); 00314 00315 bool left_son = Comp(dsc)(p,nodes[n].key); 00316 00317 if(left_son||dsc_dist<dist) 00318 { 00319 int left_child = 2*n; 00320 if(left_child < nodes.size()) 00321 in_sphere_priv(left_child, p, dist, keys, vals); 00322 } 00323 if(!left_son||dsc_dist<dist) 00324 { 00325 int right_child = 2*n+1; 00326 if(right_child < nodes.size()) 00327 in_sphere_priv(right_child, p, dist, keys, vals); 00328 } 00329 } 00330 } 00331 } 00332 namespace GEO = Geometry; 00333 00334 #if (_MSC_VER >= 1200) 00335 #pragma warning (pop) 00336 #endif 00337 00338 00339 #endif