게시물 목록

Saturday, March 12, 2022

경사하강법(stochastic gradient descent)의 일반화 성능이 높은 이유? stochasticity 없이도 달성해보기

참여중인 이론기계학습 연구모임에서 6개월마다 한 번 정도씩 내 발표순서가 돌아온다. 이번에 다룬 프리프린트 (J. Geiping et al., "Stochastic Training is Not Necessary for Generalization," https://arxiv.org/abs/2109.14119)에서는 제목 그대로 stochasticity 그 자체는 좋은 학습, 즉 오버피팅 없이 일반화를 잘하는 학습의 필수요건이 아니라고 주장한다. 즉 SGD라는 확률과정의 여러 특징들을 서로 잘 분리해 내어 보니, stochasticity라는 요인을 빼고 이외의 side effect만 취한 결정론적 최적화로도 똑같은 일반화 성능을 달성할 수 있다는 것이다.


초록도 본문도 상당히 도발적인 뉘앙스로 쓰인 논문인데, 후술하겠지만 실제 결과는 그 정도까지 surprising한 것은 아닌 듯하고, 그래도 유의미한 문제의식을 스텝바이스텝으로 재밌게 풀어가고 있다. 수학적 원리에 기초를 둔 응용연구, 혹은 딥러닝의 작동원리 자체에 대한 연구를 활발히 해온 Tom Goldstein 및 그 동료들의 연구이다.


딥러닝에서 mini-batch를 사용하는 SGD가 왜 좋은가? 일단 가장 흔히 알려져 있듯이, full batch보다 계산 시간 면에서 유리하다는 것이 있겠다. 그렇다면 SGD는 시간 단축을 위해 결과적인 성능을 희생하는 것인가? 딥러닝에선 그렇지가 않고 SGD가 '오히려 좋다'는 것이 꽤나 알려져 있다. 그 이유로는 크게 두가지가 있다.


첫번째는 최적화 관점이다. Full-batch gradient descent (FB GD)는 비용 함수 지형의 saddle point(안장점)에서 속도가 급격히 느려진다. 반면 SGD는 이러한 saddle point를 잘 벗어나서 minima를 잘 찾아간다 (이건 아마 stochasticity 그 자체가 중요할것같다). 그리고 N차원 비용함수 지형에서 기울기가 0이 되는 점들 중, 간단히 산술적으로만 생각해봐도 saddle은 매우 많지만 minima는 훨씬 적을것이다. 이계도함수의 부호가 모두 플러스여야 minimum인데, 몇개는 플러스고 나머지 몇개가 마이너스고 하면 saddle point이기 때문이다. 이런 상황에서 saddle 하나하나에서 느려지는 FB GD에 비해, minima를 잘 찾아가는 SGD는 당연히 훨씬 유리하다.


두번째는 일반화 관점이다. 만약에 기울기 하강을 통해 비용함수의 minima에 도달했더라도, 그 지점이 과적합(overfitting)을 일으키는 파라미터들이라면 딥러닝 관점에서는 별로 안좋을것이다. 학습데이터(트레이닝셋)에 존재하지 않았던 테스트셋 데이터를 넣어도 잘 작동해야 하고, 이를 일반화(generalization)을 잘한다고 한다. FB GD는 설령 minima에 무사히 도달하더라도 overfitting이 심한 곳일 가능성이 높고, 반면 SGD는 일반화를 잘하는 지점에 잘 도달하는 경향이 있다.


왜 그럴까? 일단 minima 부근의 비용함수 지형이 sharp할수록 오버피팅이 심하고 (좀만 벗어나도 많이 달라지니까), flat할수록 일반화를 잘한다는 것은 약간 애매하지만 직관적으로 받아들일 만하다. 물론 이론적, 실험적으로 입증한 논문들도 많으며 거의 정설이다. 그래서 이하에서는 일반화가 잘되는 minima를 flat minima라고 부르겠다.


그러면 SGD가 왜 flat minima를 선호하는가? 이에 대한 다양한 설명이 있다. 먼저 통계물리 관점이다. 무지성(?)이고 위치 및 방향에 대해 균질한 화이트노이즈와 다르게, SGD의 경우는 landscape의 모양에 따라 adaptive, intelligent하게 조절되는 비평형 노이즈기 때문에, sharp minima일수록 더 오래 못빠져나오는 화이트노이즈와는 정반대로 flat minima에서 더 오래 머무른다. 이걸 통계물리학자들은 fluctuation-dissipation relation의 breakdown이라고 한다.


아니면 수학적으로, SGD에 의한 implicit한 효과를 explicit하게 빼내어 주어서 설명하는 것도 있다. SGD는 full batch에 의한 진짜 비용함수 지형 대신에, 매 지점에서 약간씩 틀어져있는 '가짜 지형'을 effectively 겪는다고 할 수 있다. 그 가짜 지형을 계산해보면 실제 지형에, 비용함수의 기울기의 제곱에 비례하는 항이 추가된다. 이것을 최소화한다는 것은 기울기가 별로 안컸으면 좋겠다는 것이고 이는 전형적인 regularization term에 해당한다. 즉 SGD는 regularization 효과가 있고 이것때문에 오버피팅이 방지된다.


위 두 가지는 서로 대립되는게 아니라 서로 통해있는 얘기다. 첫번째 관점에서 말한 것처럼 adaptive하게 조절되는 구체적인 방식이, 바로 두번째 관점에서 말한 regularization term인 것이다. 아무튼 여기서는 후자에 초점을 맞추자. 이 논문의 문제의식은 다음과 같다. SGD를 실행하되 그 이론적 해석만 regularization이라고 하는 게 아니라, 아예 SGD 대신 FB GD를 해버리되 앞서말한 regularization을 직접 해주어 보자. 위의 설명대로면 FB GD로도 SGD의 일반화 성능을 달성할 수 있어야 한다.


저자들은 이를 확인하기 위해 ResNet 모델, CIFAR-10 데이터로 일반화 성능을 확인하는 실험을 돌린다. 이때 fair comparision을 위해, minibatch라는 게 없는 FB GD에서도, SGD에서와 같은 batch size에 해당하는 batch normalization은 계속 해준다.


일단 SGD에서는 validation score가 95.7%가 나온다. 반면 이를 full-batch로만 바꾼 naive FB GD에서는 75.42%가 나온다. 이 20%의 성능 갭을, (앞서 말한 explicit regularization을 포함하여) 어떻게든 non-stochastic한 방식으로만 메워보고 싶다.


첫번째로 스케쥴링을 개선해준다. 처음부터 learning rate를 크게 시작하지 말고, 상당히 느리고 긴 warm-up을 해준다. 이것만으로 87.36%로, 갭이 절반 이상 메워졌다 (근데 이걸 SGD에서도 똑같이 해준다면 그쪽도 성능이 더 좋아지는것 아닌가? 사실 이하에서도 이런 비슷한 의문이 계속 든다).


다음으로는 i) gradient clipping (FB GD에서는 landscape-dependent한 learning rate adaptation과 동등함), ii) regularization (위에서 말했듯 이게 논문의 핵심 문제의식이다), iii) smaller batch size에 해당하는 batch normalization 수행 등을 해준다. 이렇게 하면 95.67%의 성능으로 SGD의 성능에 거의 근접해진다.


그리고 마지막으로, 기본적으로 해주고 있던 random data augmentation도 꺼 주자. 그렇게 하면 SGD의 성능은 84.32%로 떨어지는데, FB GD 쪽의 성능은 89.17%가 되어 상대적으로 현격하게 잘하게 된다.


위에서 말했듯, 완전한 fair comparison이라기엔 FB GD 쪽에 너무 불공평한 추가적 성능 개선작업을 많이 해준 감은 있다. 위에 말했듯 '어떻게든 non-stochastic하게 해보겠다'는 것에 치중해서 흘러간 것이다.


그럼에도 불구하고, SGD의 여러 특징들을 잘 분리해내서 그 중에 실제 일반화 성능에의 주효한 요인을 identify하고, 무작위성이 없는 방식으로도 상당히 높은 성능을 달성할 수 있음을 보인것은 충분히 의미가 있다고 보인다.


다만 full-batch니까 당연히 시간은 더 오래걸렸을 것이다. 저자들 역시, 실제로 어떤 시간 단축과 성능 개선을 하기 위한 논문이라기보다는, (이미 잘 되고있는) 딥러닝의 작동원리에 대한 깊은 이해를 돋구기 위한 연구라는 식으로 말하고 있다.


다음으로는 SGD의 implicit bias ('가짜 지형'을 겪게끔 되는것) 효과를 explicit regularization처럼 보이게 빼내어주는 실제 이론적 계산을 따라가보고싶다. 그리고 flat minima일 때 overfitting이 덜 되는 이유를, 분포라는 관점에서 clear한 argument를 만들어보고 싶다. 발표하면서 작성한 노트를 하단에 이미지로 첨부한다.


Facebook에서 이 글 보기: 링크

Facebook 'Tensorflow KR' 그룹에서 이 글 보기: 링크


사진 설명이 없습니다.


사진 설명이 없습니다.


사진 설명이 없습니다.


사진 설명이 없습니다.

No comments:

Post a Comment