C++/과제
K-d tree vs Brute-force 탐색 속도비교
2744m
2019. 6. 7. 05:03
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | #include <iostream> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <ctime> #define MAX_DIM 2 #define N 1000000 #define rand1() (rand() / (double)RAND_MAX) using namespace std; double kd_time, bf_time; long long kd_visited, bf_visited; clock_t kd_s, kd_e, start_, end_; typedef struct kd_node_t { double x[MAX_DIM]; struct kd_node_t *left, *right; }node; inline double dist(struct kd_node_t *a, struct kd_node_t *b, int dim) { double t, d = 0; while (dim--) { t = a->x[dim] - b->x[dim]; d += t * t; } return d; } inline void swap(struct kd_node_t *x, struct kd_node_t *y) { double tmp[MAX_DIM]; memcpy(tmp, x->x, sizeof(tmp)); memcpy(x->x, y->x, sizeof(tmp)); memcpy(y->x, tmp, sizeof(tmp)); } /* see quickselect method */ struct kd_node_t* find_median(struct kd_node_t *start, struct kd_node_t *end, int idx) { if (end <= start) return NULL; if (end == start + 1) return start; struct kd_node_t *p, *store, *md = start + (end - start) / 2; double pivot; while (1) { pivot = md->x[idx]; swap(md, end - 1); for (store = p = start; p < end; p++) { if (p->x[idx] < pivot) { if (p != store) swap(p, store); store++; } } swap(store, end - 1); /* median has duplicate values */ if (store->x[idx] == md->x[idx]) return md; if (store > md) end = store; else start = store; } } struct kd_node_t* make_tree(struct kd_node_t *t, int len, int i, int dim) { struct kd_node_t *n; if (!len) return 0; if ((n = find_median(t, t + len, i))) { i = (i + 1) % dim; n->left = make_tree(t, n - t, i, dim); n->right = make_tree(n + 1, t + len - (n + 1), i, dim); } return n; } /* global variable, so sue me */ void nearest(struct kd_node_t *root, struct kd_node_t *nd, int i, int dim, struct kd_node_t **best, double *best_dist) { kd_s = clock(); double d, dx, dx2; if (!root) return; d = dist(root, nd, dim); dx = root->x[i] - nd->x[i]; dx2 = dx * dx; kd_visited++; if (!*best || d < *best_dist) { *best_dist = d; *best = root; } /* if chance of exact match is high */ if (!*best_dist) return; if (++i >= dim) i = 0; nearest(dx > 0 ? root->left : root->right, nd, i, dim, best, best_dist); if (dx2 >= *best_dist) return; nearest(dx > 0 ? root->right : root->left, nd, i, dim, best, best_dist); kd_e = clock(); kd_time += double(kd_e - kd_s) / CLOCKS_PER_SEC; } int bruteForce(node *data, node query) { start_ = clock(); //int n = N, d = MAX_DIM; double best_dist = DBL_MAX; int best_idx; for (int i = 0; i < N; i++) { double dist = 0; bf_visited++; for (int j = 0; j < MAX_DIM; j++) { dist += (data[i].x[j] - query.x[j])*(data[i].x[j] - query.x[j]); } if (dist < best_dist) { best_dist = dist; best_idx = i; } } end_ = clock(); bf_time += double(end_ - start_) / CLOCKS_PER_SEC; return best_idx; } int main(void) { printf("차원(DIM) : %d\n", MAX_DIM); srand(unsigned(time(0))); struct kd_node_t *root, *found, *data; struct kd_node_t query;//기준점 double best_dist; for (int d = 0; d < MAX_DIM; d++) query.x[d] = rand1();//기준점 랜덤 데이터 입력 data = (struct kd_node_t*) calloc(N, sizeof(struct kd_node_t));//1,000,000개 데이터 생성 for (int i = 0; i < N; i++) { for (int d = 0; d < MAX_DIM; d++) { data[i].x[d] = rand1(); } } //브루트포스 작동 테스트 /*int idx = bruteForce(data, query); cout << "query : "; for (int i = 0; i < MAX_DIM; i++) { cout << data[idx].x[i] << ' '; } cout << '\n'<< "found data : "; for (int i = 0; i < MAX_DIM; i++) { cout << data[idx].x[i] << ' '; }*/ //트리생성 root = make_tree(data, N, 0, MAX_DIM); int test_runs = 1000; // 쿼리 갯수 for (int i = 0; i < test_runs; i++) { found = 0; for (int d = 0; d < MAX_DIM; d++) query.x[d] = rand1();//기준점 랜덤 데이터 입력 nearest(root, &query, 0, MAX_DIM, &found, &best_dist); bruteForce(data, query); } printf("KD 트리 1000개 쿼리 총 탐색 시간 : %f\n", kd_time); printf("브루트포스 1000개 쿼리 총 탐색 시간 : %f\n", bf_time); printf("쿼리당 KD트리 평균 탐색 횟수 : %d\n", kd_visited / 1000); printf("쿼리당 브루트포스 평균 탐색 횟수 : %d\n", bf_visited / 1000); //free(data); return 0; } | cs |