Main MRPT website > C++ reference for MRPT 1.5.7
KmTree.cpp
Go to the documentation of this file.
1 /* +---------------------------------------------------------------------------+
2  | Mobile Robot Programming Toolkit (MRPT) |
3  | http://www.mrpt.org/ |
4  | |
5  | Copyright (c) 2005-2017, Individual contributors, see AUTHORS file |
6  | See: http://www.mrpt.org/Authors - All rights reserved. |
7  | Released under BSD License. See details in http://www.mrpt.org/License |
8  +---------------------------------------------------------------------------+ */
9 // See KmTree.cpp
10 //
11 // Author: David Arthur (darthur@gmail.com), 2009
12 
13 // Includes
14 #include "KmTree.h"
15 #include <iostream>
16 #include <stdlib.h>
17 using namespace std;
18 
19 KmTree::KmTree(int n, int d, Scalar *points): n_(n), d_(d), points_(points) {
20  // Initialize memory
21  int node_size = sizeof(Node) + d_ * 3 * sizeof(Scalar);
22  node_data_ = (char*)malloc((2*n-1) * node_size);
23  point_indices_ = (int*)malloc(n * sizeof(int));
24  for (int i = 0; i < n; i++)
25  point_indices_[i] = i;
26  KM_ASSERT(node_data_ != 0 && point_indices_ != 0);
27 
28  // Calculate the bounding box for the points
29  Scalar *bound_v1 = PointAllocate(d_);
30  Scalar *bound_v2 = PointAllocate(d_);
31  KM_ASSERT(bound_v1 != 0 && bound_v2 != 0);
32  PointCopy(bound_v1, points, d_);
33  PointCopy(bound_v2, points, d_);
34  for (int i = 1; i < n; i++)
35  for (int j = 0; j < d; j++) {
36  if (bound_v1[j] > points[i*d_ + j]) bound_v1[j] = points[i*d_ + j];
37  if (bound_v2[j] < points[i*d_ + j]) bound_v1[j] = points[i*d_ + j];
38  }
39 
40  // Build the tree
41  char *temp_node_data = node_data_;
42  top_node_ = BuildNodes(points, 0, n-1, &temp_node_data);
43 
44  // Cleanup
45  PointFree(bound_v1);
46  PointFree(bound_v2);
47 }
48 
50  free(point_indices_);
51  free(node_data_);
52 }
53 
54 Scalar KmTree::DoKMeansStep(int k, Scalar *centers, int *assignment) const {
55  // Create an invalid center for comparison purposes
56  Scalar *bad_center = PointAllocate(d_);
57  KM_ASSERT(bad_center != 0);
58  memset(bad_center, 0xff, d_ * sizeof(Scalar));
59 
60  // Allocate data
61  Scalar *sums = (Scalar*)calloc(k * d_, sizeof(Scalar));
62  int *counts = (int*)calloc(k, sizeof(int));
63  int num_candidates = 0;
64  int *candidates = (int*)malloc(k * sizeof(int));
65  KM_ASSERT(sums != 0 && counts != 0 && candidates != 0);
66  for (int i = 0; i < k; i++)
67  if (memcmp(centers + i*d_, bad_center, d_ * sizeof(Scalar)) != 0)
68  candidates[num_candidates++] = i;
69 
70  // Find nodes
71  Scalar result = DoKMeansStepAtNode(top_node_, num_candidates, candidates, centers, sums,
72  counts, assignment);
73 
74  // Set the new centers
75  for (int i = 0; i < k; i++) {
76  if (counts[i] > 0) {
77  PointScale(sums + i*d_, Scalar(1) / counts[i], d_);
78  PointCopy(centers + i*d_, sums + i*d_, d_);
79  } else {
80  memcpy(centers + i*d_, bad_center, d_ * sizeof(Scalar));
81  }
82  }
83 
84  // Cleanup memory
85  PointFree(bad_center);
86  free(candidates);
87  free(counts);
88  free(sums);
89  return result;
90 }
91 
92 // Helper functions for constructor
93 // ================================
94 
95 // Build a kd tree from the given set of points
96 KmTree::Node *KmTree::BuildNodes(Scalar *points, int first_index, int last_index,
97  char **next_node_data) {
98  // Allocate the node
99  Node *node = (Node*)(*next_node_data);
100  (*next_node_data) += sizeof(Node);
101  node->sum = (Scalar*)(*next_node_data);
102  (*next_node_data) += sizeof(Scalar) * d_;
103  node->median = (Scalar*)(*next_node_data);
104  (*next_node_data) += sizeof(Scalar) * d_;
105  node->radius = (Scalar*)(*next_node_data);
106  (*next_node_data) += sizeof(Scalar) * d_;
107 
108  // Fill in basic info
109  node->num_points = (last_index - first_index + 1);
110  node->first_point_index = first_index;
111 
112  // Calculate the bounding box
113  Scalar *first_point = points + point_indices_[first_index] * d_;
114  Scalar *bound_p1 = PointAllocate(d_);
115  Scalar *bound_p2 = PointAllocate(d_);
116  KM_ASSERT(bound_p1 != 0 && bound_p2 != 0);
117  PointCopy(bound_p1, first_point, d_);
118  PointCopy(bound_p2, first_point, d_);
119  for (int i = first_index+1; i <= last_index; i++)
120  for (int j = 0; j < d_; j++) {
121  Scalar c = points[point_indices_[i]*d_ + j];
122  if (bound_p1[j] > c) bound_p1[j] = c;
123  if (bound_p2[j] < c) bound_p2[j] = c;
124  }
125 
126  // Calculate bounding box stats and delete the bounding box memory
127  Scalar max_radius = -1;
128  int split_d = -1;
129  for (int j = 0; j < d_; j++) {
130  node->median[j] = (bound_p1[j] + bound_p2[j]) / 2;
131  node->radius[j] = (bound_p2[j] - bound_p1[j]) / 2;
132  if (node->radius[j] > max_radius) {
133  max_radius = node->radius[j];
134  split_d = j;
135  }
136  }
137  PointFree(bound_p2);
138  PointFree(bound_p1);
139 
140  // If the max spread is 0, make this a leaf node
141  if (max_radius == 0) {
142  node->lower_node = node->upper_node = 0;
143  PointCopy(node->sum, first_point, d_);
144  if (last_index != first_index)
145  PointScale(node->sum, Scalar(last_index - first_index + 1), d_);
146  node->opt_cost = 0;
147  return node;
148  }
149 
150  // Partition the points around the midpoint in this dimension. The partitioning is done in-place
151  // by iterating from left-to-right and right-to-left in the same way that partioning is done for
152  // quicksort.
153  Scalar split_pos = node->median[split_d];
154  int i1 = first_index, i2 = last_index, size1 = 0;
155  while (i1 <= i2) {
156  bool is_i1_good = (points[point_indices_[i1]*d_ + split_d] < split_pos);
157  bool is_i2_good = (points[point_indices_[i2]*d_ + split_d] >= split_pos);
158  if (!is_i1_good && !is_i2_good) {
159  int temp = point_indices_[i1];
160  point_indices_[i1] = point_indices_[i2];
161  point_indices_[i2] = temp;
162  is_i1_good = is_i2_good = true;
163  }
164  if (is_i1_good) {
165  i1++;
166  size1++;
167  }
168  if (is_i2_good) {
169  i2--;
170  }
171  }
172 
173  // Create the child nodes
174  KM_ASSERT(size1 >= 1 && size1 <= last_index - first_index);
175  node->lower_node = BuildNodes(points, first_index, first_index + size1 - 1, next_node_data);
176  node->upper_node = BuildNodes(points, first_index + size1, last_index, next_node_data);
177 
178  // Calculate the new sum and opt cost
179  PointCopy(node->sum, node->lower_node->sum, d_);
180  PointAdd(node->sum, node->upper_node->sum, d_);
181  Scalar *center = PointAllocate(d_);
182  KM_ASSERT(center != 0);
183  PointCopy(center, node->sum, d_);
184  PointScale(center, Scalar(1) / node->num_points, d_);
185  node->opt_cost = GetNodeCost(node->lower_node, center) + GetNodeCost(node->upper_node, center);
186  PointFree(center);
187  return node;
188 }
189 
190 // Returns the total contribution of all points in the given kd-tree node, assuming they are all
191 // assigned to a center at the given location. We need to return:
192 //
193 // sum_{x \in node} ||x - center||^2.
194 //
195 // If c denotes the center of mass of the points in this node and n denotes the number of points in
196 // it, then this quantity is given by
197 //
198 // n * ||c - center||^2 + sum_{x \in node} ||x - c||^2
199 //
200 // The sum is precomputed for each node as opt_cost. This formula follows from expanding both sides
201 // as dot products. See Kanungo/Mount for more info.
202 Scalar KmTree::GetNodeCost(const Node *node, Scalar *center) const {
203  Scalar dist_sq = 0;
204  for (int i = 0; i < d_; i++) {
205  Scalar x = (node->sum[i] / node->num_points) - center[i];
206  dist_sq += x*x;
207  }
208  return node->opt_cost + node->num_points * dist_sq;
209 }
210 
211 // Helper functions for DoKMeans step
212 // ==================================
213 
214 // A recursive version of DoKMeansStep. This determines which clusters all points that are rooted
215 // node will be assigned to, and updates sums, counts and assignment (if not null) accordingly.
216 // candidates maintains the set of cluster indices which could possibly be the closest clusters
217 // for points in this subtree.
218 Scalar KmTree::DoKMeansStepAtNode(const Node *node, int k, int *candidates, Scalar *centers,
219  Scalar *sums, int *counts, int *assignment) const {
220  // Determine which center the node center is closest to
221  Scalar min_dist_sq = PointDistSq(node->median, centers + candidates[0]*d_, d_);
222  int closest_i = candidates[0];
223  for (int i = 1; i < k; i++) {
224  Scalar dist_sq = PointDistSq(node->median, centers + candidates[i]*d_, d_);
225  if (dist_sq < min_dist_sq) {
226  min_dist_sq = dist_sq;
227  closest_i = candidates[i];
228  }
229  }
230 
231  // If this is a non-leaf node, recurse if necessary
232  if (node->lower_node != 0) {
233  // Build the new list of candidates
234  int new_k = 0;
235  int *new_candidates = (int*)malloc(k * sizeof(int));
236  KM_ASSERT(new_candidates != 0);
237  for (int i = 0; i < k; i++)
238  if (!ShouldBePruned(node->median, node->radius, centers, closest_i, candidates[i]))
239  new_candidates[new_k++] = candidates[i];
240 
241  // Recurse if there's at least two
242  if (new_k > 1) {
243  Scalar result = DoKMeansStepAtNode(node->lower_node, new_k, new_candidates, centers,
244  sums, counts, assignment) +
245  DoKMeansStepAtNode(node->upper_node, new_k, new_candidates, centers,
246  sums, counts, assignment);
247  free(new_candidates);
248  return result;
249  } else {
250  free(new_candidates);
251  }
252  }
253 
254  // Assigns all points within this node to a single center
255  PointAdd(sums + closest_i*d_, node->sum, d_);
256  counts[closest_i] += node->num_points;
257  if (assignment != 0) {
258  for (int i = node->first_point_index; i < node->first_point_index + node->num_points; i++)
259  assignment[point_indices_[i]] = closest_i;
260  }
261  return GetNodeCost(node, centers + closest_i*d_);
262 }
263 
264 // Determines whether every point in the box is closer to centers[best_index] than to
265 // centers[test_index].
266 //
267 // If x is a point, c_0 = centers[best_index], c = centers[test_index], then:
268 // (x-c).(x-c) < (x-c_0).(x-c_0)
269 // <=> (c-c_0).(c-c_0) < 2(x-c_0).(c-c_0)
270 //
271 // The right-hand side is maximized for a vertex of the box where for each dimension, we choose
272 // the low or high value based on the sign of x-c_0 in that dimension.
273 bool KmTree::ShouldBePruned(Scalar *box_median, Scalar *box_radius, Scalar *centers,
274  int best_index, int test_index) const {
275  if (best_index == test_index)
276  return false;
277 
278  Scalar *best = centers + best_index*d_;
279  Scalar *test = centers + test_index*d_;
280  Scalar lhs = 0, rhs = 0;
281  for (int i = 0; i < d_; i++) {
282  Scalar component = test[i] - best[i];
283  lhs += component * component;
284  if (component > 0)
285  rhs += (box_median[i] + box_radius[i] - best[i]) * component;
286  else
287  rhs += (box_median[i] - box_radius[i] - best[i]) * component;
288  }
289  return (lhs >= 2*rhs);
290 }
291 
292 Scalar KmTree::SeedKMeansPlusPlus(int k, Scalar *centers) const {
293  Scalar *dist_sq = (Scalar*)malloc(n_ * sizeof(Scalar));
294  KM_ASSERT(dist_sq != 0);
295 
296  // Choose an initial center uniformly at random
298  int i = GetRandom(n_);
299  memcpy(centers, points_ + point_indices_[i]*d_, d_*sizeof(Scalar));
300  Scalar total_cost = 0;
301  for (int j = 0; j < n_; j++) {
302  dist_sq[j] = PointDistSq(points_ + point_indices_[j]*d_, centers, d_);
303  total_cost += dist_sq[j];
304  }
305 
306  // Repeatedly choose more centers
307  for (int new_cluster = 1; new_cluster < k; new_cluster++) {
308  while (1) {
309  Scalar cutoff = (rand() / Scalar(RAND_MAX)) * total_cost;
310  Scalar cur_cost = 0;
311  for (i = 0; i < n_; i++) {
312  cur_cost += dist_sq[i];
313  if (cur_cost >= cutoff)
314  break;
315  }
316  if (i < n_)
317  break;
318  }
319  memcpy(centers + new_cluster*d_, points_ + point_indices_[i]*d_, d_*sizeof(Scalar));
320  total_cost = SeedKmppUpdateAssignment(top_node_, new_cluster, centers, dist_sq);
321  }
322 
323  // Clean up and return
324  free(dist_sq);
325  return total_cost;
326 }
327 
328 // Helper functions for SeedKMeansPlusPlus
329 // =======================================
330 
331 // Sets kmpp_cluster_index to 0 for all nodes
332 void KmTree::SeedKmppSetClusterIndex(const Node *node, int value) const {
333  node->kmpp_cluster_index = value;
334  if (node->lower_node != 0) {
337  }
338 }
339 
340 Scalar KmTree::SeedKmppUpdateAssignment(const Node *node, int new_cluster, Scalar *centers,
341  Scalar *dist_sq) const {
342  // See if we can assign all points in this node to one cluster
343  if (node->kmpp_cluster_index >= 0) {
344  if (ShouldBePruned(node->median, node->radius, centers, node->kmpp_cluster_index, new_cluster))
345  return GetNodeCost(node, centers + node->kmpp_cluster_index*d_);
346  if (ShouldBePruned(node->median, node->radius, centers, new_cluster,
347  node->kmpp_cluster_index)) {
348  SeedKmppSetClusterIndex(node, new_cluster);
349  for (int i = node->first_point_index; i < node->first_point_index + node->num_points; i++)
350  dist_sq[i] = PointDistSq(points_ + point_indices_[i]*d_, centers + new_cluster*d_, d_);
351  return GetNodeCost(node, centers + new_cluster*d_);
352  }
353 
354  // It may be that the a leaf-node point is equidistant from the new center or old
355  if (node->lower_node == 0)
356  return GetNodeCost(node, centers + node->kmpp_cluster_index*d_);
357  }
358 
359  // Recurse
360  Scalar cost = SeedKmppUpdateAssignment(node->lower_node, new_cluster, centers, dist_sq) +
361  SeedKmppUpdateAssignment(node->upper_node, new_cluster, centers, dist_sq);
362  int i1 = node->lower_node->kmpp_cluster_index, i2 = node->upper_node->kmpp_cluster_index;
363  if (i1 == i2 && i1 != -1)
364  node->kmpp_cluster_index = i1;
365  else
366  node->kmpp_cluster_index = -1;
367  return cost;
368 }
void BASE_IMPEXP memcpy(void *dest, size_t destSize, const void *src, size_t copyCount) MRPT_NO_THROWS
An OS and compiler independent version of "memcpy".
Definition: os.cpp:358
Scalar * median
Definition: KmTree.h:56
int kmpp_cluster_index
Definition: KmTree.h:60
GLenum GLsizei n
Definition: glext.h:4618
KmTree(int n, int d, Scalar *points)
Definition: KmTree.cpp:19
Scalar * radius
Definition: KmTree.h:56
Scalar * PointAllocate(int d)
Definition: KmUtils.h:47
void PointCopy(Scalar *p1, const Scalar *p2, int d)
Definition: KmUtils.h:55
STL namespace.
bool ShouldBePruned(Scalar *box_median, Scalar *box_radius, Scalar *centers, int best_index, int test_index) const
Definition: KmTree.cpp:273
GLsizei const GLfloat * points
Definition: glext.h:4797
Scalar GetNodeCost(const Node *node, Scalar *center) const
Definition: KmTree.cpp:202
void PointAdd(Scalar *p1, const Scalar *p2, int d)
Definition: KmUtils.h:60
int n_
Definition: KmTree.h:78
const GLubyte * c
Definition: glext.h:5590
Node * lower_node
Definition: KmTree.h:59
int * point_indices_
Definition: KmTree.h:82
void PointScale(Scalar *p, Scalar scale, int d)
Definition: KmUtils.h:65
Scalar opt_cost
Definition: KmTree.h:58
Scalar * sum
Definition: KmTree.h:57
Scalar DoKMeansStep(int k, Scalar *centers, int *assignment) const
Definition: KmTree.cpp:54
int num_points
Definition: KmTree.h:54
int d_
Definition: KmTree.h:78
int GetRandom(int n)
Definition: KmUtils.h:96
Scalar SeedKmppUpdateAssignment(const Node *node, int new_cluster, Scalar *centers, Scalar *dist_sq) const
Definition: KmTree.cpp:340
#define KM_ASSERT(expression)
Definition: KmUtils.h:84
Scalar DoKMeansStepAtNode(const Node *node, int k, int *candidates, Scalar *centers, Scalar *sums, int *counts, int *assignment) const
Definition: KmTree.cpp:218
Scalar SeedKMeansPlusPlus(int k, Scalar *centers) const
Definition: KmTree.cpp:292
~KmTree()
Definition: KmTree.cpp:49
Node * BuildNodes(Scalar *points, int first_index, int last_index, char **next_node_data)
Definition: KmTree.cpp:96
GLsizei const GLfloat * value
Definition: glext.h:3929
int first_point_index
Definition: KmTree.h:55
GLenum GLint x
Definition: glext.h:3516
void PointFree(Scalar *p)
Definition: KmUtils.h:51
double Scalar
Definition: KmUtils.h:41
char * node_data_
Definition: KmTree.h:81
Scalar PointDistSq(const Scalar *p1, const Scalar *p2, int d)
Definition: KmUtils.h:70
void SeedKmppSetClusterIndex(const Node *node, int index) const
Definition: KmTree.cpp:332
Scalar * points_
Definition: KmTree.h:79
Node * top_node_
Definition: KmTree.h:80
Node * upper_node
Definition: KmTree.h:59



Page generated by Doxygen 1.8.14 for MRPT 1.5.7 Git: 5902e14cc Wed Apr 24 15:04:01 2019 +0200 at lun oct 28 01:39:17 CET 2019