해결된 질문
작성
·
261
0
안녕하십니까! 배운내용 바탕으로 프로젝트 해보는 와중 해결되지 않는 부분이 있어 문의드립니다.
pytorch에서 densenet(pytorch 기본제공)과 SWIN-transformer(pytorch 미제공, git에 공유된 모델 활용)을 backbone으로 활용하고 싶습니다.
swin-unet처럼 모듈에서 제공하지 않는 모델의 경우 어떤식으로 코드를 작성해야할까요?
dictionary 형태의 pth 파일이라 *****.load_state_dict('~~~~.pth')을 활용해야할것 같은데 *****부분에 모델을 넣어줘야되는데 기본제공 모듈이 아니라 어떤식으로 해야할지 감이 오지 않습니다.
1번 문제가 해결이 된다면, densenet과 swin-transformer을 sequential 형태로 조합하여 pretrain으로 시키고 싶은데 어떻게 접근을 하면 될까요?
아래 코드에서 backbone 부분 어떻게 해야할지 방향 잡아주시면 너무 감사할것 같습니다.
path = '/content/drive/MyDrive/swin_tiny_patch4_window7_224.pth' #swin-Transformer 모델
pretrained_weights = torch.load(path, map_location='cpu')
class ImgFeatureExtractor(nn.Module):
def __init__(self):
super(ImgFeatureExtractor, self).__init__()
# self.backbone = models.efficientnet_b0(pretrained=True)
self.backbone = models.densenet201(pretrained=True)
self.backbone = *****.load_state_dict(pretrained_weights)
self.embedding = nn.Linear(1000,512)
def forward(self, x):
x = self.backbone(x)
x = self.embedding(x)
return x
답변 1
0
안녕하십니까,
custom backbone을 설정하는 것은 저도 해보진 않았습니다.
아래 URL을 보면 backbone과 neck을 customization하는 방법을 기술하고 있으니, 참조 부탁드립니다.
https://mmdetection.readthedocs.io/en/latest/tutorials/customize_models.html
감사합니다.