LaMa (Resolution-robust Large Mask Inpainting with Fourier Convolutions)

이름 LaMa : Resolution-robust Large Mask Inpainting with Fourier Convolutions
저자 Roman Suvorov, Elizaveta Logacheva, Anton Mashikhin, Anastasia Remizova, Arsenii Ashukha, Aleksei Silvestrov, Naejin Kong, Harshith Goka, Kiwoong Park, Victor Lempitsky
토픽 Image Inpainting
년도 2021.09
학회  
링크 [논문], [코드]

Intro

Intro

이미지 인페인팅 (주어진 이미지와 (유저가 입력한) ‘마스크’가 있을때, 마스크로 가려진 부분들을 자연스럽게 유추하는 문제) 는 이미지 생성뿐만이 아닌, 거시적인 이미지 구조를 ‘이해’하는 것을 요구합니다.

현재 GAN을 통한 이미지 생성이 실제와 분간이 되지 않을 정도록 뛰어난 성능을 보임에도 불구하고, 아직 이미지 인페인팅 문제는 이미지의 대부분을 덮는 마스크나 복잡한 구조의 마스크, 고해상도의 이미지 등의 상황에 대해 제대로 대처를 하지 못하고 있으며, 본 논문의 저자들은 이를 small receptive field의 문제 라고 보고 있습니다.

이를 해결하기 위해, 저자들은 LaMa(Large Mask Inpainting) 모델을 소개합니다. LaMa는 Fast Fourier Convolution 을 기반으로 하여 모델의 초기 단계에서 부터 이미지 전체를 덮을수 있는 Receptive Field를 가지고 있으며, loss function 도 High Receptive Field를 기반으로한 Perceptual loss를 사용하여 이미지의 거시적인 구조를 수월하게 이해할 수 있습니다. Fast Fourier Convolution의 특성상, 고해상도의 이미지에 대해서도 문제없이 적응이 가능합니다. 또한 훈련 단계에서 크고 복잡한 모양의 마스크들을 사용해서 해당 마스크들에 대한 저항력을 기릅니다.

저자들은 연구를 통해, 256해상도에서만 훈련시킨 본 모델로도 고해상도의 이미지와 큰 마스크들에 대해서 자연스럽게 이미지 인페인팅을 수행할 수 있음을 보여줍니다.


Main Reference Works


Method

이미지가 x, 마스크가 m일때, 모델 f(변수: $\theta$)는 마스크로 덮혀진 이미지와 마스크를 입력으로 받아, 마스크 부분이 인페인팅된 이미지를 받습니다.

\[\hat{x}=f_{\theta}(\text{stack}(x \odot m,m))\]
FFC구조
FFC 모델의 전반적인 구조. Local (미시적 요소) / Global (거시적 요소) 로 나뉜 feature들 중에서 Local은 고전적인 convolution을, Global은 Fourier Transformation을 통해 얻은 주파수에서의 convolution을 통해 모델의 이미지 전체에 대한 이해를 돕습니다.

Fast Fourier Convolution (FFC)

ResNet과 같은 기존 convolution 기반 모델들은, 작은 사이즈의 커널 (3*3 사이즈) 때문에 유효한 receptive field가 초반에 작아서 이미지 전체를 덮지 못하며, 커지는 속도도 느립니다. 따라서 모델 내부 대다수의 레이어들은 거시적인 구조에 대한 정보를 담지 못하고, 인페인팅시 썩 자연스럽지 못한 결과가 나옵니다. 특히 면적이 큰 마스크들의 경우, receptive field 전체가 마스크 안에 있게 되는 불상사가 발생하여 인페인팅이 불가능한 경우도 발생합니다.

이를 막기 위해 초기 단계에서 거시적인 구조에 대한 정보를 담을 수 있는 레이어가 필요하고, 이를 위해 FFC (Fast Fourier Convolution)이 필요합니다.

  • FFC는 채널 단위 Fast Fourier Transform (FFT) 를 사용하며, 이를 통해 이미지 전체를 덮을 수 있는 receptive field를 가집니다.
  • FFC는 두 가지 갈래의 연산을 수행하는데, local branch는 미시적인 구조를 위한 고전적인(?) convolutional layer를, global branch는 거시적인 구조를 위한 real FFT를 사용합니다. 이후 두 갈래의 출력값을 합쳐서 결과를 출력합니다.
  • FFC는 미분가능하며, convolutional layer에 1:1로 대응이 가능합니다.
    • 초기 단계부터 이미지 전체를 이해하는 네트워크를 생성 가능하며, 이를 통해 효율적인 훈련이 가능합니다.
  • 또한 receptive field가 이미지 전체를 덮을 수 있기 때문에, 고해상도 이미지에 효과적으로 적용가능합니다.
  • 결과를 확인해보면 벽돌, 사다리, 창문 등 인공적인 구조물 패턴에 잘 대응하는걸 볼 수 있습니다.

FFC 중에서도 Real FFT의 상세 구조는 다음과 같습니다;

batch = x.shape[0]
ifft_shape_slice = x.shape[-2:]

# RealFFT2d : Real(batch, c, h, w) --> Complex(batch, c, h, w/2)
ffted = torch.fft.rfftn(x, dim=(-2, -1), norm='ortho')

# ComplexToReal : Complex(batch, c, h, w/2) --> Real(batch, 2c, h, w/2)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
ffted = ffted.view((batch, -1,) + ffted.size()[3:])

# Convolutional Block : Real(batch, 2c, h, w/2) --> Real(batch, 2c, h, w/2)
# 1*1 kernel이지만, 위 fft에서 변환된 이미지의 주파수 도메인에 대해 작업하므로, 
# 이미지 전역을 덮는 receptive field를 가집니다.
ffted = self.conv_layer(ffted)
ffted = self.bn(ffted)
ffted = sel
# RealToComplex : Real(batch, 2c, h, w/2) --> Complex(batch, c, h, w/2)
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:])
ffted = ffted.permute(0, 1, 3, 4, 2).contiguous()
ffted = torch.complex(ffted[..., 0], ffted[..., 1])

# InverseRealFFT2d : Complex(batch, c, h, w/2) --> Real(batch, c, h, w)
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=(-2,-1), norm='ortho')
모델구조
모델 전체의 구조. 비교적 간단한 모래시계 모양의 구조를 가집니다.

모델 전반 구조

  • 입력 : 마스크로 덮힌 이미지 + 마스크 (크기 : 4해상도해상도)
  • 출력 : 마스크로 덮힌 부분이 칠해진 이미지 (크기 : 3해상도해상도)
  • 모델 : 모래시계 모양의 ResNet-structure 모델
    1. Reflection Padding + Conv-BatchNorm-ReLU Block (ksize=7)
    2. Downsampling Blocks * 3
      • Conv-BatchNorm-ReLU Block (ksize=3, stride=2, padding=1)
      • 마지막 블록에서 feature를 local과 global로 나눔
        • 이때, local은 채널중 25%, global은 75%의 비율로 나눕니다.
    3. FFC Residual Blocks * 9 (big-LAMA 세팅이면 18)
      • 마지막 블록에서 나눠진 local과 global을 다시 하나의 feature로 합침
    4. Upsampling Blocks * 3
      • ConvT-BatchNorm-ReLU Block (ksize=3, stride=2, padding=1)
    5. Reflection Padding + Conv-BatchNorm-ReLU Block (ksize=7) + Sigmoid

손실함수

인페인팅 함수에서 손실함수를 정하는건, 가능한 출력물의 가짓수 때문에 매우 어려운 일이 됩니다.

이를 위해 저자들은 Perceptual Loss 중에서도 특별한 형태의 손실함수를 사용합니다.

  • 기존의 supervised loss의 경우, 마스크가 이미지 대부분을 덮고 있을때 모델의 학습을 어렵게 하기 때문에 출력물의 해상도가 떨어져서 나오는 경우가 많았습니다.
  • 이에 비해 perceptual loss는 사전에 훈련된 네트워크 (ResNet50 기반)를 통해 이미지간의 feature distance를 측정하기 때문에, 큰 마스크에도 상대적으로 잘 되며 다양한 출력물을 지원합니다.
  • Perceptual Loss에서 쓰일 네트워크도 초반 레이어의 넓은 receptive field와, 이를 통한 거시적인 이미지 구조에 대한 이해가 필요하기 때문에, 저자들은 High Receptive Field Model (이하 HRFM)을 새로 설계하여 perceptual loss를 구합니다.
    • Fourier / Dilated convolution을 기존 convolution 대신 사용
  • 해당 HRFM이 어떻게 훈련되었는지도 중요합니다.
    • Segmentation-based 일 경우, 배경의 사물같은 고레벨 구성요소 정보에 대해 집중할 수 있습니다.
    • Classification-based일 경우, 배경 텍스쳐같은 거시적 정보에 집중할 수 있습니다.
    • 선택과 집중의 문제

또한 Pix2Pix에 영향을 받은 Discriminator를 추가해서 Adversarial Loss를 손실함수에 포함시킵니다.

  • 해당 Discriminator는 local patch 단계에서 계산해, 해당 patch가 마스크로 덮힌 부분인지 아닌지를 분별합니다.
  • Adversarial Loss는 Non-saturating adversarial loss의 형태로 사용됩니다.
  • 또한, Discriminator에 대한 feature-matching loss를 추가로 계산해, GAN 훈련의 안정성을 늘립니다.
  • 이 두 손실함수를 통해 자연스럽게 원본 이미지와 맞물리는 이미지를 생성할 수 있습니다.

마지막으로 R1 gradient penalty 함수를 추가하는걸로 손실함수를 완성시킵니다.


데이터셋 예시
훈련용 데이터셋 마스크 생성 예시. LaMa는 밑의 두 행처럼 마스크를 생성합니다. 보시는 것처럼 기존보다 더 크고 다양한 면적의 마스크들을 볼 수 있습니다.

데이터셋 마스크 생성

다양한 형태의 입력 마스크에 성공적으로 대응하기 위해서는 훈련 단계에서 다양한 마스크를 생성해서 같이 훈련시키는 것이 중요합니다.

  • 저자들은 이를 위해 기존보다 더 크고 다양한 면적을 덮는 마스크 생성법을 사용합니다.
  • 이를 위해 두 가지 방식으로 마스크를 생성하는데, 하나는 polygonal chain에서 크기를 더 키운 wide mask 방식이고, 다른 하나는 임의의 비율을 가진 직사각형을 사용한 box mask 방식입니다.
  • 이때, 마스크가 이미지의 50% 이상을 넘지 않도록 합니다.

실험결과
서로 다른 데이터셋과 마스크 종류에 따른 인페인팅 모델들의 결과 비교. LaMa는 거의 모든 결과에 대해 타 모델 대비 우위를 차지하는걸 확인할 수 있습니다.

실험 및 결과

LaMa 모델은 실험에서 모델 크기가 몇배는 큰 모델들과의 비교에서도 LPIPS / FID metric면에서 훨씬 좋은 결과를 보여줍니다.

  • 비교 실험에서 metric면에서 (부분적으로) LaMa 보다 나은 모델들이었던 CoModGANMADF는 LaMA 대비 3~4배는 더 큰 모습을 보이며, segmentation mask를 대상으로 한 실험에서는 LaMa 보다 쳐지는 모습을 보입니다.
FFC 유효성
(훈련에 사용한 입력 해상도보다 큰 해상도의) 고화질 이미지를 대상으로한 인페인팅. FFC를 사용한 LaMa 외의 다른 모델들은 모두 인페인팅 결과가 부정확한 걸 확인할 수 있습니다.

Ablation Study 결과;

  • FFC의 유무는 특히 고화질 이미지를 대상으로한 inference에서 눈에 띄며, 창문이나 철조망 같은 반복적 구조의 inpainting에 큰 영향을 끼칩니다.
    • Dilated Convolution의 경우, FFC와 비슷한 성질을 가지며 (어느 정도는) 대용으로 사용이 가능하지만, receptive field가 더 제한이 되어 고화질 이미지를 대상으로 적용이 어렵다는 단점을 가집니다.
  • Perceptual Loss의 기본이 되는 모델의 선정은 최종 결과에 중요하게 적용됩니다.
    • 논문에서는 ResNet50 기반 Segmentation + ADE20K 데이터셋 훈련 모델을 채택했습니다.

부정적 lama 예시
부정적인 결과 예시. 인물을 지운 부분에 흐릿한 artifact가 생기는걸 볼 수 있으며, 이는 데이터셋 분포에 포함되어 있지 않던 복잡한 배경 및 마스크 크기 때문입니다.

결론

저자들은 FFC 레이어와 특수 perceptual loss, 마스크 데이터의 생성을 통해 단순한 모델구조로도 다양한 상황에 쉽게 사용될수 있는 image inpainting 모델을 제시했습니다.

아직 나름대로의 한계 (데이터셋 분포에 있지 않은 이미지들에 대한 인페인팅, 사람 몸같은 큰 대상의 인페인팅시 왜곡 현상)등이 있지만, 현재까지 나온 모델들 중에서 가장 자연스럽게 인페인팅이 가능한 모델이라고 생각됩니다.

저자들은 또한 FFC나 Dilated Convolution 뿐만이 아닌, Vision Transformer의 사용도 염두에 두면서 향후 이미지 인페인팅 연구에 대한 기대감을 심어줍니다.