인프런 커뮤니티 질문&답변

Baek Kyun Shin님의 프로필 이미지
Baek Kyun Shin

작성한 질문수

[개정판] 파이썬 머신러닝 완벽 가이드

교차검증 성능평가 cross_val_score()와 하이퍼 파라미터 튜닝을 위한 GridSearchCV - 01

estimator 질문

작성

·

187

3

안녕하세요.

마지막 부분에서 질문이 있습니다.

grid_dtree.predict(X_test)

estimator = grid_dtree.best_estimator_

estimator.predict(X_test)

로 나뉜 부분에서 predict 결과가 같습니다.

fit을 한 후라면 grid_dtree도 best estimator이고 grid_dtree.best_estimator_도 best estimator일텐데

grid_dtree와 estimator (= grid_dtree.best_estimator_)가 동일한 객체인가요?

즉, grid_dtree와 grid_dtree.best_estimator_가 최적인 하이퍼 파라미터로 훈련된 DecisionTree모델일텐데, 둘의 차이가 무엇인지 궁금합니다.

아니면, 둘이 같은 객체는 아니지만, 단지 grid_tree 객체에 predict 속성을 주면 grid_tree.best_estimator_.predict와 같은 기능을 주는 건가요?

감사합니다.

답변 2

2

권 철민님의 프로필 이미지
권 철민
지식공유자

안녕하십니까,

소스코드에서 grid_dtree는 Estimator객체가 아닌 GridSearchCV 객체 입니다.

GridSearchCV는 생성자로 Estimator객체와 하이퍼 파라미터 튜닝을 수행할 파라미터들, 그리고 cv 횟수등을 생성 인자로 받습니다.

GridSearchCV는 본질적으로는 교차검증을 위한 하이퍼 파라미터 튜닝 방법을 제공하는 클래스이지만 Estimator(Regressor, Classifier)와 유사하게 fit()과 predict() 메소드를 가지고 있습니다.

GridSearchCV의 fit()는 인자로 피처 데이터 세트, 타겟 데이터 세트를 받으며, GridSearchCV의 생성자로 입력받은 Estimator객체를 입력 받은 하이퍼 파라미터들을 번갈아 가면서 fit()인자로 입력받은 데이터 세트에 대해 학습과 검증용 데이터 폴드 세트로 나누면서 최적의 하이퍼 파라미터를 찾습니다. 이때 만일 GridSearchCV의 생성자로 refit=True(디폴트가 refit=True 입니다)를 입력하게 되면 GridSearchCV에서 사용된 Estimator객체를 최적 하이퍼 파라미터로 최종적으로 학습 시킨 뒤 이를 best_estimator_ 객체 속성으로 가지고 있게 됩니다.

GridSearchCV는 predict() 메소드 역시 가지고 있습니다. 이 predict()메소드를 수행하게 되면 GridSearchCV 객체내의 best_estimator_ 객체, 즉 GridSearchCV의 생성자로 입력되어 최적의 하이퍼 파라미터로 학습된 Estimator의 predict() 메소드를 호출하는 것입니다. GridSearchCV의 predict()는 단순히 이 역할만 하는 것입니다.

소스코드에서 estimator = grid_dtree.best_estimator_ 는 결국은 최종 하이퍼 파라미터로 학습된 결정트리 dtree 객체를 의미하며, grid_dtree.predict(X_test)는 결국 estimator.predict(X_test)와 동일합니다. 

감사합니다.

0

확실히 하이퍼 파라미터가 뭔지 모르니까 이해하기 힘들었는데 이거 보고 이해햇어요

Baek Kyun Shin님의 프로필 이미지
Baek Kyun Shin

작성한 질문수

질문하기