분류: 세그먼트 트리 /
문제
문제 설명
어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.
입력
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
출력
첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
풀이
#include <iostream>
#define ll long long
using namespace std;
int N, M, K, leafSize = 1;
ll segment[(1 << 21)];
inline int left(int node) { return 2 * node; }
inline int right(int node) { return 2 * node + 1; }
void init() {
for(int i = leafSize; i < leafSize+N; i++)
cin >> segment[i];
for(int i = leafSize-1; i > 0; i--)
segment[i] = segment[left(i)] + segment[right(i)];
}
ll sum(int L, int R, int node, int nodeL, int nodeR) {
if(R < nodeL || L > nodeR) return 0;
if(L <= nodeL and nodeR <= R) return segment[node];
int mid = (nodeL + nodeR) / 2;
return sum(L, R, left(node), nodeL, mid)
+ sum(L, R, right(node), mid+1, nodeR);
}
void update(int i, ll val) {
i = leafSize + i - 1;
segment[i] = val;
while(i >>= 1)
segment[i] = segment[left(i)] + segment[right(i)];
}
int main() {
ios::sync_with_stdio(false); cin.tie(NULL);
cin >> N >> M >> K;
while(leafSize < N) leafSize <<= 1;
init();
for(int Q=M+K; Q--;) {
ll a, b, c; cin >> a >> b >> c;
if(a == 1) update(b, c);
else cout << sum(b, c, 1, 1, leafSize) << '\n';
}
}
세그먼트 트리의 구현에 대해 생각해보자.
# 포화 이진 트리
높이가 `k`인 포화 이진 트리는 총 `2k+1 - 1`개의 노드를 가지며 이때 리프 노드는 `2k`개다.
포화 이진 트리를 구현 할때는 위 그림에서 노드에 쓰인 숫자를 인덱스로 하여 `1-based indexing` 배열로 구현할 수 있다.
이때 배열의 크기는 `2k+1 - 1`개의 노드와 `[0]` 한 개를 포함해 `2k+1`이다.
위 사전 지식을 가지고 문제로 돌아가보자. 문제는 리프 노드가 최소 `N`개 인 세그먼트 트리를 만들어야한다.
세그먼트 트리의 최소 높이 `k`는 리프 노드 2k개에 N개의 입력이 모두 들어가야 하고 리프의 부모 레벨(k-1)의 2k-1개의 노드에는 입력이 모두 들어갈 수 없어야 하므로 `2k-1 < N ≤ 2k`을 만족한다.
위에서 살펴 봤듯 리프 노드가 2k개 일때, 세그먼트 트리는 2k+1 길이의 배열로 구현할 수 있다.
문제에서 N이 최대 1,000,000으로 k는 최대 20이다. 따라서 구현체 배열(`segment[]`)의 길이는 최대 221이다.
// 배열의 최대 크기 = 2^21
ll segment[(1 << 21)];
// 2^(k-1) < N ≤ 2^k = leafSize
int leafSize = 1;
while(leafSize < N) leafSize <<= 1;
# 주의
코드에서 두개의 인덱스가 사용된다. `L, R, nodeL, nodeR`은 입력의 인덱스로 "`L`번째 입력"이라는 의미이고 범위는 `[1, N]`이다.
`node`는 세그먼트 트리 구현체 배열 `segment[]`의 인덱스로 범위는 `[1, 2 * leafSize - 1]`이다.
이때 입력 `[1, N]`은 세그먼트 트리의 리프노드 `[leafSize, 2 * leafSize - 1]`에 대응된다.
# 세그먼트 트리 생성 init()
void init() {
for(int i = leafSize; i < leafSize+N; i++)
cin >> segment[i];
for(int i = leafSize-1; i > 0; i--)
segment[i] = segment[left(i)] + segment[right(i)];
}
세그먼트 트리 배열(`segment[]`)에서 리프 노드의 인덱스 범위는 `[leafSize, 2 * leafSize - 1]`이므로 `i = leafSize`부터 `N`개의 입력을 받는다.
이 후 `leafSize - 1`부터 루트 노드까지 자식 노드의 합으로 계산한다.
# 합 계산 sum()
// [L, R]의 합을 계산
// node: [nodeL, nodeR]
ll sum(int L, int R, int node, int nodeL, int nodeR) {
// 노드가 [L, R]과 겹치지 않음
if(R < nodeL || L > nodeR) return 0;
// 노드가 [L, R]에 완전히 포함됨
if(L <= nodeL and nodeR <= R) return segment[node];
// 노드가 [L, R]과 겹침
int mid = (nodeL + nodeR) / 2;
return sum(L, R, left(node), nodeL, mid)
+ sum(L, R, right(node), mid+1, nodeR);
}
// root: [1, leafSize] 부터 탐색 시작
sum(L, R, 1, 1, leafSize)
`[L, R]`의 합을 계산 하는 쿼리가 주어졌을 때, 루트부터 내려가며 합을 계산한다.
`node`는 현재 노드의 `segment[]`상의 인덱스이며 `[nodeL, nodeR]`의 합을 저장하고 있다.
`sum`함수는 구하고자 하는 구간 `[L, R]`에 노드가 완전히 포함되거나 포함되지 않을 때까지 범위를 분할하며 재귀적으로 탐색한다.
# 업데이트 update()
void update(int i, ll val) {
i = leafSize + i - 1;
segment[i] = val;
while(i >>= 1)
segment[i] = segment[left(i)] + segment[right(i)];
}
업데이트 쿼리 때문에 세그먼트 트리를 사용한다. 업데이트 쿼리가 없다면 `prefix sum`으로 선형 시간에 답을 낼 수 있다.
업데이트는 해당 인덱스의 리프 노드를 업데이트하고 업데이트 된 리프 노드부터 루트까지 경로를 업데이트하며 진행된다.
노드의 부모 노드는 인덱스를 2로 나누는 것으로 쉽게 찾을 수 있다. (`i >>= 1`)
`i` 번째 입력을 업데이트 하라는 쿼리의 `i`는 입력 인덱스로 `[1, N]`의 값이다.
따라서 `i`를 실제 값이 저장되어 있는 `segment[]`의 리프 노드 인덱스 범위(`[leafSize, 2*leafSize-1]`)로 변환해 주어야한다. (`i = leafSize + i - 1`)
'Problem Solving > BOJ' 카테고리의 다른 글
[백준 - 15684] 사다리 조작 - C++ (0) | 2023.11.22 |
---|---|
[백준 - 16987] 계란으로 계란치기 - C++, Python (0) | 2023.11.22 |
[백준 - 1967] 트리의 지름 - C++ (0) | 2023.11.20 |
[백준 - 12865] 평범한 배낭 - C++ (0) | 2023.11.20 |
[백준 - 11659] 구간 합 구하기 4 - C++ (0) | 2023.11.19 |