23 node_data_ = (
char*)malloc((2 * n - 1) * node_size);
31 KM_ASSERT(bound_v1 !=
nullptr && bound_v2 !=
nullptr);
34 for (
int i = 1; i < n; i++)
35 for (
int j = 0; j < d; j++)
37 if (bound_v1[j] > points[i *
d_ + j])
38 bound_v1[j] = points[i *
d_ + j];
39 if (bound_v2[j] < points[i *
d_ + j])
40 bound_v1[j] = points[i *
d_ + j];
63 memset(bad_center, 0xff,
d_ *
sizeof(
Scalar));
67 int* counts = (
int*)calloc(k,
sizeof(
int));
68 int num_candidates = 0;
69 int* candidates = (
int*)malloc(k *
sizeof(
int));
70 KM_ASSERT(sums !=
nullptr && counts !=
nullptr && candidates !=
nullptr);
71 for (
int i = 0; i < k; i++)
72 if (memcmp(centers + i *
d_, bad_center,
d_ *
sizeof(
Scalar)) != 0)
73 candidates[num_candidates++] = i;
77 top_node_, num_candidates, candidates, centers, sums, counts,
81 for (
int i = 0; i < k; i++)
107 Scalar* points,
int first_index,
int last_index,
char** next_node_data)
110 Node* node = (
Node*)(*next_node_data);
111 (*next_node_data) +=
sizeof(
Node);
113 (*next_node_data) +=
sizeof(
Scalar) *
d_;
115 (*next_node_data) +=
sizeof(
Scalar) *
d_;
117 (*next_node_data) +=
sizeof(
Scalar) *
d_;
120 node->
num_points = (last_index - first_index + 1);
127 KM_ASSERT(bound_p1 !=
nullptr && bound_p2 !=
nullptr);
130 for (
int i = first_index + 1; i <= last_index; i++)
131 for (
int j = 0; j <
d_; j++)
134 if (bound_p1[j] > c) bound_p1[j] = c;
135 if (bound_p2[j] < c) bound_p2[j] = c;
141 for (
int j = 0; j <
d_; j++)
143 node->
median[j] = (bound_p1[j] + bound_p2[j]) / 2;
144 node->
radius[j] = (bound_p2[j] - bound_p1[j]) / 2;
145 if (node->
radius[j] > max_radius)
147 max_radius = node->
radius[j];
159 if (last_index != first_index)
171 int i1 = first_index, i2 = last_index, size1 = 0;
178 if (!is_i1_good && !is_i2_good)
183 is_i1_good = is_i2_good =
true;
197 KM_ASSERT(size1 >= 1 && size1 <= last_index - first_index);
199 points, first_index, first_index + size1 - 1, next_node_data);
201 BuildNodes(points, first_index + size1, last_index, next_node_data);
234 for (
int i = 0; i <
d_; i++)
254 int* counts,
int* assignment)
const 259 int closest_i = candidates[0];
260 for (
int i = 1; i < k; i++)
264 if (dist_sq < min_dist_sq)
266 min_dist_sq = dist_sq;
267 closest_i = candidates[i];
276 int* new_candidates = (
int*)malloc(k *
sizeof(
int));
278 for (
int i = 0; i < k; i++)
282 new_candidates[new_k++] = candidates[i];
289 centers, sums, counts, assignment) +
292 centers, sums, counts, assignment);
293 free(new_candidates);
298 free(new_candidates);
305 if (assignment !=
nullptr)
308 i < node->first_point_index + node->
num_points; i++)
327 int test_index)
const 329 if (best_index == test_index)
return false;
331 Scalar* best = centers + best_index *
d_;
332 Scalar* test = centers + test_index *
d_;
334 for (
int i = 0; i <
d_; i++)
336 Scalar component = test[i] - best[i];
337 lhs += component * component;
339 rhs += (box_median[i] + box_radius[i] - best[i]) * component;
341 rhs += (box_median[i] - box_radius[i] - best[i]) * component;
343 return (lhs >= 2 * rhs);
356 for (
int j = 0; j <
n_; j++)
359 total_cost += dist_sq[j];
363 for (
int new_cluster = 1; new_cluster < k; new_cluster++)
367 Scalar cutoff = (rand() /
Scalar(RAND_MAX)) * total_cost;
369 for (i = 0; i <
n_; i++)
371 cur_cost += dist_sq[i];
372 if (cur_cost >= cutoff)
break;
418 i < node->first_point_index + node->
num_points; i++)
421 centers + new_cluster *
d_,
d_);
433 node->
lower_node, new_cluster, centers, dist_sq) +
435 node->
upper_node, new_cluster, centers, dist_sq);
438 if (i1 == i2 && i1 != -1)
KmTree(int n, int d, Scalar *points)
Scalar * PointAllocate(int d)
void PointCopy(Scalar *p1, const Scalar *p2, int d)
bool ShouldBePruned(Scalar *box_median, Scalar *box_radius, Scalar *centers, int best_index, int test_index) const
Scalar GetNodeCost(const Node *node, Scalar *center) const
void PointAdd(Scalar *p1, const Scalar *p2, int d)
void PointScale(Scalar *p, Scalar scale, int d)
Scalar DoKMeansStep(int k, Scalar *centers, int *assignment) const
Scalar SeedKmppUpdateAssignment(const Node *node, int new_cluster, Scalar *centers, Scalar *dist_sq) const
#define KM_ASSERT(expression)
Scalar DoKMeansStepAtNode(const Node *node, int k, int *candidates, Scalar *centers, Scalar *sums, int *counts, int *assignment) const
Scalar SeedKMeansPlusPlus(int k, Scalar *centers) const
Node * BuildNodes(Scalar *points, int first_index, int last_index, char **next_node_data)
void PointFree(Scalar *p)
Scalar PointDistSq(const Scalar *p1, const Scalar *p2, int d)
void SeedKmppSetClusterIndex(const Node *node, int index) const
void memcpy(void *dest, size_t destSize, const void *src, size_t copyCount) noexcept
An OS and compiler independent version of "memcpy".