이미지 분류 모델을 만들 때는 backbone을 가져올 때 주로 timm을 사용할 때가 많다. torchvision은 최근 모델 업데이트가 느리고 mmdetection은 이미지 분류에 사용하기에는 수고가 너무 많이 든다고 생각한다.
timm은 유명한 모델 아키텍처는 간단히 클래스로 간단히 불러올 수 있다.
timm 모델 사용법
##model.py
import timm
class TIMM(nn.Module):
def __init__(self, model):
super().__init__()
self.model = timm.create_model(model, pretrained=True, num_classes=1)
self.layer = nn.Sigmoid()
def forward(self, x):
x = self.model(x)
x = self.layer(x)
return x
#train.py
from importlib import import_module
import argparse
from model import *
if __name__ == '__main__':
parser.add_argument("--model", type=str, default="resnet18")
model = getattr(import_module("model"), "TIMM")(args.model)
사람마다, 팀마다 모듈화 시키는 법이 다른것 같은데 개인으로 짤 때는 보통 위와 같이 선언해서 사용한다.
timm 모델 리스팅
timm 공식 docs에는 모델을 timm.list_model()이었나? 어쨌든 이런 함수를 사용해 모델 리스트를 확인하라고 한다. 보통 다들 starts_with("vit")와 같은 코드로 원하는 모델을 찾는데 모델을 찾을 때마다 이런 수고를 하는게 너무 귀찮아 검색을 해보니
허깅페이스에 작성된 공식개발자의 포스트에서 모델을 편하게 리스팅할 수 있었다.
아래로 좀 스크롤하면 models와 검색칸이 있는데 이곳에 원하는 모델 이름을 넣고 검색하면 원하는 베이스 모델의 모든 변형을 편하게 볼 수 있다. 그리고 정렬기준을 most star, most download로 설정하면 어떤 모델을 자주 쓰는지 확인할 수 있어 처음 베이스 모델을 고를 때 고민을 덜어준다.
모델을 클릭하면 파라미터, img_size, 훈련에 사용된 데이터셋 등 다양한 정보를 알 수 있다.
timm 모델 성능 확인
중요한 성능확인 이야기를 안했는데 먼저 공식 깃허브 사이트에서 result 폴더에서 성능을 csv형태로 확인할 수 있으며
아래의 레포에서 streamlit으로 SOTA benchmark 느낌으로 timm model들을 리스팅 해뒀으므로 더 편하게 확인할 수 있다. 아마 csv를 읽어서 시각화하는거 같은데 6개월 이상 업데이트를 안 했기에 최근 모델은 없을 수도 있다.
연구자는 모르겠지만 AI개발자 혹은 경진대회 참여자들은. 일단 좋은걸 갖다 쓰고 성능을 올리는게 최우선이기에 이런 다양한 정보를 빠르게 검색하는게 중요하다고 느껴진다.
'Notes' 카테고리의 다른 글
WSL VScode에서 SSH 에러 (Could not establish connection to "ip")해결방법 (1) | 2024.09.20 |
---|---|
Pytorch-GradCAM 치트시트 (0) | 2024.09.07 |
기존 딥러닝 개발환경 복제하기 (0) | 2024.06.24 |
[Poetry] toml, lock파일 기반 가상환경 설치하기 (0) | 2023.07.24 |
[Poetry] Poetry를 사용해 프로젝트 버전 관리하기 (0) | 2023.05.01 |