해결된 질문
23.03.27 22:45 작성
·
462
·
수정됨
0
안녕하세요 선생님
선생님 강의를 통해서 custom dataset을 이용하여
faster-rcnn 모델을 돌려볼 수 있었습니다.
이 custom dataset으로 다른 모델(swin)도 적용해보려고 하는데요 https://github.com/open-mmlab/mmdetection/tree/master/configs/swin 이 페이지의 mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py 파일을 이용해보려고 합니다. 그에 맞게 config파일과 checkpoints를 변경하고 모델을 구동하려고 하니 mask관련해 오류가 발생했습니다. 아마 mask-rcnn으로인해 발생한 오류처럼 보입니다. 구글링을 해보니 이 부분을 주석 처리해서 실행해보라고 하던데 colab에서 해당 부분을 주석처리할 수 있는 방법이 있을까요? 혹시 더 좋은 방법이 있다면 가르쳐 주시면 감사하겠습니다.
2023-03-27 14:19:05,247 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
<ipython-input-14-f8ce61995cc8>:47: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
'labels': np.array(gt_labels, dtype=np.long),
<ipython-input-14-f8ce61995cc8>:49: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
'label_ignore':np.array(gt_labels_ignore, dtype=np.long)
2023-03-27 14:19:08,688 - mmdet - INFO - load checkpoint from local path: checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth
2023-03-27 14:19:08,849 - mmdet - WARNING - The model and loaded state dict do not match exactly
size mismatch for roi_head.bbox_head.fc_cls.weight: copying a param with shape torch.Size([81, 1024]) from checkpoint, the shape in current model is torch.Size([16, 1024]).
size mismatch for roi_head.bbox_head.fc_cls.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([16]).
size mismatch for roi_head.bbox_head.fc_reg.weight: copying a param with shape torch.Size([320, 1024]) from checkpoint, the shape in current model is torch.Size([60, 1024]).
size mismatch for roi_head.bbox_head.fc_reg.bias: copying a param with shape torch.Size([320]) from checkpoint, the shape in current model is torch.Size([60]).
size mismatch for roi_head.mask_head.conv_logits.weight: copying a param with shape torch.Size([80, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 256, 1, 1]).
size mismatch for roi_head.mask_head.conv_logits.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([15]).
2023-03-27 14:19:08,856 - mmdet - INFO - Start running, host: root@06d3ab7dae34, work_dir: /content/gdrive/MyDrive/htp_dir_swin
2023-03-27 14:19:08,858 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH ) StepLrUpdaterHook
(NORMAL ) CheckpointHook
(LOW ) EvalHook
(VERY_LOW ) TextLoggerHook
--------------------
before_train_epoch:
(VERY_HIGH ) StepLrUpdaterHook
(NORMAL ) NumClassCheckHook
(LOW ) IterTimerHook
(LOW ) EvalHook
(VERY_LOW ) TextLoggerHook
--------------------
before_train_iter:
(VERY_HIGH ) StepLrUpdaterHook
(LOW ) IterTimerHook
(LOW ) EvalHook
--------------------
after_train_iter:
(ABOVE_NORMAL) OptimizerHook
(NORMAL ) CheckpointHook
(LOW ) IterTimerHook
(LOW ) EvalHook
(VERY_LOW ) TextLoggerHook
--------------------
after_train_epoch:
(NORMAL ) CheckpointHook
(LOW ) EvalHook
(VERY_LOW ) TextLoggerHook
--------------------
before_val_epoch:
(NORMAL ) NumClassCheckHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook
--------------------
before_val_iter:
(LOW ) IterTimerHook
--------------------
after_val_iter:
(LOW ) IterTimerHook
--------------------
after_val_epoch:
(VERY_LOW ) TextLoggerHook
--------------------
after_run:
(VERY_LOW ) TextLoggerHook
--------------------
2023-03-27 14:19:08,859 - mmdet - INFO - workflow: [('train', 1)], max: 5 epochs
2023-03-27 14:19:08,860 - mmdet - INFO - Checkpoints will be saved to /content/gdrive/MyDrive/htp_dir_swin by HardDiskBackend.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-35-c8cc0d536607> in <module>
4 mmcv.mkdir_or_exist(os.path.abspath(cfg.work_dir))
5 # epochs는 config의 runner 파라미터로 지정됨. 기본 12회
----> 6 train_detector(model, datasets, cfg, distributed=False, validate=True)
6 frames
/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/apis/train.py in train_detector(model, dataset, cfg, distributed, validate, timestamp, meta)
244 elif cfg.load_from:
245 runner.load_checkpoint(cfg.load_from)
--> 246 runner.run(data_loaders, cfg.workflow)
/usr/local/lib/python3.9/dist-packages/mmcv/runner/epoch_based_runner.py in run(self, data_loaders, workflow, max_epochs, **kwargs)
134 if mode == 'train' and self.epoch >= self._max_epochs:
135 break
--> 136 epoch_runner(data_loaders[i], **kwargs)
137
138 time.sleep(1) # wait for some hooks like loggers to finish
/usr/local/lib/python3.9/dist-packages/mmcv/runner/epoch_based_runner.py in train(self, data_loader, **kwargs)
47 self.call_hook('before_train_epoch')
48 time.sleep(2) # Prevent possible deadlock during epoch transition
---> 49 for i, data_batch in enumerate(self.data_loader):
50 self.data_batch = data_batch
51 self._inner_iter = i
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in __next__(self)
626 # TODO(https://github.com/pytorch/pytorch/issues/76750)
627 self._reset() # type: ignore[call-arg]
--> 628 data = self._next_data()
629 self._num_yielded += 1
630 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
1331 else:
1332 del self._task_info[idx]
-> 1333 return self._process_data(data)
1334
1335 def _try_put_index(self):
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1357 self._try_put_index()
1358 if isinstance(data, ExceptionWrapper):
-> 1359 data.reraise()
1360 return data
1361
/usr/local/lib/python3.9/dist-packages/torch/_utils.py in reraise(self)
541 # instantiate since we don't know how to
542 raise RuntimeError(msg) from None
--> 543 raise exception
544
545
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/custom.py", line 220, in __getitem__
data = self.prepare_train_img(idx)
File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/custom.py", line 243, in prepare_train_img
return self.pipeline(results)
File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/compose.py", line 41, in __call__
data = t(data)
File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/loading.py", line 398, in __call__
results = self._load_masks(results)
File "/usr/local/lib/python3.9/dist-packages/mmdet-2.28.2-py3.9.egg/mmdet/datasets/pipelines/loading.py", line 347, in _load_masks
gt_masks = results['ann_info']['masks']
KeyError: 'masks'
답변 1
0
2023. 03. 28. 16:58
안녕하십니까,
음, faster rcnn은 object detection 모델인데, object detection용 dataset로 mask rcnn 적용이 어려울 것 같습니다만,,
어떤 dataset을 지금 이용하고 계시는건지요? 강의에 있는 데이터 세트인가요? 아님 다른 데이터 세트 인가요?
2023. 03. 28. 20:17
mask-rcnn을 적용하시려면 segmentation dataset을 적용하셔야 합니다. 이후 강의에 mmdetection mask-rcnn을 사용하는 강의가 segmentation 섹션에 있으니 해당 강의를 듣고 적용해 보시면 좋을 것 같습니다.
mmdetection의 SSD는 이슈가 많고, 성능이 좋지 않습니다.
2023. 03. 28. 18:35
현재 다른 데이터 셋을 사용하고 있습니다. faster rcnn모델을 사용해서 object detection을 수행해보았고 mmdetection 패키지를 사용해서 다른 모델과의 성능 비교를 하고 싶은데 mask-rcnn은 적용이 어려울까요? ssd도 적용해보려고 하는데 자꾸 오류를 뱉어내네요 ㅠㅠ mmdetection configs에 swin이라고 되어 있길래 swin transformer인 줄 알았는데 backbone만 swin transformer를 사용하는 mask rcnn모델이었나봅니다