본문 바로가기
Tabular data

Tree-based-model(트리 기반 모델) Bagging & Boosting 이해하기- Tabular data

by inhovation97 2022. 6. 21.

Tabular data를 공부하면서 연속적으로 포스팅하는 글입니다. 이전 포스팅 참고

2번째 포스팅은 tree-based-model에 대해 설명하려 합니다. 

제가 거의 처음 데이터 과학 분야를 공부할 때에는 언제 딥러닝을 쓰고 언제 tree model을 쓰는지 잘 몰랐습니다.

 

이번 포스팅은 데이터에 대한 제 생각에 대해 적어보는 정도고, tree-based model을 완벽하게 수식으로 설명하는 글이 아니라 직관적인 이해를 돕는 글입니다.

1. Tabular data를 추론하는 문제
2. Tree-based model
     Bagging
     Boosting
3. Tree-based model 고찰

 

 

 

 

1. Tabular data를 추론하는 문제

 

지난번 포스팅에서 Tabular data가 뭔지 알아봤습니다. 그럼 이제 Tabular data를 추론하는 문제로 생각해봅시다. 

제가 거의 처음 데이터 과학 분야를 공부할 때에는 언제 딥러닝을 쓰고 언제 tree model을 쓰는지 잘 몰랐습니다. 그래서 이것부터 짚고 넘어가면 좋을 것 같습니다. 

 

그림 1

이거로만 단순히 모델을 선정하기는 힘들지만 반응 변수와 예측 변수가 비선형 관계인 경우에서 데이터셋이 충분하다면

딥러닝 모델, 데이터셋을 충분히 확보하기 힘들다면 트리 모델을 고려하기 좋은 것 같습니다. 

무조건 딥러닝이 좋은 건 아니라는 거죠.

학습 난이도도 과적합의 위험으로 딥러닝이 좀 더 높겠죠.

 

하지만 TabNet 이전의 tabular data의 대부분 부스팅 트리모델인 XGBoost, LightGBM 등등이 많이 쓰였고, 성능이 좋았습니다. 그럼 Tree-based-model의 원리를 한 번 살펴봅니다. 

 

 

 

2. Tree-based model

 

<Decision Tree>

Descision Tree 원리부터 Bagging, Boosting 순으로 가볍게 살펴봅니다.

 

https://scikit-learn.org/stable/modules/tree.html

Tree 모델은 위 처럼 Spliting Rule에 의해 노드를 분기해나가고, 회귀 및 분류를 진행합니다.

Spliting Rule을 정하는 방법은 여러 변수 중 하나를 선택하고,

회귀 문제에서는 분기된 값들의 RSS(잔차 제곱합)를 최소화하는 방향으로 split합니다. split된 각각의 노드들로 들어온 값들의 평균을 예측 값으로 정한 것이며 이걸로 RSS를 계산한 겁니다.

분류 문제의 경우에는 분기된 값들의 Impurity(불순도)를 최소화하는 방향으로 split합니다. split된 각각의 노드들로 들어온 값들 중 다수 클래스를 예측 클래스로 정한 것이며 이걸로 불순도를 계산한 겁니다. (이때 불순도로 위 처럼 Gini계수나 entropy를 이용합니다.)

 

test set을 넣었을 때, 각각 최종 노드들로 분기된 데이터를 위에서 정한 예측값으로 디코딩하는 것이죠. 

위 그림에서는 첫 번째 노드에서 Petal length 변수에 대해 2.45cm를 기준으로 데이터를 찢었네요. 

이러한 원리 때문에 Tree-based model이 그림1처럼 예측 변수 공간을 어느정도 비선형적으로 분할할 수 있었던 것이죠.

 

학습 방법이 단순하니 설명력도 매우 높아서 Feature importance를 구할 수 있죠. 이건 다음 포스팅에서 다루려합니다.

어쨌든 위의 방식을 기저로 Tree-based model은 더욱 발전하기 시작합니다. 

 

 

<Bagging>

여기서 기본적인 Decision Tree보다 Bagging이라는 기법으로 높은 성능을 내는 Random Forest가 있습니다. 

Bootstrap의 원리로 훈련세트 B개를 만든 뒤, 이를 Aggregate한다고 하여 Bagging입니다.

 

RandomForest를 기준으로 설명합니다.

중복을 허용하여 훈련 데이터를 B개의 훈련 데이터로 쪼갠 뒤(Bootstrap), 각각 B개에 대해 Decision Tree 모델을 생성하여 추론하면 예측값 y^이 B개가 나올겁니다.

이를 평균(Aggreagate)한 것입니다. 아마 여러개의 나무를 적합하기 때문에 Forest라고 하는 것 같습니다.

이때 각각의 tree들은 다양성을 주기위해 feature를 랜덤으로 몇개씩만 이용합니다. 

좀 더 자세한 원리와 코드는 링크를 참고하세요.

Bagging의 효과를 허접하게 그려봤는데, 위 그림으로 봤을때 DecisionTree를 4개 적합했다고 해봅시다.

과녁의 정중앙으로 갈수록 예측한 값이 정답 값에 가깝게 예측된 것입니다.

1개의 observation에 대해 예측을 했을 때 조금씩 다른 spliting rule로 학습된 각 DT들의 예측값은 위와 같이 조금씩 다를겁니다. 이때 4개의 모델을 고려하여 평균 값으로 예측해줌으로써 모델을 1개만 고려했을 때의 예측값이 튈 수 있는 정도 즉,

분산을 크게 줄여 줍니다.

중심극한의 정리로 생각해볼 수도 있는데 위처럼 1개의 observation에 대해서 B개의 Tree가 예측한 값의 분포의 평균 즉 표본 평균의 분포는 분산이 작은 정규 분포를 따르게 되는거죠.

 

분산을 크게 줄이면서 모델이 더욱 로버스트해졌지만(로버스트하다의 의미) 아직 편향치(E(y^) - y)가 너무 큽니다. 

under fitting의 문제점을 안고있습니다.

 

 

 

<Boosting>

Bagging은 병렬적으로 여러개의 독립적인 DT를 만들어 평균을 낸다면, Boosting은 직렬적으로 1개의 모델을 계속 업데이트 시켜가는 것이 큰 차이점입니다. 

딥러닝 모델처럼 이전에 학습한 모델을 update시켜 잔차를 줄여가는 방식으로 모델을 B번 업그레이드 시키는 방법이기 때문에 편향치를 줄일 수 있는 겁니다. (쪼개지 않은 원데이터에 대해 B번 업데이트시키는 것)

 

부스팅 모델의 프로세스는 간단히 설명하기 좀 어렵네요... 링크에 수식도 그렇고 설명이 잘 되어있습니다. 

링크에서 1. 전체 데이터에서 random sampling라고 설명하시는데, 위에서도 말했지만 부스팅 모델은 부스트랩을 하지 않으니 그 점만 필터링하시고 보시면 될 것 같습니다.

 

Gradient boosting으로 트리 모델이 더욱 발전하면서 XGboost나 LightGBM같은 강력한 모델이 한 동안 대부분 경진대회에서 우승 모델일 정도로 좋은 성능을 얻어냈습니다. 

 

부스팅은 모델의 잔차를 줄이도록 update 해나가기 때문에 이렇게 bias를 줄여나가 예측력이 높아지죠. 

bias를 크게 줄였지만, 분산도 줄이기 위해 여기서 마지막으로 할 수 있는게 바로 앙상블이죠. 

위 그림에서 모델을 update를 하면 할 수록 예측값들이 더욱 가운데로 shift할 거예요. 하지만 그만큼 부스팅 모델은 만져야할 파라미터도 많고 overfitting의 위험이 커져서 적합 난이도가 올라갑니다.

 

 

 

3. Tree-based model 고찰

Tree-based model이 SOTA였던 시절 여러 경진대회에서 부스팅 트리 모델을 적합한 뒤 마지막으로 조금이나마 성능을 올리기 위해 적합한 여러 모델들을 앙상블 하는 것이 관례? 였습니다. 수학적으로 생각해보면 부스팅 트리모델을 적합하여 편차를 최소화 시키고 앙상블을 통해 분산까지 줄이려는 목적으로 해석할 수 있습니다. 

 

Bagging, Boosting 두 방법 모두 모델이 조금씩 복잡해지면서 설명력은 조금 줄어들지만 예측력은 크게 올라가는 방식입니다. 그래도 Tree 모델의 설명력이 높은 기본 원리 때문에 Feature Importance를 계산할 수 있는데, 이것이 트리모델의 최대 장점이자 핵심이죠. 

Tree 모델의 학습과정에서 spliting rule을 만드는 데에 어떤 변수가 가장 많이 쓰였는가? 라는 관점으로 Feature Importance를 계산합니다. 중요한 변수를 뽑아내는 것이죠. 이건 다음 포스팅입니다. 

 

어쨋든 트리 모델에서 가장 중요한 건 변수를 선택하여 Spliting Rule을 만들어 낸다는 것입니다. 

저번 포스팅에서 제가 말한 Tabular data에 대해 한동안 Tree-based model이 딥러닝 모델보다 잘 먹힐 수 있었던 이유가 Tree-based model은 Feature Selection이라는 Tabular data의 핵심을 그대로 가지고 있었기 때문이라고 생각합니다. 

'Tabular data' 카테고리의 다른 글

Tabular data란? Tabular data 이해하기  (0) 2022.06.13

댓글