[머신러닝] PyTorch Lightning

Updated:

TorchIO를 이용해서 3D Image Augmentation 하는 코드 예시를 찾다가 PyTorch Lightning[1]이라는 라이브러리를 알게 됐다.

PyTorch Lightning

PyTorch LightningPyTorch 프레임워크를 보다 효율적으로 사용할 수 있도록 만들어진 라이브러리다.
Lightning 모듈을 사용하면 코드가 훨씬 직관적이고 간결하다는 장점이 있다.

PyTorch Lightning을 사용했을 때와 사용하지 않았을 때의 코드를 비교해 보자.
먼저 아래의 코드는 PyTorch Lightning을 안 쓰고 training과 validation을 구현한 것이다.

# Training
def train(model, train_loader, device, lr=0.0001, optim_class=optim.Adam, scheduler=None):
  criterion = nn.L1Loss()
  optimizer = optim_class(model.parameters(), lr=lr)
  total_loss = 0

  model.train()
  for inputs, labels in tqdm(train_loader):
    inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)

    optimizer.zero_grad()

    output = model(inputs)

    loss = criterion(output, labels)
    loss.backward()

    optimizer.step()
    if scheduler:
      scheduler.step()

    total_loss += loss.data.item()

    gc.collect()
    torch.cuda.empty_cache()

  return total_loss
# Validation
def valid(model, valid_loader, device):
  criterion = nn.L1Loss()
  total_loss = 0

  model.eval()
  with torch.no_grad():
    for inputs, labels in tqdm(valid_loader):
      inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)

      output = model(inputs)

      loss = criterion(output, labels)
      total_loss += loss.data.item()

  return total_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epoch = 100
batch_size = 16
lr = 0.0001
optim_class = optim.Adam

dataset = CustomDataset()
train_set, valid_set = train_test_split(dataset, test_size=0.2)
train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

model = Net()
model.to(device)
for epoch in range(n_epoch):
  print('# Epoch %d / %d' % (epoch + 1, n_epoch))
  train_mae = train(model, train_loader, device, lr=lr, optim_class=optim_class)
  valid_mae = valid(model, valid_loader, device)
  print('Train MAE: %.3f / Validation MAE: %.3f' % (train_mae, valid_mae))

이 코드가 PyTorch Lightning을 쓰면 이렇게 깔끔해진다:

import pytorch_lightning as pl
...

data = CustomDataModule(batch_size=16, train_val_ratio=0.2)

net = Net()

model = Model(
  net=net,
  criterion=nn.L1Loss(),
  learning_rate=0.0001,
  optimizer_class=optim.Adam
)

trainer = pl.Trainer(
  gpus=1 if torch.cuda.is_available() else 0
)

trainer.fit(model=model, datamodule=data)

더 자세히 살펴보기

데이터 모듈

PyTorch Lightning의 또다른 장점은 데이터 모듈을 통해 data augmentation도 손쉽게 구현할 수 있다는 점이다 [2].

# augmentation 라이브러리
import torchio as tio
...

class CustomDataModule(pl.LightningDataModule):
  def __init__(self, batch_size, train_val_ratio):
    super().__init__()
    self.prepare_data_per_node = True
    self.batch_size = batch_size
    self.train_val_ratio = train_val_ratio
    self.train_set = None
    self.valid_set = None

  def get_augmentation_transform(self):
    random_rotate = ...
    random_flip = ...
    random_shift = ...
    augment = tio.transforms.OneOf([random_rotate, random_flip, random_shift])
    return augment

  def prepare_data(self):
    ...
    train, val = train_test_split(dataset, test_size=self.train_val_ratio)
    augment = self.get_augmentation_transform()

    self.train_set = tio.SubjectsDataset(train, transform=augment)
    self.valid_set = tio.SubjectsDataset(val)

  def train_dataloader(self):
    return DataLoader(self.train_set, batch_size=self.batch_size)

  def val_dataloader(self):
    return DataLoader(self.valid_set, batch_size=self.batch_size)

학습 루프 구조

일반적으로 쓰는 구조는 training_step - validation_step - validation_epoch_end이다 [3].

class Model(pl.LightningModule):
  def __init__(self, net, criterion, learning_rate, optimizer_class):
    super().__init__()
    self.net = net
    self.lr = learning_rate
    self.criterion = criterion
    self.optimizer_class = optimizer_class

  def configure_optimizers(self):
    optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
    return optimizer

  def prepare_batch(self, batch):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return batch['X'][tio.DATA].to(device, dtype=torch.float), batch['y'].to(device, dtype=torch.float)

  def infer_batch(self, batch):
    X, y = self.prepare_batch(batch)
    y_hat = self.net(X)
    return y_hat, y

  def training_step(self, batch, batch_idx):
    y_hat, y = self.infer_batch(batch)
    loss = self.criterion(y_hat, y)
    self.log('train_loss': loss)
    return loss

  def validation_step(self, batch, batch_idx):
    y_hat, y = self.infer_batch(batch)
    loss = self.criterion(y_hat, y)
    self.log('val_loss': loss)
    return loss

  def validation_epoch_end(self, outputs):
        # called at the end of the validation epoch
        # outputs is an array with what you returned in validation_step for each batch
        # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}]
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

WandB 연동

WandB는 실험 모니터링을 제공하는 머신러닝 플랫폼이다 [4]. PyTorch Lightning에는 WandbLogger 모듈이 있어서 WandB에 쉽게 연동할 수 있다.

# logger 세팅
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name='experiment-001', project='pytorch-lightning')

...

trainer = pl.Trainer(
  logger=wandb_logger,  # logger 사용
  gpus=1 if torch.cuda.is_available() else 0
)

PyTorch Lightning 사용 후기

확실히 Lightning 모듈을 사용하니까 나중에 코드를 읽거나 수정할 때도 편하고 online augmentation을 구현하기도 쉽다는 점에서 이게 좋은 툴이라고 느꼈다. 하지만 max_epochs 값을 설정했는데도 설정이 안 됐다는 warning이 뜨면서 기본값인 1000으로 설정되는 등의 버그가 있어서 그런 건 좀 아쉬웠다. (아님 내가 설정 방법을 잘못 이해한 건지…)
잘 활용하면 좋을 것 같은데 좀 더 공부하고 내 능력만 된다면 이 라이브러리에 contribute 하고 싶다.

참고 문서 출처

[1] PyTorch Lightning Github
[2] TorchIO를 이용한 3D Segmentation
[3] 우리가 PyTorch Lightning을 써야 하는 이유
[4] Weights & Biases Documentation
Pytorch lightning 튜토리얼
PyTorch Lightning

Leave a comment