본문 바로가기
Notes

Pytorch-GradCAM 치트시트

by Yonghip 2024. 9. 7.

이번에 신입 프로젝트를 진행하며 처음으로 GradCAM을 써봤는데 그 사용법에 대해 정리해보려 한다.

 

현재 pytorch에서 GradCAM을 구현한 레포 중 가장 대표적인 건 https://github.com/jacobgil/pytorch-grad-cam 이건데 사용법은 분명 다 나와있는데 조금 직관적이지 않다고 느껴 한 번 정리하고 싶었다.

Pytorch GradCAM

일단 설치부터 한다.

pip install grad-cam

 

라이브러리 가져오고 적당히 설정을 한 다음

import timm
import os
import random
from collections import defaultdict
from importlib import import_module

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
import random
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchvision.datasets import ImageFolder

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, BinaryClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

#아래에서 자신의 데이터셋 경로를 지정하자
IMAGE_PATH = "/kaggle/input/microsoft-catsvsdogs-dataset/PetImages"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

 

원하는 데이터셋과 모델을 지정해 주면 준비는 끝이다. 

def make_dataset():
    dataset = ImageFolder(IMAGE_PATH)    
    
    #dont need to make data_loader if you just want to visualize random image in dataset
    train_loader = DataLoader(
        dataset= dataset,
        batch_size=32,            #args.batch_size
        shuffle=True,
        num_workers=0,      #cpu 코어 절반
        drop_last=False)

    return dataset, train_loader
    
dataset, train_loader = make_dataset()
model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=1)

 

target_layers = [model.conv_head]

fig, axes = plt.subplots(4, 4, figsize=(14, 10))
axes = axes.flatten()
for i in range(16):
    img = torch.tensor(np.array(dataset[i][0])).float()
    img /= 255.0                       
    label = torch.tensor(dataset[i][1])
    targets = [BinaryClassifierOutputTarget(label)]
    img = img.permute(2, 0, 1)
    cam = GradCAM(model=model, target_layers=target_layers)
    grayscale_cam = cam(input_tensor=img.unsqueeze(0).float(), targets=targets)
    img = img.permute(1, 2, 0)

    visualization = show_cam_on_image(np.float32(img), grayscale_cam[0, :], use_rgb=True)
    axes[i].imshow(visualization)
    axes[i].axis("off")

efficientnet_b0라 엄청 좋은 결과는 나오지 않았지만 그래도 몇몇 데이터는 모델이 제대로 고양이를 타겟으로 한 것을 알 수 있다.

 

위의 코드만 있어도 실행에 큰 문제는 없을 것이다. 그래도 분명 조금 더 알고 싶은 분들이 있을 것이라 생각해 사족을 덧붙여 설명해 보겠다.

More detail

이번에 나는 모델의 출력층쪽의 layer만 사용했는데 이를 적당히 지정해 사용할 수 있다. 

for name,ch in model.named_children():
    print("name :",name)
    print("child :", ch)
    print('===========================')

코드를 실행시키면 위의 크림처럼 모델의 layer이름을 알 수 있는데 이 중 CNN 모듈의 이름을 target_layer에 넣어 원하는 층이 이미지의 특정 부분을 얼마나 집중했는지 알 수 있다.

 

target_layers = [model.conv_stem]

fig, axes = plt.subplots(4, 4, figsize=(14, 10))
axes = axes.flatten()
for i in range(16):
    img = torch.tensor(np.array(dataset[i][0])).float()
    img /= 255.0                       
    label = torch.tensor(dataset[i][1])
    targets = [BinaryClassifierOutputTarget(label)]
    img = img.permute(2, 0, 1)
    cam = GradCAM(model=model, target_layers=target_layers)
    grayscale_cam = cam(input_tensor=img.unsqueeze(0).float(), targets=targets)
    img = img.permute(1, 2, 0)

    visualization = show_cam_on_image(np.float32(img), grayscale_cam[0, :], use_rgb=True)
    axes[i].imshow(visualization)
    axes[i].axis("off")

예를 들기 위해 efficientnet_b0의 가장 위 층을 출력해 보면 예상대로 모델이 global한 특징에 조금 더 집중하고 있는걸 알 수 있다.

target_layers = model.blocks

fig, axes = plt.subplots(4, 4, figsize=(14, 10))
axes = axes.flatten()
for i in range(16):
    img = torch.tensor(np.array(dataset[i][0])).float()
    img /= 255.0                       
    label = torch.tensor(dataset[i][1])
    targets = [BinaryClassifierOutputTarget(label)]
    img = img.permute(2, 0, 1)
    cam = GradCAM(model=model, target_layers=target_layers)
    grayscale_cam = cam(input_tensor=img.unsqueeze(0).float(), targets=targets)
    img = img.permute(1, 2, 0)

    visualization = show_cam_on_image(np.float32(img), grayscale_cam[0, :], use_rgb=True)
    axes[i].imshow(visualization)
    axes[i].axis("off")

layer에 여러 층을 넣어서도 사용할 수 있다 위의 이미지는 모든 layer의 가중치를 평균 낸 GradCAM 출력이다. 이번 예제에서 나는 dataset의 이미지를 그냥 불러왔는데 실제로는 모델이 보지 못한 valid set에 대해 GradCAM을 찍어 모델의 일반화 성능을 체크해 보는 게 좋다.

 

전체 코드를 캐글 노트북에서 돌려볼 수 있으니 직접 돌려보고 싶으면 아래의 링크를 참조하면 좋을 것 같다.

https://www.kaggle.com/code/hykhhijk/gradcam-sample


경진대회 때도 손만 많이 들어 실제로 시도해 본 적은 없는 기법이지만 드디어 사용해 봐서 스스로는 만족하지만 이번에 과제를 발표하며 느낀게 대체 이걸 어디다 쓸지였다. 

 

객관적으로 생각해봐도 이걸 구현할 시간이면 앙상블을 더 해보거나 모델을 경량화하는 등의 시도를 하는게 성능, 효율면에서 더 이득이고 좋은 결과를 얻었을 것 같다.

 

그렇다고 이러한 부분이 아예 쓸모가 없다는 건 아니고 적어도 목적을 확실히 하고 시도하는게 좋을 것 같다는 뜻이다. 예를 들어 "위에서 고양이가 아니라 배경을 보고 있는데 이런 문제는 어떻게 해결할까?"같은 인사이트를 얻게 해주는 용도 정도로는 사용하지 않는게 좋을것 같다. 비전공자나 학생에게는 설득력 있는 자료가 되겠지만 실제로는 이를 해결하지 못하면 결국 큰 의미를 가지지 못한다고 생각한다.