인프런 커뮤니티 질문&답변

kjyn0124님의 프로필 이미지
kjyn0124

작성한 질문수

딥러닝 CNN 완벽 가이드 - Fundamental 편

섹션6 CIFAR10 imshow() 시각화 문제

해결된 질문

작성

·

232

0

안녕하세요 교수님!

5강 시작부분에서 get_preprocessed_data의 scaling 파라미터 값을 False로 하셨는데, 그러면 픽셀값을 255로 나누지 않는 것인데 이렇게 하면 다음과 같은 흰색 배경만 뜨더라구요..

그래서 구글링을 해보니까 plt.imshow() 함수가 0 ~ 1 사이의 float형이나 0 ~ 255 사이의 int형만 가능하다고 해서 다음과 같이 바꾸었는데 제대로 출력되더라구요..!

...

def get_preprocessed_data(images, labels, scaling=True):
    if scaling: # 직접 scaling을 한다고 했을때?
        images = np.array(images/255.0, dtype=np.float32)
    else:
        images = np.array(images, dtype=np.int32) # 이 부분을 수정했습니다.
    
    oh_labels = np.array(labels, dtype=np.float32)
    return images, oh_labels

def get_preprocessed_ohe(images, labels):
    images,labels = get_preprocessed_data(images, labels, scaling=False)
    
    # OHE
    oh_labels = to_categorical(labels)
    return images, oh_labels

...

교수님 코드랑 다른 부분이 없는데 저는 흰 배경으로만 나오고, 저렇게 설정해야지만 올바르게 나오는 점이 이상해서 여쭤보고자 합니다ㅠㅠ!

 

혹시 몰라서 해당 부분 전체 코드 올리겠습니다!

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

# seed 설정
def set_random_seed(seed_value):
    np.random.seed(seed_value)
    python_random.seed(seed_value)
    tf.random.set_seed(seed_value)

def get_preprocessed_data(images, labels, scaling=True):
    if scaling: # 직접 scaling을 한다고 했을때?
        images = np.array(images/255.0, dtype=np.float32)
    else:
        images = np.array(images, dtype=np.float32)
    
    oh_labels = np.array(labels, dtype=np.float32)
    return images, oh_labels

def get_preprocessed_ohe(images, labels):
    images,labels = get_preprocessed_data(images, labels, scaling=False)
    
    # OHE
    oh_labels = to_categorical(labels)
    return images, oh_labels

def get_train_valid_test_set(train_images, train_labels, test_images, test_labels, valid_size=0.15, random_state=2023):
    train_images, train_ohe_labels = get_preprocessed_ohe(train_images, train_labels)
    test_images, test_ohe_labels = get_preprocessed_ohe(test_images, test_labels)
    
    train_images, valid_images, train_ohe_labels, valid_ohe_labels = train_test_split(train_images, train_ohe_labels, test_size=valid_size, random_state=random_state)
    return train_images, train_ohe_labels, valid_images, valid_ohe_labels, test_images, test_ohe_labels
set_random_seed(2023)
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

train_images, train_ohe_labels, valid_images, valid_ohe_labels, test_images, test_ohe_labels = get_train_valid_test_set(train_images, train_labels, test_images, test_labels, valid_size=0.15, random_state=2023)
print(train_images.shape, train_ohe_labels.shape, valid_images.shape, valid_ohe_labels.shape, test_images.shape, test_ohe_labels.shape)
NAMES = np.array(['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'])
def show_images(images, labels, ncols=8):
    figure, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(22, 6))
    
    for i in range(ncols):
        axs[i].imshow(images[i])
        label = labels[i].squeeze()
        axs[i].set_title(NAMES[int(label)])

show_images(train_images[:8], train_labels[:8], ncols=8)
show_images(train_images[8:16], train_labels[8:16], ncols=8)
show_images(train_images[16:24], train_labels[16:24], ncols=8)

감사합니다!

답변 1

0

권 철민님의 프로필 이미지
권 철민
지식공유자

안녕하십니까,

네 맞습니다. plt.imshow() 함수가 0 ~ 1 사이의 float형이나 0 ~ 255 사이의 int형을 가진 array일 때 제대로 시각화 해줍니다.

제 생각엔 get_preprocessed_data()를 변경하실 필요까지는 없고 시각화 할때만 0~255 사이의 값을 가질 경우에 int형으로 변경만 해주시면 될 것 같습니다.

그러니까 시각화 할때만 plt.imshow(train_data[:8].astype(np.int32))

로만 적용해 주시면 될 것 같습니다.

 

감사합니다.

kjyn0124님의 프로필 이미지
kjyn0124

작성한 질문수

질문하기