본문 바로가기
C++/과제

K-d tree vs Brute-force 탐색 속도비교

by 2744m 2019. 6. 7.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctime>
 
#define MAX_DIM 2
#define N 1000000
#define rand1() (rand() / (double)RAND_MAX)
 
using namespace std;
 
double kd_time, bf_time;
long long kd_visited, bf_visited;
clock_t kd_s, kd_e, start_, end_;
 
typedef struct kd_node_t {
    double x[MAX_DIM];
    struct kd_node_t *left, *right;
}node;
 
inline double
dist(struct kd_node_t *a, struct kd_node_t *b, int dim)
{
    double t, d = 0;
    while (dim--) {
        t = a->x[dim] - b->x[dim];
        d += t * t;
    }
    return d;
}
inline void swap(struct kd_node_t *x, struct kd_node_t *y) {
    double tmp[MAX_DIM];
    memcpy(tmp, x->x, sizeof(tmp));
    memcpy(x->x, y->x, sizeof(tmp));
    memcpy(y->x, tmp, sizeof(tmp));
}
 
 
/* see quickselect method */
struct kd_node_t* find_median(struct kd_node_t *start, struct kd_node_t *endint idx)
{
    if (end <= start) return NULL;
    if (end == start + 1)
        return start;
 
    struct kd_node_t *p, *store, *md = start + (end - start) / 2;
    double pivot;
    while (1) {
        pivot = md->x[idx];
 
        swap(md, end - 1);
        for (store = p = start; p < end; p++) {
            if (p->x[idx] < pivot) {
                if (p != store)
                    swap(p, store);
                store++;
            }
        }
        swap(store, end - 1);
 
        /* median has duplicate values */
        if (store->x[idx] == md->x[idx])
            return md;
 
        if (store > md) end = store;
        else        start = store;
    }
}
 
struct kd_node_t* make_tree(struct kd_node_t *t, int len, int i, int dim) {
    struct kd_node_t *n;
    if (!len) return 0;
    if ((n = find_median(t, t + len, i))) {
        i = (i + 1) % dim;
        n->left = make_tree(t, n - t, i, dim);
        n->right = make_tree(n + 1, t + len - (n + 1), i, dim);
    }
    return n;
}
 
/* global variable, so sue me */
 
void nearest(struct kd_node_t *root, struct kd_node_t *nd, int i, int dim, struct kd_node_t **best, double *best_dist) {
    kd_s = clock();
 
    double d, dx, dx2;
    if (!root) return;
    d = dist(root, nd, dim);
    dx = root->x[i] - nd->x[i];
    dx2 = dx * dx;
    kd_visited++;
    if (!*best || d < *best_dist) {
        *best_dist = d;
        *best = root;
    }
 
    /* if chance of exact match is high */
    if (!*best_dist) return;
    if (++>= dim) i = 0;
    nearest(dx > 0 ? root->left : root->right, nd, i, dim, best, best_dist);
    if (dx2 >= *best_dist) return;
    nearest(dx > 0 ? root->right : root->left, nd, i, dim, best, best_dist);
 
    kd_e = clock();
    kd_time += double(kd_e - kd_s) / CLOCKS_PER_SEC;
}
 
int bruteForce(node *data, node query) {
    start_ = clock();
    //int n = N, d = MAX_DIM;
    double best_dist = DBL_MAX; int best_idx;
    for (int i = 0; i < N; i++) {
        double dist = 0;
        bf_visited++;
        for (int j = 0; j < MAX_DIM; j++) {
            dist += (data[i].x[j] - query.x[j])*(data[i].x[j] - query.x[j]);
        }
        if (dist < best_dist) {
            best_dist = dist;
            best_idx = i;
        }
    }
    end_ = clock();
    bf_time += double(end_ - start_) / CLOCKS_PER_SEC;
    
    return best_idx;
}
 
 
int main(void) {
 
    printf("차원(DIM) : %d\n", MAX_DIM);
    srand(unsigned(time(0)));
 
 
    struct kd_node_t *root, *found, *data;
    struct kd_node_t query;//기준점
    double best_dist;
    for (int d = 0; d < MAX_DIM; d++) query.x[d] = rand1();//기준점 랜덤 데이터 입력
    data = (struct kd_node_t*calloc(N, sizeof(struct kd_node_t));//1,000,000개 데이터 생성
 
    for (int i = 0; i < N; i++) {
        for (int d = 0; d < MAX_DIM; d++) {
            data[i].x[d] = rand1();
        }
    }
    //브루트포스 작동 테스트
    /*int idx = bruteForce(data, query);
    cout << "query : ";
    for (int i = 0; i < MAX_DIM; i++) {
        cout << data[idx].x[i] << ' ';
    }
    cout << '\n'<< "found data : ";
    for (int i = 0; i < MAX_DIM; i++) {
        cout << data[idx].x[i] << ' ';
    }*/
    //트리생성
    root = make_tree(data, N, 0, MAX_DIM);
    int test_runs = 1000// 쿼리 갯수
    
    for (int i = 0; i < test_runs; i++) {
        found = 0;
        for (int d = 0; d < MAX_DIM; d++) query.x[d] = rand1();//기준점 랜덤 데이터 입력
        nearest(root, &query, 0, MAX_DIM, &found, &best_dist);
        bruteForce(data, query);
    }
 
    printf("KD 트리 1000개 쿼리 총 탐색 시간 : %f\n", kd_time);
    printf("브루트포스 1000개 쿼리 총 탐색 시간 : %f\n", bf_time);
    printf("쿼리당 KD트리 평균 탐색 횟수 : %d\n", kd_visited / 1000);
    printf("쿼리당 브루트포스 평균 탐색 횟수 : %d\n", bf_visited / 1000);
    
 
    //free(data);
 
    return 0;
}
cs





'C++ > 과제' 카테고리의 다른 글

[백트레킹] 점들을 연결했을 때 최소 거리 구하기  (0) 2018.10.09

댓글