문제 링크: https://leetcode.com/problems/sum-of-distances-in-tree/

 

Sum of Distances in Tree - LeetCode

Sum of Distances in Tree - There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges. You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the

leetcode.com

 

문제

문제를 요약하자면 tree가 주어질 때 모든 정점으로부터의 거리가 최소가 되는 정점에서의 모든 정점으로까지의 길이의 합을 출력하는 문제가 되겠습니다.

직관

tree기 때문에 Cycle이 없습니다. 따라서 한 edge를 제거했을 때,  그래프로 나눠집니다.

A와 B를 잇는 edge가 있다고 가정해봅시다. 그리고, A쪽 그래프에있는 모든 노드로부터 A까지의 거리를 알 수 있다고 가정해봅시다.

그럼 B쪽 노드의 관점에서 봤을 때는 우리는 모든 A쪽 그래프의 노드로부터 B까지의 거리의 총합을 다음 수식으로 구할 수 있습니다. 

$$TotalDist_{A side} + NodeCnt_{A side}$$

왜냐하면 A쪽 노드로부터 B쪽 노드로 오는 모든 길은 edge A->B를 지날 수 밖이 없기 때문이죠.

접근법

그러므로 우리는 DP를 계산할 수 있습니다. 

노드 N과 S가 있을 때, edge (N,S)에 의해서 N쪽 graph와 S쪽 graph 로 나뉘게됩니다.

그리고
$cnt[n][s] = $ $n$ 쪽 그래프에 있는 노드의 수
$dist[n][s] =$ N쪽 그래프에서 S로 가는 총 거리의 합

라고 정의할 때 우리는 이 두 배열을 다음과 같은 방식으로 업데이트할 수 있습니다.

$G[n]$ 을 노드 N과 이어진 모든 이웃 노드들을 포함하고 있을 때


$$ dp[n][s] = sum_{i \in G[n], i \neq s}(dp[n][i] + cnt[n][i])$$
$$ cnt[n][s] = sum_{i \in G[n], i \neq s}(cnt[n][i])$$

로 계산되게 됩니다. 

복잡도

이 풀이는 시간복잡도와 공간복잡도가 오직 edge의 개수의 영향받으며, edge는 총 N-1개 존재합니다.

시간 복잡도: $O(N)$

공간 복잡도:  $O(N)$

코드

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        self.G = [[] for i in range(n)]
        
        for e in edges:
            self.G[e[0]].append(e[1])
            self.G[e[1]].append(e[0])
        
        ans = []
        for i in range(n):
            ans.append(self.dp(i, -1)[0])
        return ans
    
    @cache
    def dp(self, pnode, source):
        dist = 0
        cnt = 0
        for n in self.G[pnode]:
            if n == source:
                continue
            n_dist, n_cnt = self.dp(n, pnode)
            dist += n_dist + n_cnt
            cnt += n_cnt
        cnt += 1
        return dist, cnt

 

영문 솔루션

https://leetcode.com/problems/sum-of-distances-in-tree/discuss/2939374/Python-Super-Simple-O(N)-DP-Solution 

라떼는 말이야... 다 C언어 부터 배웠어...

그런데 요즘은 다 python으로 문제풀고 말이야...

Python으로 문제를 풀어?

그렇습니다. 이제는 python으로도 문제를 풀수 있는 시대가 도래했습니다.

C나 C++로 문제를 풀던 많은 사람들이 python으로 문제푸는 것을 두려워합니다.

python의 입출력 방식이 C나 C++과 매우 사맛디 아니하기 때문이죠.

하지만 과연 그럴까요? 익숙해지면 C나 C++보다도 손쉽게 input을 받을 수 있습니다.

머신러닝을 전공하시던 분들은 취직시장에 나와보니 아뿔사 코딩테스트는 봐야하는데, 익숙한 언어가 python 뿐입니다.

알고리즘을 하시는 분들에게 문의를 해보니,

본디 알고리즘은 C++이 짱이다!

python이 세상에서 제일 편한 언어가 되어버렸는데, C++이라뇨...

코딩 테스트 때문에 새로운 언어를 익혀야 하는 불상사를 막고자 이번 글을 작성했습니다.

본 게시글은 맛보기인데, 제가 python으로 문제풀 때 주로 사용하는 테크닉들을 알려드리도록 하겠습니다.

제가 심심할 때 푸는 문제의 솔루션이나, 핫한 코딩테스트 문제들도 여기에 솔루션을 올리도록 할께요.

여러분들의 시간은 소중하니까.

인풋 받기

정수 하나 받기

코드

인풋을 받는 여러가지 방법들이 있습니다. 간단하게 정수 하나를 입력으로 받아 볼까요?

# "10"
N = int(input())
print(N)

이건 솔직히 새로울게 없습니다. input함수를 받고, 그 결과를 int로 캐스팅한 후, 변수 N에 넣어주는 것입니다.

input란?

input()은 입력을 받는 함수라고 알고 있습니다.

그리고 input()의 return형은 항상 str 형입니다.

그렇기 때문에, 들어온 입력을 적절히 가공해 주어야 합니다.

int()

지금으로서는 정수로 변환해주는 캐스팅 같은 역할을 한다고 생각하시면 좋을 것 같습니다.

계속 넘어가도록 하죠. 더 자세히 알 필요가 있다면 더 깊게 설명하도록 하겠습니다.

정수 두개 받기

코드

# 10 20
N, M = map(int, input().split(" "))
print(N+M) # 30

input이 한줄에 2개가 들어온다면 매우 귀찮아 집니다.

2개부터는 각각 인풋을 받아서 나누고 거기에 int로 형까지 변환해줘야합니다.

매우매우 귀찮은 일이 아닐 수 없죠.

위의 코드로 입력을 손쉽게 받을 수 있습니다.

이 코드는 아래 코드와 동일한 방식으로 작동하게 됩니다.

# 10 20
inp = input()
sp_inp = inp.split(" ")
N = int(sp_inp[0])
M = int(sp_inp[1])

int의 정체

N, M = map(int , input().split(" "))

아실지 모르겠지만 python에서 캐스팅하드시 사용하는 int, float와 같은 자료형처럼 보이는 모든 것들이 모두 함수의 형태로 사용합니다.

엄연히 말하자면 class의 형태를 가지고, int로 변환해주는 캐스팅 작업을 해줍니다.

python의 정확한 구조를 모르신다면 함수라고 생각하셔도 무방합니다! 단지 int라는 이름을 가진 함수다 라고 생각하셔도 괜찮습니다.

input 함수가 하는 일

N, M = map(int , input().split(" "))

위에서 설명했듯이 input은 입력을 받는 함수이며, 동시에 str 를 return 합니다.

그리고 더 정확히는 한 라인을 통째로 받는 함수입니다.

이 부분의 기존의 Problem Solving (PS) 계열에 있던 사람들이 고통 받는 이유중 하나입니다.

PS러들은 문자열을 스페이스 단위로 받아서 처리하고 싶어하기 때문이죠...

입력이 뭐가 들어왔던 간에 str로 한 라인을 통째로 받기 때문에 이를 필요에 따라 여러가지 형태로 가공해줘야하는 것이 매우 거슬릴 것입니다.

str의 내장 함수 split

N, M = map(int, input().split(" "))

str에는 다양한 내장 함수들이 있습니다. 이를 통해서 손쉽게 string을 가공할 수 있습니다.

ps러라면 #include <string.h> 나 #include 을 하고 고통받아 보신 분이 있을 겁니다.

python에서는 훨씬 더 쉬운 형태의 인터페이스를 제공 합니다!

pycharm같은 훌륭한 ide를 쓰신다면 "abc". 하고 컨트롤 탭을 하시면 다양한 함수들을 구경하실 수 있을 겁니다.

오늘 알아볼 함수는 split입니다.

이름부터 직관적이게 해당 문자열을 나눠주는 함수입니다.

아래 코드를 확인하시면 어떤 짓을 할 수 있는 녀석인지 직관적으로 깨달으실 수 있을 겁니다.

"10 20 30".split(" ") # ["10", "20", "30"]
"10,20,30".split(",") # ["10", "20", "30"]
"10,20,30,".split(",") # ["10", "20", "30", ""]
"10 20 30 ".split(" ") # ["10", "20", "30", ""]

다만 마지막 두 예제를 조심하세요!

마지막에도 delimiter (구분자?)가 있다면 비어 있는 string을 return할 거에요!

위와 같은 예제를 피하기 위해서는 아래와 같이 strip 함수를 사용해주면 됩니다.

"10 20 30 ".strip().split(" ") # ["10", "20", "30"]

map 함수

이제 마지막으로, map함수를 알아봅시다.

map 함수는 첫번째 인자로 함수를, 두번째 인자로 배열을 받습니다. 그래서 각 배열의 원소에 해당 함수를 적용해주는 함수입니다.

  • (사족1) 사실 함수가 아니라 callable 한 모든 것들이 여기에 올 수 있습니다.
  • (사족2) 사실 배열이 아니라 iterator를 받습니다. 대충 반복할수 있는거라고 생각하시면 될것 같습니다.

말로는 잘 이해가지 않을 것이라고 생각해요. 예제로 이해하는 편이 좋을 겁니다.

arr = ["10", "20", "30"]
map(int, arr) 
# == [int(arr[0]), int(arr[1]), int(arr[2])]
# == [10, 20, 30]

이제 https://www.acmicpc.net/problem/1000 를 풀어볼까요?

마치며

아니 벌써 마치다니요? 사실 맥주 먹으러 가야해서 시간이 없네요.

다 먹고 살자고 하는 건데 맥주한잔하고 오겠습니다.

2탄에서 계속 하도록 할게요!

+ Recent posts