본문 바로가기
기타/알고리즘

[알고리즘]세그먼트 트리(feat. Java)

by 코딩하는 랄로 2023. 9. 24.
728x90

세그먼트 트리(Segment Tree)

세그먼트 트리는 트리 형태의 자료 구조를 사용하여 숫자가 저장된 배열이 존재할 때 해당 배열의 구간 합을 구하거나, 배열의 특정 인덱스의 값을 변경한 후에 다시 구간합을 구해야 하는 경우에 적은 시간 복잡도로 작업을 진행할 수 있도록 해주는 트리 구조이다.

 

즉, 세그먼트 트리에는 특정 구간 배열의 합들을 노드에 저장해 놓는 형태인 것이다. 예를 들어 배열의 크기가 10일 때 세그먼트 트리의 각 노드는 다음을 의미한다.

 

 

리프 노드는 배열의 각 인덱스의 값을 나타내고 (0번 노드는 배열[0] 값) 리프 노드 이외의 노드는 왼쪽 자식과 오른쪽 자식의 합, 즉 구간합을 나타낸다.

 

이렇게 트리를 구성할 경우, 특정 인덱스의 값이 바뀌어도 구간합의 변경이 O(logN)의 시간 복잡도를 가지게 되므로 유용한 구조 중 하나이다.

 

 

 

코드 구현

세그먼트 트리 초기화 메서드

static class SegmentTree{
        // 세그먼트 트리를 구현할 배열
        private long[] tree;

        // 생성자에서 세그먼트 트리의 전체노드 수 계산 (즉, 배열 길이)
        SegmentTree(int n) {
            // 트리의 높이 계산
            double treeHeight = Math.ceil(Math.log(n)/Math.log(2))+1;
            // 트리의 노드 수 계산
            long treeNodeCount = Math.round(Math.pow(2, treeHeight));
            // 트리의 길이 설정
            tree = new long[Math.toIntExact(treeNodeCount)];
        }

        // 세그먼트 트리의 노드 값 초기화
        long init(long[] arr, int node, int start, int end ){
            // 세그먼트 트리의 리프노드인 경우
            if (start == end) {
                // 리프노드에 배열의 값 저장 후 리턴
                return tree[node] = arr[start];
            }else{
                // 리프노드가 아닌 경우에는 자식노드의 값을 더해서 노드의 값 초기화 후 리턴
                return tree[node] = init(arr, node*2, start, (start+end)/2) 
                                  + init(arr, node*2+1, (start+end)/2+1, end);
            }
        }
}

 

배열의 크기가 N일 때, 세그먼트 트리의 높이와 전체 노드 수를 구하는 식은 다음과 같다.

  • 세그먼트 트리의 높이 = logN(밑이 2인 log)의 값을 올림 후 + 1
  • 세그먼트 트리의 전체 노드 수 = 2 ^ (트리의 높이)

 

 

특정 구간의 합 구하는 메서드

// 배열의 특정 구간 합을 세그먼트 트리로 구하기
long sum(int node, int start, int end, int left, int right) {
    // 노드가 가지는 값의 구간이 구하려고 하는 합의 구간에 속하지 않는 경우 0리턴
    if (end < left || right < start) {
        return 0;
    } else if (left <= start && end <= right) {
    // 노드가 가지는 값의 구간이 구하려고 하는 합의 구간에 속하는 경우 노드 값 리턴
        return tree[node];
    } else {
    // 그 외는 2가지 경우가 존재
    // 1. 노드가 가지는 값의 구간이 구하려고 하는 합의 구간에 일부는 속하고 일부는 속하지 않는 경우
    // 2. 노드가 가지는 값의 구간이 구하려고 하는 합의 구간을 모두 포함하는 경우
    // 이와 같은 경우에는 자식노드를 탐색해서 값을 리턴
        return sum(node*2, start, (start+end)/2, left, right)
             + sum(node*2+1, (start+end)/2+1, end, left, right);
    }
}

 

  • 노드가 가지는 구간이 구하려고 하는 배열의 구간에 포함되지 않은 경우 -> 0 리턴
  • 노드가 가지는 구간이 구하려고 하는 배열의 구간에 포함되거나 같은 경우 -> 노드 값 리턴
  • 노드가 가지는 구간이 구하려고 하는 배열의 구간을 모두 포함하고 있는 경우 -> 자식 노드로 이동
  • 노드가 가지는 구간이 구하려고 하는 배열의 구간에 일부는 포함 일부는 미포함인 경우 -> 자식 노드로 이동

 

 

배열의 특정 인덱스 값을 변경하는 메서드

배열의 특정 인덱스 값이 변경되면 구간합도 바뀌기 때문에 세그먼트 트리를 돌면서 배열의 특정 인덱스 값을 변경해줌과 동시에 구간합도 업데이트해야 한다.

 

여기서 세그먼트 구간합을 업데이트 해주는 방법엔느 두가지가 있다. 첫번째 방법은 특정 인덱스의 값이 변경될 경우 그 차이값을 해당 인덱스가 포함된 모든 구간합 노드에 더해주는 것이다.

 

// 배열의 특정 인데스의 값이 변경 될 경우 세그먼트 트리의 노드 값 변경(차이 값을 더하는 방법)
void update(int node, int start, int end, int index, long diff) {
    // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)가 포함되지 않을 경우
    if (index < start || end < index) {
        // 아무것도 안함
        return;
    }else {
        // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)가 포함되는 경우
        // 노드의 값 + 차이값(변경할 값-기존값)
        tree[node] = tree[node] + diff;                                                                  

        // 노드가 리프노드가 아닌 경우
        if (start != end) {
            // 리프노드까지 계속 자식노드를 탐색
            update(node*2, start, (start+end)/2, index, diff) ;
            update(node*2+1, (start+end)/2+1, end, index, diff) ;
        }
    }
}

 

다른 방법은 차이값을 해당 모든 노드에 더해주는 것이 아닌 특정 인덱스에 해당하는 리프 노드 값을 직접 변경해줌으로써 구간합을 업데이트 하는 방법이다.

 

// 배열의 특정 인데스의 값이 변경 될 경우 세그먼트 트리의 노드 값 변경(노드 값을 직접 변경)
long update2(int node, int start, int end, int index, long changeValue) {
    // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)가 포함되지 않을 경우
    if (index < start || end < index) {
        // 트리의 노드 값 리턴
        return tree[node];
    } else if (start == index && end == index) {
        // 노드가 가지는 값의 구간과 배열의 인덱스(값이 변경 될 인덱스)값이 같은 경우
        // 노드의 값을 변경 될 값으로 변경
        return tree[node] = changeValue;
    } else {
        // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)값이 포함되는 경우(같은 경우는 제외)
        // 자식 노드를 탐색 후 값을 더해서 리턴
        return tree[node] = update2(node * 2, start, (start + end) / 2, index, changeValue) +
                update2(node * 2 + 1, (start + end) / 2 + 1, end, index, changeValue);
    }
}

 

 

전체 코드

static class SegmentTree{
        // 세그먼트 트리를 구현할 배열
        private long[] tree;

        // 생성자에서 세그먼트 트리의 전체노드 수 계산 (즉, 배열 길이)
        SegmentTree(int n) {
            // 트리의 높이 계산
            double treeHeight = Math.ceil(Math.log(n)/Math.log(2))+1;
            // 트리의 노드 수 계산
            long treeNodeCount = Math.round(Math.pow(2, treeHeight));
            // 트리의 길이 설정
            tree = new long[Math.toIntExact(treeNodeCount)];
        }

        // 세그먼트 트리의 노드 값 초기화
        long init(long[] arr, int node, int start, int end ){
            // 세그먼트 트리의 리프노드인 경우
            if (start == end) {
                // 리프노드에 배열의 값 저장 후 리턴
                return tree[node] = arr[start];
            }else{
                // 리프노드가 아닌 경우에는 자식노드의 값을 더해서 노드의 값 초기화 후 리턴
                return tree[node] = init(arr, node*2, start, (start+end)/2)
                        + init(arr, node*2+1, (start+end)/2+1, end);
            }
        }

        // 배열의 특정 구간 합을 세그먼트 트리로 구하기
        long sum(int node, int start, int end, int left, int right) {
            // 노드가 가지는 값의 구간이 구하려고 하는 합의 구간에 속하지 않는 경우 0리턴
            if (end < left || right < start) {
                return 0;
            } else if (left <= start && end <= right) {
                // 노드가 가지는 값의 구간이 구하려고 하는 합의 구간에 속하는 경우 노드 값 리턴
                return tree[node];
            } else {
                // 그 외는 2가지 경우가 존재
                // 1. 노드가 가지는 값의 구간이 구하려고 하는 합의 구간에 일부는 속하고 일부는 속하지 않는 경우
                // 2. 노드가 가지는 값의 구간이 구하려고 하는 합의 구간을 모두 포함하는 경우
                // 이와 같은 경우에는 자식노드를 탐색해서 값을 리턴
                return sum(node*2, start, (start+end)/2, left, right)
                        + sum(node*2+1, (start+end)/2+1, end, left, right);
            }
        }

        // 배열의 특정 인데스의 값이 변경 될 경우 세그먼트 트리의 노드 값 변경(차이 값을 더하는 방법)
        void update(int node, int start, int end, int index, long diff) {
            // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)가 포함되지 않을 경우7
            if (index < start || end < index) {
                // 아무것도 안함
                return;
            }else {
                // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)가 포함되는 경우
                // 노드의 값 + 차이값(변경할 값-기존값)
                tree[node] = tree[node] + diff;

                // 노드가 리프노드가 아닌 경우
                if (start != end) {
                    // 리프노드까지 계속 자식노드를 탐색
                    update(node*2, start, (start+end)/2, index, diff) ;
                    update(node*2+1, (start+end)/2+1, end, index, diff) ;
                }
            }
        }

        // 배열의 특정 인데스의 값이 변경 될 경우 세그먼트 트리의 노드 값 변경(노드 값을 직접 변경)
        long update2(int node, int start, int end, int index, long changeValue) {
            // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)가 포함되지 않을 경우
            if (index < start || end < index) {
                // 트리의 노드 값 리턴
                return tree[node];
            } else if (start == index && end == index) {
                // 노드가 가지는 값의 구간과 배열의 인덱스(값이 변경 될 인덱스)값이 같은 경우
                // 노드의 값을 변경 될 값으로 변경
                return tree[node] = changeValue;
            } else {
                // 노드가 가지는 값의 구간에 배열의 인덱스(값이 변경 될 인덱스)값이 포함되는 경우(같은 경우는 제외)
                // 자식 노드를 탐색 후 값을 더해서 리턴
                return tree[node] = update2(node * 2, start, (start + end) / 2, index, changeValue) +
                        update2(node * 2 + 1, (start + end) / 2 + 1, end, index, changeValue);
            }
        }
    }

 

728x90