인프런 커뮤니티 질문&답변

ecomarine님의 프로필 이미지
ecomarine

작성한 질문수

딥러닝 CNN 완벽 가이드 - Fundamental 편

ImageDataGenerator로 Augmentation 적용 - 01

data_generator.fit(image_batch)에서 fit의 의미

작성

·

308

1

설명 중에 fit을 안해도 되는데

나중에 노멀라이제이션도 전체 데이터에 적용되어야하기 때문필요하다고 하던데  

잘 이해가 안되는데 fit이 뭔지 궁금합니다. 

답변 1

1

권 철민님의 프로필 이미지
권 철민
지식공유자

안녕하십니까,

일반적으로 fit() 메소드는 model 학습시  호출합니다. 모델을 학습하실때 사용해 보셨겠지만, model.fit(ImageDataGenerator, ,,,, ) 와 같이 사용됩니다.   이렇게 fit()을 호출하면 ImageDataGenerator가 Batch size만큼 데이터를 입력받아서, augmentation과 기타 변환을 수행하여 모델에 입력 시켜서 학습을 하게 됩니다.

ImageDataGenerator도 fit()이라는 메소드를 가지고 있습니다. 이는 모델에서 사용하는 fit()하고는 성격이 다릅니다. 보통은 ImageDataGenerator의 fit()을 호출할 일이 없습니다.  모델의 fit()을 호출하면 next(ImageDataGenerator)와 같이 데이터가 batch size만큼  ImageDataGenerator로 자동으로 입력되고, 변환이 됩니다.

이게 보통의 augmentation 변환 로직인데, Normalization인 경우는 좀 사정이 다릅니다. 일반적으로 Normalization을 R, G, B 채널별 전체 데이터에 대해서 평균값, 표준 편차 값을 구한 뒤 Normalization을 수행하게 됩니다. 즉 Normalization의 경우 batch size만큼 순차적으로 데이터가 필요한게 아니라 전체 데이터가 필요하게 됩니다. 

따라서 이 경우는 ImageDataGenerator의 fit()을 호출해서 전체 데이터를 가공할 수 있어야 합니다. Normalization 변환을 ImageDataGenerator로 적용하려면 ImageDataGenerator의 fit()을 먼저 호출해 줘야 합니다. 이는 모델의 fit()하고는 상관이 없습니다.

감사합니다.

안녕하세요 교수님.

교수님이 말씀하신것중에서 질문이 있습니다.

그러면 fit함수를 사용하면 전체데이터를 불러들이는것과 같은것일까요?

그것을 flow를 통해서 이미지를 가공하는것이 맞을까요?

 

제가 정리한 내용은 다음과 같습니다.

 

# 1.fit을 통해서 각 채널별로 평균과 표준편차를 구한것임.

# 2.data_generator를 통해서 위에서 계산한 평균과 표준편차를 통해 zscore변환한다.

# 3.next를 통해서 다음 배치를 반환한다.

ecomarine님의 프로필 이미지
ecomarine

작성한 질문수

질문하기