작성
·
173
2
안녕하세요
GridSearch 가 어떻게 작동되는지 많이 햇갈려서요
예제코드가 iris data 값들을 파라미터값들인 깊이 와 split 으로 분할해서 학습시키는것이 맞는지요? 그리고 학습시키고 예측은 어느부분에서 실행되는것인지요
감사합니다.
답변 1
1
안녕하십니까,
사이킷런은 GridSearchCV 라는 객체를 제공합니다. GridSearchCV는 하이퍼 파라미터들을 순차적으로 대입해 가면서 모델을 학습 시키고 검증해서 최적 하이퍼 파라미터를 찾는데, 이때 Cross validation (교차 검증) 방식으로 학습과 검증을 합니다.
GridSearchCV() 생성 시 입력 인자로 사이킷런 Decision Tree와 같은 Estimator객체와 테스트할 하이퍼 파라미터들을 Dictionary 형태로 입력해주면 자동으로 Decision Tree를 입력된 하이퍼 파라미터들로 반복적으로 학습 후 성능 검증을 합니다.
# DecisionTree 객체 생성.
dt_clf = DecisionTreeClassifier()
# 테스트할 하이퍼 파라미터 설정. max_depth를 4, 5, 6 값을 번갈아서 테스트
hyper_params = {'max_depth': [ 4, 5, 6] }
# GridSearchCV 객체 생성, DecisionTreeClassifier 객체의 max_depth 4, 5, 6을 번갈아 가면서 학습, 검증 수행하되, 교차 검증 데이터 세트를 3으로 설정해서 수행. 즉 max_depth 4일때 모델 학습, 성능 검증, 5일때 학습 및 성능 검증, 6일때 학습 및 성능 검증을 하는데 교차 검증을 3으로 하므로 3개의 max_depth 하이퍼 파라미터 각각에 대해서 3번씩 학습, 검증 수행하므로 3x3번 학습, 3x3번 검증, 총 9번의 학습 검증을 수행
grid_cv = GridSearchCV(dt_clf, params=hyper_params, cv=3, refit=True)
# 학습데이터 x_train, y_train에 대해서 GridSearchCV 수행. max_depth 4일때 모델 학습, 성능 검증, 5일때 학습 및 성능 검증, 6일때 학습 및 성능 검증을 하는데 교차 검증을 3으로 하므로 3개의 max_depth 하이퍼 파라미터 각각에 대해서 3번씩 학습, 검증 수행하므로 3x3번 학습, 3x3번 검증, 총 9번의 학습 검증을 수행
grid_cv.fit(x_train, y_train)
학습과 예측 모두 GridSearchCV에서 입력된 DecisionTreeClassifier 객체를 이용해서 수행합니다.
GridSearchCV는 해당 DecisionTreeClassifier 객체를 GridSearchCV에 입력된 하이퍼 파라미터와 CV값에 따라 반복적으로 학습 시키고 성능 검증하고, 그 결과를 별도의 Dictionary에 기록하는 역할을 수행합니다. 최종적으로 가장 좋은 성능을 나타낸 하이퍼 파라미터로 DecisionTreeClassifier 객체를 최종 학습 시키고 종료 합니다.
감사합니다.