문제를 요약하자면 tree가 주어질 때 모든 정점으로부터의 거리가 최소가 되는 정점에서의 모든 정점으로까지의 길이의 합을 출력하는 문제가 되겠습니다.
직관
tree기 때문에 Cycle이 없습니다. 따라서 한 edge를 제거했을 때, 그래프로 나눠집니다.
A와 B를 잇는 edge가 있다고 가정해봅시다. 그리고, A쪽 그래프에있는 모든 노드로부터 A까지의 거리를 알 수 있다고 가정해봅시다.
그럼 B쪽 노드의 관점에서 봤을 때는 우리는 모든 A쪽 그래프의 노드로부터 B까지의 거리의 총합을 다음 수식으로 구할 수 있습니다.
$$TotalDist_{A side} + NodeCnt_{A side}$$
왜냐하면 A쪽 노드로부터 B쪽 노드로 오는 모든 길은 edge A->B를 지날 수 밖이 없기 때문이죠.
접근법
그러므로 우리는 DP를 계산할 수 있습니다.
노드 N과 S가 있을 때, edge (N,S)에 의해서 N쪽 graph와 S쪽 graph 로 나뉘게됩니다.
그리고 $cnt[n][s] = $ $n$ 쪽 그래프에 있는 노드의 수 $dist[n][s] =$ N쪽 그래프에서 S로 가는 총 거리의 합
라고 정의할 때 우리는 이 두 배열을 다음과 같은 방식으로 업데이트할 수 있습니다.
$G[n]$ 을 노드 N과 이어진 모든 이웃 노드들을 포함하고 있을 때
$$ dp[n][s] = sum_{i \in G[n], i \neq s}(dp[n][i] + cnt[n][i])$$ $$ cnt[n][s] = sum_{i \in G[n], i \neq s}(cnt[n][i])$$
로 계산되게 됩니다.
복잡도
이 풀이는 시간복잡도와 공간복잡도가 오직 edge의 개수의 영향받으며, edge는 총 N-1개 존재합니다.
시간 복잡도: $O(N)$
공간 복잡도: $O(N)$
코드
class Solution:
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
self.G = [[] for i in range(n)]
for e in edges:
self.G[e[0]].append(e[1])
self.G[e[1]].append(e[0])
ans = []
for i in range(n):
ans.append(self.dp(i, -1)[0])
return ans
@cache
def dp(self, pnode, source):
dist = 0
cnt = 0
for n in self.G[pnode]:
if n == source:
continue
n_dist, n_cnt = self.dp(n, pnode)
dist += n_dist + n_cnt
cnt += n_cnt
cnt += 1
return dist, cnt
제가 예전에 담당했던 업무중에 사람 얼굴의 특징점 (눈, 코, 입 등이라고생각하면 편하다) 가지고 기능을 짜야했던 적이 있었습니다. 지금도 그렇지만 그때는 텐서감수성이 떨어질 때라서, 현재 짜고 있는 코드가 맞게 동작하는지 잘 이해가 안가는 경우가 굉장히 많았습니다. 🙂
특히, 이게 단순 산술연산이 아니라 이미지상에서 어느 부분을 짜르거나, 어느 부분을 기준으로 뭔짓을 가해야하는 종류라면 정말 디버깅에 많은 시간이 걸립니다. (예를 들면 얼굴 부분만 자르기!)
짜도 그 이미지에 알맞게 적용됐는지, 머리속으로 상상은 해도 실제로 그렇게 짜졌는지 실행하기 전까지 불안에 떨어야 했죠.
따라서 디버깅 하면서 이미지를 보는 건 디버깅 시에 확실히 이점이 있습니다. 보면서 실수를 그때 그때 잡을 수 있으니까요. 특히 저 같은 뉴비 머신러닝러는 더더더더더더욱 그렇습니다.
이처럼 머신러닝, 특히 비전관련한 테스크를 하다보면 디버깅을 할 때 이미지 파일을 보면서 할 때 더편한 경우가 왕왕 있습니다.
대안들
중간중간 이미지를 보는 작업은 생각보다 편하지만, 단순하게 이미지를 로그처럼 찍는 방법은 생각보다 불편합니다. 이렇게 할 경우 폴더를 열어서 이미지를 켜봐야 하고, 코드에 삽입해서 작성함으로 그때 그때 모델을 실행해야하죠.
특히 후처리를 해야할 경우 매 실행마다 모델을 로드해야하기 때문에 더 오랜 시간이 걸리게 됩니다. 이것 때문에 코드하나 고치고 다시 재시작하고, 모델을 로드하는데 많은 시간을 소모하게 됩니다. 😢
다른 방법으로는 plt.show() 같은 짓을 코드에 삽입하는 겁니다. 그러면 scientific mode로 pycharm을 사용하면 되죠. 하지만 전 이 scientific mode를 좋아하지 않습니다. 솔직히 창도 번거롭게 많아지고 생각보다 기능이 불편했습니다. 제가 잘 사용하지 못하는걸까요.
플러그인
그래서 저는 주로 pycharm의 디버거를 사용하는데, 이 디버거에서 사용하기 매우 좋은 플로그인을 소개하고자 합니다.
오늘 알아볼 내용은 GAN의 증명과 한계 그리고 해결법에 관한 내용입니다. 오늘 내용은 WGAN을 들어가기 앞서 필요한 GAN에 대한 증명과 한계들을 다뤄볼 예정입니다.
수학적인 내용을 최소한으로 준비했지만, 도저히 못덜어낸 수학들이 좀 많이 남았습니다.
작성한 글에 오류가 있다면 코멘트 달아주세요! 그럼 바로 시작하겠습니다.
복습
역시 복습을 안하고 갈수는 없겠죠?
복습에서는 이전 포스팅의 내용을 요약하도록 하겠습니다.
GAN의 Component
GAN은 두가지 component로 구성되어 있습니다. Generator와 Discriminator입니다. Generator는 실제 데이터와 유사한 분포의 데이터를 생성하기 위해 노력하는 Component죠. Discriminator는 Generator가 만든 가짜 데이터와 실제 Data를 구별해주는 Component였습니다. Discriminator로 머신러닝 네트워크를 사용하는 이유는 generator가 생성하는 분포와 실제 데이터 분포를 수식화 하기 힘든 매우 어려운 분포이기 때문입니다.
GAN의 Sampling
그리고, GAN은 latent space에서 latent vector z를 sampling합니다. 초기의 latent space는 아무런 의미를 가지지 않기 때문에 여기서 sampling된 latent vector z도 아무런 의미를 가지지 않습니다만, 학습을 통해서 latent vector와 특정 데이터를 연결하게 되면서 latent space가 의미를 가지도록 바꿔줍니다. 마치 사람들이 아무의미없는 것들에 의미를 부여하는것처럼 네트워크도 경험적으로 latent space에 의미를 부여하기 시작할겁니다. ㅎㅎ
GAN의 학습 Algorithm
GAN의 학습은 2가지 step으로 이루어졌었죠. 첫번째 스텝에서는 Generator를 고정시키고, Discriminator를 학습시킵니다. 두번째 스텝에서는 Discriminator를 고정시키고, Generator를 학습시킵니다. 그리고 고정시킴에 따라 Loss Function이 일부 변경되었었습니다.
전체를 요약해보자면
GAN의 전체 구조를 요약해보자면 GAN은 latent space에서 sampling 한 z로부터 Generator는 어떠한 데이터 분포를 만듭니다. 그리고 이렇게 생성한 Generator의 분포와 실제 데이터 분포를 Discriminator로부터 비교하게 함으로서 G가 잘생성하면 D를 혼내주고, D가 잘 구분하면 G를 혼내주는 방법으로 학습을 수행합니다.
아래와같은 구조를 가지고 있죠.
GAN과 JS Divergence
직관적으로 GAN이 왜 동작하는지는 쉽게 이해가 됩니다. 그런데 직관적으로만 ㅇㅋ 하고 가면 조금 찝찝하죠. 그래서 GAN이 동작하는 이유에 대해서 증명을 하나하고 넘어가겠습니다.
이 내용은 이 후 WGAN으로 넘어가기 위해서 반드시 필요하다고 생각해서 넣었는데, 생략하시고 싶으시면 바로 WGAN으로 넘어가시면 됩니다.
한줄 요약
최초 제안된 GAN의 Loss function은 JSD를 계산하게 되는데, JSD는 두 분포가 겹치지 않았을 때, 상수 값을 가져서 gradient 계산으로 optimal을 찾는 것이 쉽지 않다.
GAN의 동작 조건
GAN의 목표는 실제 데이터 분포인 $P_{data}$ 와 Generator 가 만든 가짜 데이터 분포인 $P_{g}$ 가 같아지는 것을 목표로 합니다.
즉 GAN의 Loss Function인 아래 식이 최적일 때, $P_{data}=P_g$ 임을 만족해야한다는 뜻이죠.
오우 정신 사납습니다. 결론은 저희의 Loss function은 JSD로 귀결되게 됩니다.
그리고 JSD는 두 분포사이의 거리로 사용할 수 있음으로, 따라서 위의 수식을 최소로 만드는 것은 두 분포사이의 거리를 최소로 만드는 것이란 거죠!
WGAN의 등장
자 이제부터 쉬워져요. 왜냐면 여기서 수학을 다뺏거든요.
한숨돌리시고 갑시다!
JS Divergence는 학습에 적합하지 않다!
WGAN에서 주장하길 좋은친구아저씨가 제안한 loss의 JSD는 학습에 적합하지 않다는 것이었습니다. 이걸 이해하기 위해서는 SUPP를 이해하셔야하는데, 간단히 설명할께요.
지지 집합 SUPP
SUPP는 support라고 읽는데, 한국어로는 지지집합이라고 합니다. 멋진 위키를 참고하자면
수학에서, 함수의 지지집합(支持集合, 영어: support 서포트[*]) 또는 받침은 그 함수가 0이 아닌 점들의 집합의 폐포이다.
X가 위상 공간이고, ${\displaystyle f\colon X\to \mathbb {R} }$이 함수라고 하자. 그렇다면 ${\displaystyle f}$의 지지집합 ${\displaystyle \operatorname {supp} f}$는 다음과 같다. ${\displaystyle \operatorname {supp} f=\operatorname {cl} {x\in X\colon f(x)\neq 0}}$ 여기서 ${\displaystyle \operatorname {cl} }$ 폐포 연산자다.
인데, 사실 어려운말 다 쳐내고, 이해하시면 좋은 것 단하나 바로 0이 아닌 점들의 집합 입니다. 콤펙트하고 머시기하고 유계 집합 머시기 이런것들이 있는데 사실 저도 잘모릅니다. 수학과가아니라서요. 개념상 0이 아닌 점들의 집합으로 이해하고 넘어가셔도 큰문제가 없습니다.
그림으로 표현해보겠습니다. 두 분포 $\mathbb{P}_r$ 과 $\mathbb{P}_g$ 가 있다고 하겠습니다. 그 분포가 아래와 같이 있다고 가정해볼께요.
여기서 $SUPP \space \space \mathbb{P}_r$ 는 B 공간이 되고, $SUPP \space \space \mathbb{P}_g$ 는 A공간이 된다는 겁니다. 그야말로 0이 아닌 점들의 집합입니다.
이 때, $\mathbb{P}_r(A)=0, \mathbb{P}_r (B)=1$ 이 될 것입니다. 반대로, $\mathbb{P}_g(A)=1, \mathbb{P}_g (B)=0$ 이 되겠죠?
분포의 거리 측정
자 그럼 조금더 직관적인 설명을 위해서, 두 확률 변수를 2차원상의 분포라고 가정해서 생각해보겠습니다.
두 분포가 겹치지 않는다는 사실은 명확하죠. 서로 $0$ 과 $\theta$ 사이에서만 움직이니까요.
여기서 $\theta$ 는 다양한 값을 가질 수 있습니다.
하지만, $\theta \neq0$ 인 경우를 제외한 어떤 경우에도 두 분포가 겹칠 수는 없습니다. 이 경우 두 분포 중 하나가 0이 아닌 값을 가질 때는 다른 분포에서는 무조건 0인 값을 가지게 될 것입니다.
가 된다는 것이죠. 따라서 마찬가지로 $\theta$ 의 값과 상관없이 상항 같은 상수값을 나타내게됩니다.
GAN에서의 의미
이는 GAN을 학습하는데 큰 문제가 됩니다. Loss Function은 가까우면 가깝다고, 멀면 멀다고 명확히 말을 해주어야 이에 따른 gradient를 계산할 수 있습니다. 하지만 두 분포의 SUPP이 겹치지 않는다면 '두 분포가 완전히 다르다.' 라는 정보만 줄 뿐 어떻게 가깝게 만들지에 대한 정보 즉 gradient를 계산할 수 없다는 뜻입니다.
즉, D가 너무 깐깐하게 두 분포를 판단해서, 만약 두 분포가 겹치지 않는다면 두분포를 어떻게 가깝게 만들지에 대한 gradient를 계산할 수 없게 된다는 거죠.
이는 사실 일만적으로 이미지 생성과 같은 높은 dimension의 문제를 푸는 GAN에서는 이러한 분포가 겹치지 않는 문제가 더 심하게 발상할 것이기 때문에 GAN의 학습 성능이 떨어지는 등 꽤 크리티컬한 문제로 다가올 것이었습니다.
이 문제를 해결하러 왔다.
이 문제를 해결하기 위해서 2가지 방법을 소개해드리도록 하겠습니다.
분포를 건드리는 방법
분포의 거리 측정 방식을 개선하는 방법
당연히 우리가 앞으로 할 짓은 후자입니다만, 전자 쪽은 매우 간단하고 직관적으로 이해가 쉬우니까 한번 보고 가실께요.
노이즈를 통한 해결
Support가 겹치지 않는 것이 문제였습니다. 때문에 기존의 이미지에 노이즈 n을 추가해주면서 아래 그림처럼 두 분포의 Support영역을 넓혀 겹칠수 있도록 만들어주는 것입니다. 이렇게 두 분포를 겹치게 만들어주면 JSD를 사용해서 문제를 해결할 수 있습니다.
하지만 이러한 해결방법은 생성된 이미지가 굉장히 흐릿하게 나오는 등 문제가 발생하는 등 성능이 좋지 않게 나왔다고 합니다.
JSD가 아닌 다른 형태의 거리 측정 방식 사용
JSD는 두 분포가 겹치지 않을 때 상수가 나와 gradient를 계산하기 힘든 문제가 있었습니다.
따라서 저자는 분포가 겹치지 않아도 두 분포의 거리를 측정할 수 있는 방법인 Earth Mover Distance를 사용할 것을 제안합니다. 그리고 이제부터가 WGAN의 시작이죠.
WGAN은 다음 포스팅에서 다루도록 하겠습니다. 길이 너무 길어져서 여기서 한번 끊고 가도록 하겠습니다.
Reference
(Paper) Goodfellow, Ian J., et al. "Generative adversarial networks." arXiv preprint arXiv:1406.2661 (2014).
1B는 통과할만한 것같은데, 간만에 문제를 풀어서 그런가 문제 푸는속도도 코드도 영 맘에들지 않는다.
무엇보다 아이디어 떠오르는 속도가 정말 너무 많은 trial-and-error을 필요로한다.
Append Sort
문제는 아래에서 자세히 읽을 수 있다.
문제 설명
정수가 N개가 주어진다.
정수 N개를 정렬하고 싶은데, 정수끼리의 순서를 바꿀수 없다.
우리가수행할 수 있는 연산은 각 정수에다가 숫자들을 더하는 건데 예를 들면 123이 있다면 여기다가 4를 더해서 1234로만드는 이런연산이가능하다.
123 → 1234
이런 연산을 최소한으로 수행해서 strictly increasing order로 만들어라.
sample Input
4 3 100 7 10 2 10 10 3 4 19 1 3 1 2 3
Sample Output
Case #1: 4 Case #2: 1 Case #3: 2 Case #4: 0
설명:
첫번째는 아래처럼 추가하면 100, 700, 1000 으로 오름차순으로 만들수 있다.
100
7 → 700
10 → 1000
총 0을 4번 추가했고, 이거보다 적게 추가해서 오름차순으로 만들 수 없기 때문에 답은 4다
솔루션 코드
for tc in range(int(input())):
N = int(input())
X = list(map(int, input().split(" ")))
ans = 0
for i in range(1, N):
if X[i] > X[i-1]:
continue
s_b = list(str(X[i-1]))
s_p_max = list(str(X[i]))
s_p_min = list(str(X[i]))
count = 0
ori_len = len(s_p_max)
while int("".join(s_b)) >= int("".join(s_p_max)):
s_p_max += "9"
s_p_min += "0"
count += 1
ans += count
if int("".join(s_b)) < int("".join(s_p_min)):
X[i] = int("".join(s_p_min))
else:
# same
X[i] = X[i-1]+1
print(f"Case #{tc+1}: {ans}")
그지처럼 풀었다.
첫번째 원소는 건드리지 않아도 되니까 range(1, N) 으로 시작한다. 시작부터 크다면 신경쓸필요없으니까 그리고 뒤를 9로 일단 채워나가면서(그 자리수에서 체울 수 있는 가장 큰 수) 될때까지 자리수를 늘려나간다.
그리고 늘려나간 최종 값중에서 가장 작은 값 (예를들자면, 10에서 2개의 digit을 추가했다면, 1099이 최대값, 1000이 최솟값이 된다.)이 배열의 이전 값 (s_b, X[i-1]) 보다 크다면 그게 답이 되기 때문에 X[i]에 그 값 자체를 넣어준다. int("".join(s_p_min))
만약 최솟값이 같다면, X[i-1]의 값에 +1해주면 된다. 이부분은 +1해줌으로서 다른자리까지 전부 변화시켜주려면 9여야하는데, 모두 9로 되어있는 경우는 14번 째 줄의
while int("".join(s_b)) >= int("".join(s_p_max))
에서 컷팅당한다. 즉 +1만해줘도된다.
이해가 안가시면 질문주세요. 대충 설명드립니다.
Prime Time
문제는 아래에서 자세히 읽을 수 있다.
소수가 주어진다. (중복된 소수들도 주어진다.)
소수를 두 그룹으로 만들 것인데, 한 그룹은 덧셈으로, 한 그룹은 곱셈으로 모든 연산을 수행할것이다. 이 때 합을 한 그룹과 곱을 한 그룹의 계산 값이 같아야 한다.
예를 들자면
2 - 2개
3 - 1개
5 - 2개
7 - 1개
11 - 1개
가 주어졌을 때,
2+ 2 + 3 + 7 + 11 == 5 * 5
가 되고 두 그룹의 합은 25로 같다. 이렇게 만들 수 있는 가장 큰 값을 구하라.
중요한 인풋 설명
소수는 499 이하의 것들로만 주어진다.
소수는 서로다른 소수 95개 까지 주어질 수있으며, 각 소수는 여러개 존재할 수 있다.
Sample Input
4 5 2 2 3 1 5 2 7 1 11 1 1 17 2 2 2 2 3 1 1 2 7
Sample Output
Case #1: 25 Case #2: 17 Case #3: 0 Case #4: 8
설명:
첫번째 input은 2가 2개, 3이 1개, 5가 2개, 7이 1개 11이 1개다.
이 때, 더하기 그룹은 2, 2, 3, 7, 11을 선택하고, 나머지 곱하기 그룹은 5 5를 선택하면 최대값 25가 뽑인다. 그래서 답은 25
Test Set 1 솔루션 소스코드
$2≤N_1+N_2+⋯+N_M≤10$. (소수의 총 개수가 10개 이하)
from functools import reduce
for tc in range(int(input())):
N = int(input())
P = []
for i in range(N):
p_i, n_i = map(int, input().split())
P += [p_i for j in range(n_i)]
ans = 0
for i in range(1 << len(P)):
arr = list(enumerate([1 if i & (1 << j) else 0 for j in range(len(P))]))
arr_mul = reduce(lambda x, y: x * y, [P[i] if v == 0 else 1 for i, v in arr])
arr_sum = sum([P[i] if v == 1 else 0 for i, v in arr])
if arr_mul == arr_sum:
ans = max(ans, arr_mul)
print(f"Case #{tc+1}: {ans}")
소수의 총 개수가 10개 이하다.
잘모를 땐 다해보자. 모든 소수를 분해한다. 0번그룹 1번그룹
그래서 두개의 곱과 합을 구한 후 두개가 같다면 ans로 출력한다.
Test Set 2 솔루션 소스코드
$2≤N_1+N_2+⋯+N_M≤100$. (소수의 총 개수가 100개 이하)
from functools import reduce
for tc in range(int(input())):
N = int(input())
P = []
UP = {}
for i in range(N):
p_i, n_i = map(int, input().split())
UP[p_i] = n_i
P += [p_i for j in range(n_i)]
all_sum = sum(P)
ans = 0
unique_p = UP.keys()
for t in range(all_sum, 0, -1):
is_fail = False
may_ans = t
tot = 0
for p in unique_p:
p_count = UP[p]
while t % p == 0:
t //= p
p_count -= 1
tot += p
if p_count < 0:
is_fail = True
if is_fail:
break
if is_fail:
break
if t != 1:
is_fail = True
if not is_fail and all_sum - tot == may_ans:
ans = may_ans
break
print(f"Case #{tc+1}: {ans}")
소스코드가 개판이다.
자 이제 어떻게 줄일 수 있을 지 생각해보자. 100개니까 모두를 해보는 건 불가능하다. ($2^{100}$ 은 안봐도 TLE)
합과 곱중 하나를 고정시켜야하는데, 생각해보면 모든 수의 합은 매우 작다. 그러니까 곱하는걸 선택할 때 많은 수를 곱할 수 없다.
100개를 다합쳐보면 최대 얼마가나올까?
$499 * 100 = 49900$ 밖에 안나온다.
그럼 모든 소수를 다더하고, 1씩 빼가면서 그 수를 만들 수 있는지 확인한다. 만약 만드는 것을 실패하면, 불가능하다는 것이고, 만드는 것을 성공하면 만드는데 든 소수의 합을 뺀 것이 지금의 수인지 체크하면 된다. 그 코드는 아래와 같다.
if not is_fail and all_sum - tot == may_ans:
ans = may_ans
break
시간복잡도는 대충 49900 * len(unique_p) * log (49900) 정도 될껀데, 뒤에껀 대충 빼고, 계산해도 49,90,000 정도다. 매우적으니까 당연히 통과다. 문제는 테스트 셋 3인데...
Test Set 3 솔루션 소스코드
$2≤N_1+N_2+⋯+N_M≤10^{15}$. (소수의 총 개수가 $10^{15}$개 이하)
이거 보고 솔직히 아 당연히 수학문제겠지 하고 제꼈다. 구글 수학문제좋아하니까 그럴꺼라고 생각했다.
그런데 멍청했다.
from functools import reduce
for tc in range(int(input())):
N = int(input())
P = []
UP = {}
all_sum = 0
for i in range(N):
p_i, n_i = map(int, input().split())
UP[p_i] = n_i
all_sum += p_i * n_i
ans = 0
unique_p = UP.keys()
for t in range(all_sum, max(0, all_sum-8982), -1):
is_fail = False
may_ans = t
tot = 0
for p in unique_p:
p_count = UP[p]
while t % p == 0:
t //= p
p_count -= 1
tot += p
if p_count < 0:
is_fail = True
if is_fail:
break
if is_fail:
break
if t != 1:
is_fail = True
if not is_fail and all_sum - tot == may_ans:
ans = may_ans
break
print(f"Case #{tc+1}: {ans}")
거의 메인 로직은 똑같고 중요하게 추가된게 아래 이거 하나다.
for t in range(all_sum, max(0, all_sum-8982), -1):
생각해보면, 곱셈으로 아무리 빼고싶어도 소수를 빼는데 한계가 있다. 곱셈이 너무 커지니까. 그러니까 곱셈측에도 컷팅같은 느낌을 줄 수 있다.
잘 생각해보면 499를 $10^{15}$ 더해봤자 $499*10^{15}$ 이다. 아니 엄청 큰 숫자아니냐고? 곱셈의 세계에서는 매우 작은 숫자다.
그럼 소수 최대 몇개를 뺄 수 있을까. 를 계산해보면
$$\text{log}_2(499*10^{15})=17.xx$$
이긴한데, 사실 계산하기 귀찮으니까 아무리그래도 $499^{18}$ 보다는 작을거라고 생각할 수 있다. 그러니까 아무리 빼고싶어도 499를 18번보다 많이 뺄수없으며, 어떤 소수든 18개이상 빼면 전체합을 넘어가게된다.
즉 뺄수있는 것의 범위가 $499*18=8982$보다는 작게된다는 것이다.
그래서 정말 뺄수있는최소값을 range(all_sum-8982, all_sum) 로 잡으면 되는데, all_sum-8982 가 0보다 작을수있으니까 max를 씌워준것일 뿐이다.
이건 대회중에 떠올리지 못한 아이디어다. 답안지보고 아 수학아니었는데... 하고 너무 아쉬워서 딱 저 아이디어 추가하고 메모리문제좀개선해서 제출했떠니 바로 답이 뜨더라.
눈에 보이는 조건들만이 조건이 아닙니다. 들어온 순서나, 길이, 입력받은 상수값 같은 것들도 이에 활용할 수 있습니다. 위 문제에도 이를 활용할 수 있습니다.
문제 설명은 지면이 길어져서 생략하도록 할게요.
입력 예제 1
3 21 Junkyu 21 Dohyun 20 Sunyoung
정답 코드
N = int(input())
D = [input().split(" ") for i in range(N)]
D = sorted([(int(age), order, name) for order, (age, name) in enumerate(D)])
for age, order, name in D:
print(age, name)
이 코드에도 몇가지 보여드릴께 있습니다. 이 코드에서 배울 것은 크게 3가지입니다.
enumerate
for문의 unpack 트릭
order를 정렬의 조건으로 넣기
enumerate
enumerate는 반복문을 돌 때, 배열의 인덱스, 배열의 원소 두가지를 활용할 수 있게 해줍니다.
3개라구요? 아닙니다 잘 보세요! 2개입니다.
order, (age, name) 2개라구요.
for문에서 이름 붙여서 가져오기
이것이 바로 python에서 for문에서 container를 unpack할 시에 제공하는 문법입니다.
아마 컨테이너 안에 있는 tuple에서 몇가지 원소를 이름을 붙여서 가져올 때 사용해보셨을 텐데, 컨테이너 안에 컨테이너에서 이름을 붙여서 가져올 수 있다는 사실은 사람들이 잘 모르시는 경우가 많았습니다.
이 경우에는 괄호를 생략하실수 없습니다! 바로 container에서 가져옴과 동시에 이름을 붙일 수 있기 때문에 코드의 가독성이 더 올라가게 되죠.
순서를 정렬의 기준으로 만들기
enumerate에서 가져온 배열의 인덱스는 곧 input을 받은 순서입니다. 문제의 조건에서