문제 출처 : https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

알고리즘 분류에 세그먼트 트리가 있어 세그먼트 트리를 활용하여 풀어보았다.

세그먼트 트리는 이진 트리이며 배열의 특정 구간에 대한 정보를 관리하기에 좋다.

리프 노드에 값을 입력 받고 부모 노드에 특정 정보를 저장하여 관리한다.

  3 
 / \
1   2

해당 예에서는 리프 노드에 1, 2를 저장하고 부모 노드에는 두 값의 합인 3을 저장하였다.

 

vector<long long> arr;
vector<long long> segtree;

arr에는 입력받은 값을 저장할 것이고 segtree에는 해당 값을 트리 형태로 저장 할 것이다.

입력으로 들어오는 수가 int형을 넘기 때문에 long long으로 선언하였다.

 

	int n, m, k, a, b;
	long long c;
	cin >> n >> m >> k;
	arr.resize(n);
	for (int i = 0; i < n; i++) {
		cin >> arr[i];
	}
	int h = (int)ceil(log2(n));
	int tree_size = (1 << (h + 1));
	segtree.resize(tree_size);
	init(1, 0, n - 1);

문제에서 주어진대로 n, m, k를 입력 받고 arr의 사이즈를 n으로 설정하였다.

h는 트리의 높이이다.

트리의 높이는 배열의 개수 n에 따라 결정 된다.

레벨이 오르면 이전 레벨의 2배가 되는 노드를 저장할 수 있으므로 트리의 높이는 log2(n)이 된다.

n이 2의 제곱이 아니더라도 모든 배열의 수를 저장해야 하므로 ceil을 사용해서 올림처리 해준다.

높이가 h인 트리의 노드의 수는 2^(h + 1) - 1이다.

트리 사이즈를 2^(h + 1)로 설정하여 첫 노드의 인덱스로 1을 사용할 수 있도록 하였다.

 

long long init(int node, int start, int end) {
	if (start == end) return segtree[node] = arr[start];
	int mid = (start + end) / 2;
	return segtree[node] = init(node * 2, start, mid) + init(node * 2 + 1, mid + 1, end);
}

init 함수를 통해 트리를 만들어준다.

node에는 인덱스, start에는 배열의 시작 인덱스, end에는 배열의 끝 인덱스를 넣어준다.

만약 start와 end가 같아면 원소가 1개인 것이므로 segtree[node]에 arr[start]를 넣어주고 끝낸다.

아니라면 반으로 쪼개어 재귀 호출을 해준다.

이 때 구간합을 구해야 하므로 왼쪽 자식 값과 오른쪽 자식 값의 합을 저장해준다.

 

                  15
               /      \
            6            9
          /   \        /   \
        3      3      4     5
       / \ 
      1   2

예제의 1, 2, 3, 4, 5를 넣고 init을 돌렸다면 위와 같은 형태로 저장이 되어있을 것이다.

 

	for (int i = 0; i < m + k; i++) {
		cin >> a >> b >> c;
		b--;
		if (a == 1) {
			long long diff = c - arr[b];
			arr[b] = c;
			update(1, 0, n - 1, b, diff);
		}
		else if (a == 2) {
            c--;
			cout << sum(1, 0, n - 1, b, c) << endl;
		}
	}

트리를 완성 했다면 세 수를 입력받고 a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력 해준다.

a가 1인 경우 b번째 수를 c로 바꾸는 것 부터 확인을 해보면 c와 b번째 수의 차를 diff에 저장해두고 b번째 수를 c로 바꾸어준다.

이후 update 함수를 호출하여 b번째 수를 포함하는 구간 값을 업데이트 해준다.

 

void update(int node, int start, int end, int idx, long long diff) {
	if (idx < start || idx > end) return;
	segtree[node] += diff;
	if (start != end) {
		int mid = (start + end) / 2;
		update(node * 2, start, mid, idx, diff);
		update(node * 2 + 1, mid + 1, end, idx, diff);
	}
}

만약 인덱스가 start보다 작거나 end보다 크다면 구간을 벗어간 것이므로 return 해준다.

구간에 포함 된다면 해당 인덱스 값에 diff를 더해준다.

만약 start와 end가 다르다면 자식 노드가 존재하는 것이므로 자식노드도 업데이트 해준다.

 

                  18
               /      \
            9            9
          /   \        /   \
        3      6      4     5
       / \ 
      1   2

예제에서의 1 3 6을 입력하여 업데이트 하면 다음과 같이 업데이트 되어있을 것이다.

 

 

long long sum(int node, int start, int end, int left, int right) {
	if (left > end || right < start) return 0;
	if (left <= start && end <= right) return segtree[node];
	int mid = (start + end) / 2;
	return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
}

a가 2라면 구간합을 출력해야한다.

구간의 시작 인덱스를 left, 끝 인덱스를 end라고 했을 때, left가 end보다 크거나 right 가 start보다 작다면 구간을 벗어난 것이므로 0을 return 해준다. 0을 return하는 이유는 덧셈의 항등원이기 때문이다.

만약 left가 start이하이고 right가 end이상이라면 해당 구간을 포함하기 때문에 해당 인덱스 값을 return 해주면 된다.

만약 구간을 완전히 포함하지 못한다면 반을 쪼개어 재귀 호출을 해준다.

 

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

vector<long long> arr;
vector<long long> segtree;

long long init(int node, int start, int end) {
	if (start == end) return segtree[node] = arr[start];
	int mid = (start + end) / 2;
	return segtree[node] = init(node * 2, start, mid) + init(node * 2 + 1, mid + 1, end);
}

long long sum(int node, int start, int end, int left, int right) {
	if (left > end || right < start) return 0;
	if (left <= start && end <= right) return segtree[node];
	int mid = (start + end) / 2;
	return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
}

void update(int node, int start, int end, int idx, long long diff) {
	if (idx < start || idx > end) return;
	segtree[node] += diff;
	if (start != end) {
		int mid = (start + end) / 2;
		update(node * 2, start, mid, idx, diff);
		update(node * 2 + 1, mid + 1, end, idx, diff);
	}
}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);
	int n, m, k, a, b;
	long long c;
	cin >> n >> m >> k;
	arr.resize(n);
	for (int i = 0; i < n; i++) {
		cin >> arr[i];
	}
	int h = (int)ceil(log2(n));
	int tree_size = (1 << (h + 1));
	segtree.resize(tree_size);
	init(1, 0, n - 1);
	for (int i = 0; i < m + k; i++) {
		cin >> a >> b >> c;
		b--;
		if (a == 1) {
			long long diff = c - arr[b];
			arr[b] = c;
			update(1, 0, n - 1, b, diff);
		}
		else if (a == 2) {
            c--;
			cout << sum(1, 0, n - 1, b, c) << endl;
		}
	}
	return 0;
}

전체코드

 

알고리즘

- 자료 구조

- 세그먼트 트리

'알고리즘 > 백준' 카테고리의 다른 글

[백준 1916] 최소비용 구하기 (C++)  (0) 2023.08.08
[백준 1238] 파티 (C++)  (0) 2023.08.07
[백준 1753] 최단경로 (C++)  (0) 2023.08.06
[백준 11438] LCA2 (C++)  (0) 2023.07.25
[백준 2293] 동전1 (C++)  (0) 2023.07.22

+ Recent posts