최소 신장 트리(최소 스패닝 트리)

최소 신장 트리를 배워봅시다.

Imagem de capa

이제 난이도가 있는 알고리즘을 서서히 배워나가봅시다.

최소 신장 트리는 그래프와 구별을 해야하는데요! 트리와 그래프의 차이는, 트리는 사이클이 존재할 수 없지만 그래프는 사이클이 존재할 수 있다는 점입니다.

이 알고리즘을 풀기 위해서는 프림과 크루스칼을 알아야하는데요. 프림은 정점 위주로 간선을 선택해 나갑니다. 반면에 크루스칼은 간선 위주로 사이클이 생성되지않게 간선들을 선택해 나갑니다.

아래 링크를 통해 이론과 그림을 자세히 봐봅시다!

프림 알고리즘

크루스칼 알고리즘

이제 선택해나가는 방법을 알게되었다면 코딩으로 한번 풀어보죠!

백준 1197번

프림


import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main{

    static final int INF = 100000000;

    public static void main(String[] args) throws IOException {

        int V; // V 그룹 정점의 개수
        int E; // 간선의 개수
        int[][] w; // 간선의 정보
        int sumOfWeight = 0; // MST 가중치의 합
        boolean[] visited;
        int numOfMST = 0; // MST 그룹 정점의 개수
        int[] d; // MST 그룹에서 V그룹의 vertex까지의 거리

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

        StringTokenizer st = new StringTokenizer(br.readLine());

        V = Integer.parseInt(st.nextToken());
        E = Integer.parseInt(st.nextToken());

        visited = new boolean[V + 1];
        w = new int[V + 1][V + 1];
        d = new int[V + 1];

        for (int i = 0; i < V + 1; i++) {

            d[i] = INF;
            visited[i] = false;
            for (int j = 0; j < V + 1; j++) {

                w[i][j] = INF;
            }
        }

        // 간선 정보 입력
        for (int i = 0; i < E; i++) {
            st = new StringTokenizer(br.readLine());

            int A = Integer.parseInt(st.nextToken());
            int B = Integer.parseInt(st.nextToken());
            int C = Integer.parseInt(st.nextToken());

            if (w[A][B] > C) {

                // 무방향 그래프
                w[A][B] = C;
                w[B][A] = C;
            }
        }

        // 초기값
        numOfMST = 0;
        // 첫 시작 정점을 1로 시작하기 위한 코드입니다.
        d[1] = 0;
 	//모든 노드 방문을 하기전까지는 아래 코드를 반복합니다.
        while (numOfMST < V) {

            int min = INF;
            int u = -1;

            // 처음 정점은 1로 시작하고 다음 for문에서 간선 d를 통해 가장 작은 값을 선택해 다른 정점으로 향해갑니다.
            for (int i = 1; i < V + 1; i++) {

                if (!visited[i] && d[i] < min) {

                    min = d[i];
                    u = i;
                }
            }

            // 선택된 정점 u에서부터 i까지의 거리를 통해 d를 갱신합니다. 그럼 연결된 정점은 d에 차곡차곡 저장이 되니 간선 비교가 되겠죠?
            for (int i = 1; i < V + 1; i++) {

                if (!visited[i] && w[u][i] != INF) {

                    if (w[u][i] < d[i]) {

                        d[i] = w[u][i];
                    }
                }
            }

            numOfMST++;
            sumOfWeight += min;
            visited[u] = true;

        }

        // 결과 출력
        System.out.println(sumOfWeight);

        br.close();

    }

}

크루스칼

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.StringTokenizer;

/*
   간선을 추가할 때마다 트리가 사이클을 이루는지 판단해야 하는데 그 자료구조가 바로 Union-Find 자료 구조입니다.
 */

public class Main{

    static final int INF = 100000000;

    public static void main(String[] args) throws IOException {

        int V; // V 그룹 정점의 개수
        int E; // 간선의 개수
        int weightSumOfMST = 0; // MST 가중치의 합
        int edgeNumOfMST = 0; // MST 정점의 개수

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

        StringTokenizer st = new StringTokenizer(br.readLine());

        V = Integer.parseInt(st.nextToken());
        E = Integer.parseInt(st.nextToken());

        Queue<Edge> edgePriorityQueue = new PriorityQueue<Edge>();

        for (int i = 0; i < E; i++) {

            st = new StringTokenizer(br.readLine());

            int A = Integer.parseInt(st.nextToken());
            int B = Integer.parseInt(st.nextToken());
            int C = Integer.parseInt(st.nextToken());

            // 우선순위큐-간선의 Weight 기준으로 정렬시킵니다.
            edgePriorityQueue.add(new Edge(A, B, C));
        }

        UnionFind unionFind = new UnionFind(V);
        //정점을 다 방문하지않았고 큐가 비기전까지 돌려줍니다.
        while (edgeNumOfMST < V && !edgePriorityQueue.isEmpty()) {
        	//가장 짧은 간선을 꺼내 이 간선의 정점들이 사이클을 이루는지 확인을 합니다.
            Edge edge = edgePriorityQueue.poll();

            // Union-Find로 사이클을 이루는지 확인
            // 사이클을 이룬다면
            if (unionFind.find(edge.v1) == unionFind.find(edge.v2)) {

                continue;
            }
            // 사이클을 이루지 않는다면 union함수를 통해 그룹을 지어줍니다.
            else {

                unionFind.union(edge.v1, edge.v2);
                weightSumOfMST += edge.weight;
                edgeNumOfMST++;
            }
        }

        // 결과 출력
        System.out.println(weightSumOfMST);

        br.close();

    }

}

class UnionFind {

    int[] root;

    public UnionFind(int V) {

        root = new int[V + 1];

        initialize();
    }
    //배열 값이 음수가 나올때까지 find함수를 돌려 같은 그룹인지 확인해줍니다.
    public int find(int a) {

        if (root[a] < 0) {

            return a;
        }

        return root[a] = find(root[a]);
    }

    public void union(int a, int b) {

        int root1 = find(a);
        int root2 = find(b);

        // 이미 같은 그룹이라면
        if (root1 == root2) {

            return;
        }

        // 다른 그룹이라면

        // root1의 그룹이 더 작다면 (root1 < root2)
        if (root[root1] > root[root2]) {

            root1 ^= root2;
            root2 ^= root1;
            root1 ^= root2;
        }

        // root1과 root2를 결합하고, root2의 부모를 root1로 설정
        root[root1] += root[root2];
        root[root2] = root1;

    }
    //루트 배열의 값을 -1로 초기화합니다.
    private void initialize() {

        for (int i = 0; i < root.length; i++) {

            root[i] = -1;
        }
    }

    // a를 포함하는 그룹의 정점의 개수를 확인하는 함수
    public int size(int a) {

        return -root[find(a)];
    }
}

class Edge implements Comparable<Edge> {

    int v1;
    int v2;
    int weight;

    Edge(int v1, int v2, int weight) {

        this.v1 = v1;
        this.v2 = v2;
        this.weight = weight;
    }

//오름차순

    @Override
    public int compareTo(Edge o) {

        return (this.weight > o.weight ? 1 : (this.weight == o.weight ? 0 : -1));
    }

}