알고리즘 - 기초

다이나믹 프로그래밍 - 기초

치킨먹고싶어요 2022. 5. 26. 11:11

다이나믹 프로그래밍에 대해서 알아봅시다.

기초 단계에서 다이나믹 프로그래밍은 브루트 포스(완전탐색법)으로 구현하였을때 너무 오래걸리는 문제를 풀기위하여

이전에 구한 답을 기억하였다가 활용하는 것을 뜻합니다. (메모이제이션)

 

일단 문제와 함께 이해해 봅시다.

https://www.acmicpc.net/problem/1463

 

1463번: 1로 만들기

첫째 줄에 1보다 크거나 같고, 106보다 작거나 같은 정수 N이 주어진다.

www.acmicpc.net

 

문제 요약을 하자면 N을 나누기 3을 하던 N / 2를 하던 N - 1 해서 1로 만들어야 합니다. 1로 만들 때의 최소 계산 횟수가 몇 번인가를 구해야 합니다.

푸는 방법을 생각해 보면 처음에는 'N을 최대한 작게 만들기 위해서는 N / 3, N / 2, N - 1 순으로 사용해야겠구나!'라고 생각할 수 있습니다. 그리디(Greedy) 하게 숫자를 적게 만들수록 1에 가까워진다는 생각이죠.
그러나 반례가 있습니다. N = 10일 때는, 10 -> 5 -> 4 -> 2 -> 1 보다 10 -> 9 -> 3 -> 1 가 유리하죠. 분명히 1번의 계산으로 10 -> 5가 되어 10 -> 9 보다 훨씬 더 줄어들었지만 계산을 해보면 10 -> 5는 총 5번, 10 -> 9는 총 3번의 계산으로 1이 나옵니다
그렇다면 반례가 있을 수 있기에 모든 경우의 수를 다 찾아가는 완전 탐색법을 사용해야 할 것입니다.

#include <iostream>
#include <algorithm> // min을 쓰기 위한 헤더 파일
#define fastio() ios::sync_with_stdio(0),cin.tie(0); // 빠른 입출력을 위한 코드
#define INF 100000000
using namespace std;
int solve(int n, int cnt) {
    if (n == 1return cnt; // n이 1이 된다면 cnt를 리턴한다
    else {
        int a, b, c;
        a = INF; b = INF; c = INF; // a, b, c를 cnt로는 나올 수 없는 큰 값을 넣어준다.
        if ((n % 3== 0) a = solve(n / 3, cnt + 1); // n이 3으로 나누어 질때의 cnt를 리턴한다 
        if ((n % 2== 0) b = solve(n / 2, cnt + 1); // n이 2로 나누어 질때의 cnt를 리턴한다
        c = solve(n - 1, cnt + 1); // n - 1로 바꾼 후, 최소 값을 탐색힌다
        return min({a, b, c}); // min(,) 은 2개의 값을 비교하지만 min({, ,})은 여러개의 값의 최소값을 찾을 수 있다.
    }
}
 
int main() {
    fastio();
    int N; cin >> N;
    cout << solve(N, 0); 
    return 0;
}
 
cs

완전 탐색법을 사용하면 위와 같은 코드가 완성됩니다.

 

그러나 이 코드는 너무 느립니다.

왜냐하면 이미 계산한 값도 매번 새롭게 계산하니 시간이 엄청나게 걸릴 수밖에 없죠.
그렇다면 다이나믹 프로그래밍을 사용하면 더 빠른 코드를 만들 수 있을까요?

 

5의 트리

일단 5를 완전 탐색법으로 트리를 그려보았습니다. 5는 3번의 계산으로 1이 나오는군요.

7의 트리

그리고 7을 완전 탐색법으로 트리를 그리려고 했습니다. 그러나 또 5를 계산하자니 막막하군요. 

이전에 나온 트리를 활용할 수는 없을까요?

 

5의 트리

분명히 이전 트리에서 5는 3이라는 것을 알았습니다. 그렇다면

5에 도착했다면 단순히 3을 더해주면 7 - 6 - 5로 갔을때의 cnt를 알 수 있습니다.

5까지 가는데 2번 계산 했으니 7 - > 6 -> 5 -> ... -> 1은 총 5번의 계산이 되겠군요!

 

그렇다면 기억해서 사용하는 메모이제이션 방법으로 다이나믹 프로그래밍을 구현하여

모든 숫자에 대하여 값을 기억하여 활용해 봅시다.

#include <iostream>
#include <algorithm> // min을 쓰기 위한 헤더 파일
#define fastio() ios::sync_with_stdio(0),cin.tie(0); // 빠른 입출력을 위한 코드
#define INF 100000000
using namespace std;
int dp[1111111]; // 이전에 나온 값을 기억할 배열, 전역변수 이므로 0으로 초기화 되어있다.
int solve(int n, int cnt) {
    if (n == 1return cnt; // n이 1이 된다면 cnt를 리턴한다
    else if (dp[n] != 0return dp[n] + cnt; // 이전에 나온 값이라면 재귀할 필요없이 n까지 걸린 cnt와 n이 1이 되는데 걸리는 cnt를 더해준다 
    // dp[n]이 0이 아니라면 이미 계산했다는 뜻이다.
    else { 
        int a, b, c;
        a = INF; b = INF; c = INF; // a, b, c를 cnt로는 나올 수 없는 큰 값을 넣어준다.
        if ((n % 3== 0) a = solve(n / 3, cnt + 1); // n이 3으로 나누어 질때의 cnt를 리턴한다 
        if ((n % 2== 0) b = solve(n / 2, cnt + 1); // n이 2로 나누어 질때의 cnt를 리턴한다
        c = solve(n - 1, cnt + 1); // n - 1로 바꾼 후, 최소 값을 탐색힌다
        dp[n] = min({a, b, c}) - cnt; // dp는 순수히 n에서 1까지 걸리는 값을 구해주어야 하므로 - cnt를 해준다
        return dp[n] + cnt; // 그러나 리턴할때는 n까지 걸리는 값과 n이 1이 되는데 걸리는 값을 리턴한다
    }
}
 
int main() {
    fastio();
    int N; cin >> N;
    cout << solve(N, 0);
    return 0;
}
 
cs

이와 같이 dp라는 배열로 n에서 1까지 걸리는 값을 구해주고, N에서 n까지 걸리는 cnt를 더해주면 답이 나옵니다