인프런 커뮤니티 질문&답변

bloomingdiana님의 프로필 이미지

작성한 질문수

[개정판] 딥러닝 컴퓨터 비전 완벽 가이드

Config의 이해 - Config 대분류 및 주요 설정 이해하기

config 파일 수정 문의

해결된 질문

23.03.27 22:45 작성

·

462

·

수정됨

0

안녕하세요 선생님

선생님 강의를 통해서 custom dataset을 이용하여

faster-rcnn 모델을 돌려볼 수 있었습니다.

이 custom dataset으로 다른 모델(swin)도 적용해보려고 하는데요 https://github.com/open-mmlab/mmdetection/tree/master/configs/swin 이 페이지의 mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py 파일을 이용해보려고 합니다. 그에 맞게 config파일과 checkpoints를 변경하고 모델을 구동하려고 하니 mask관련해 오류가 발생했습니다. 아마 mask-rcnn으로인해 발생한 오류처럼 보입니다. 구글링을 해보니 이 부분을 주석 처리해서 실행해보라고 하던데 colab에서 해당 부분을 주석처리할 수 있는 방법이 있을까요? 혹시 더 좋은 방법이 있다면 가르쳐 주시면 감사하겠습니다.


2023-03-27 14:19:05,247 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
<ipython-input-14-f8ce61995cc8>:47: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  'labels': np.array(gt_labels, dtype=np.long),
<ipython-input-14-f8ce61995cc8>:49: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  'label_ignore':np.array(gt_labels_ignore, dtype=np.long)
2023-03-27 14:19:08,688 - mmdet - INFO - load checkpoint from local path: checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth
2023-03-27 14:19:08,849 - mmdet - WARNING - The model and loaded state dict do not match exactly

size mismatch for roi_head.bbox_head.fc_cls.weight: copying a param with shape torch.Size([81, 1024]) from checkpoint, the shape in current model is torch.Size([16, 1024]).
size mismatch for roi_head.bbox_head.fc_cls.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([16]).
size mismatch for roi_head.bbox_head.fc_reg.weight: copying a param with shape torch.Size([320, 1024]) from checkpoint, the shape in current model is torch.Size([60, 1024]).
size mismatch for roi_head.bbox_head.fc_reg.bias: copying a param with shape torch.Size([320]) from checkpoint, the shape in current model is torch.Size([60]).
size mismatch for roi_head.mask_head.conv_logits.weight: copying a param with shape torch.Size([80, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 256, 1, 1]).
size mismatch for roi_head.mask_head.conv_logits.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([15]).
2023-03-27 14:19:08,856 - mmdet - INFO - Start running, host: root@06d3ab7dae34, work_dir: /content/gdrive/MyDrive/htp_dir_swin
2023-03-27 14:19:08,858 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) StepLrUpdaterHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
 -------------------- 
after_train_iter:
(ABOVE_NORMAL) OptimizerHook                      
(NORMAL      ) CheckpointHook                     
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
after_train_epoch:
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_val_epoch:
(NORMAL      ) NumClassCheckHook                  
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_val_iter:
(LOW         ) IterTimerHook                      
 -------------------- 
after_val_iter:
(LOW         ) IterTimerHook                      
 -------------------- 
after_val_epoch:
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
after_run:
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
2023-03-27 14:19:08,859 - mmdet - INFO - workflow: [('train', 1)], max: 5 epochs
2023-03-27 14:19:08,860 - mmdet - INFO - Checkpoints will be saved to /content/gdrive/MyDrive/htp_dir_swin by HardDiskBackend.
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-35-c8cc0d536607> in <module>
      4 mmcv.mkdir_or_exist(os.path.abspath(cfg.work_dir))
      5 # epochs는 config의 runner 파라미터로 지정됨. 기본 12회
----> 6 train_detector(model, datasets, cfg, distributed=False, validate=True)


6 frames


/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/apis/train.py in train_detector(model, dataset, cfg, distributed, validate, timestamp, meta)
    244     elif cfg.load_from:
    245         runner.load_checkpoint(cfg.load_from)
--> 246     runner.run(data_loaders, cfg.workflow)

/usr/local/lib/python3.9/dist-packages/mmcv/runner/epoch_based_runner.py in run(self, data_loaders, workflow, max_epochs, **kwargs)
    134                     if mode == 'train' and self.epoch >= self._max_epochs:
    135                         break
--> 136                     epoch_runner(data_loaders[i], **kwargs)
    137 
    138         time.sleep(1)  # wait for some hooks like loggers to finish

/usr/local/lib/python3.9/dist-packages/mmcv/runner/epoch_based_runner.py in train(self, data_loader, **kwargs)
     47         self.call_hook('before_train_epoch')
     48         time.sleep(2)  # Prevent possible deadlock during epoch transition
---> 49         for i, data_batch in enumerate(self.data_loader):
     50             self.data_batch = data_batch
     51             self._inner_iter = i

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1331             else:
   1332                 del self._task_info[idx]
-> 1333                 return self._process_data(data)
   1334 
   1335     def _try_put_index(self):

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1357         self._try_put_index()
   1358         if isinstance(data, ExceptionWrapper):
-> 1359             data.reraise()
   1360         return data
   1361 

/usr/local/lib/python3.9/dist-packages/torch/_utils.py in reraise(self)
    541             # instantiate since we don't know how to
    542             raise RuntimeError(msg) from None
--> 543         raise exception
    544 
    545 

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/custom.py", line 220, in __getitem__
    data = self.prepare_train_img(idx)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/custom.py", line 243, in prepare_train_img
    return self.pipeline(results)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/compose.py", line 41, in __call__
    data = t(data)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/loading.py", line 398, in __call__
    results = self._load_masks(results)
  File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/loading.py", line 347, in _load_masks
    gt_masks = results['ann_info']['masks']
KeyError: 'masks'

답변 1

0

권 철민님의 프로필 이미지
권 철민
지식공유자

2023. 03. 28. 16:58

안녕하십니까,

음, faster rcnn은 object detection 모델인데, object detection용 dataset로 mask rcnn 적용이 어려울 것 같습니다만,,

어떤 dataset을 지금 이용하고 계시는건지요? 강의에 있는 데이터 세트인가요? 아님 다른 데이터 세트 인가요?

 

bloomingdiana님의 프로필 이미지
bloomingdiana
질문자

2023. 03. 28. 18:35

현재 다른 데이터 셋을 사용하고 있습니다. faster rcnn모델을 사용해서 object detection을 수행해보았고 mmdetection 패키지를 사용해서 다른 모델과의 성능 비교를 하고 싶은데 mask-rcnn은 적용이 어려울까요? ssd도 적용해보려고 하는데 자꾸 오류를 뱉어내네요 ㅠㅠ mmdetection configs에 swin이라고 되어 있길래 swin transformer인 줄 알았는데 backbone만 swin transformer를 사용하는 mask rcnn모델이었나봅니다

권 철민님의 프로필 이미지
권 철민
지식공유자

2023. 03. 28. 20:17

mask-rcnn을 적용하시려면 segmentation dataset을 적용하셔야 합니다. 이후 강의에 mmdetection mask-rcnn을 사용하는 강의가 segmentation 섹션에 있으니 해당 강의를 듣고 적용해 보시면 좋을 것 같습니다.

mmdetection의 SSD는 이슈가 많고, 성능이 좋지 않습니다.