+2

Tìm hiểu về Swin Transformers

Giới thiệu chung

Ngoài các model CNN thì các model họ Transformer cũng đạt những kết quả ấn tượng khi sử dụng trong các task về Computer Vision như object detection, image classification, semantic segmentation,... Mô hình Transformer đầu tiên được sử dụng trong Computer Vision là ViT (Vision Transformer) đã cho những kết quả SOTA tại thời điểm ra mắt. Các mô hình Transformer cải tiến khác cho Computer Vision cũng lần lượt ra đời. Trong bài báo Swin Transformer: Hierarchical Vision Transformer using Shifted Windows đề xuất mô hình Swin Transformer có thể mô hình hóa sự khác biệt giữa sự thay đổi về tỷ lệ của đối tượng và độ phân giải của ảnh đầu vào hiệu quả hơn, cũng như có thể đóng vai trò là một pipeline tổng quát cho các bài toán Computer Vision.

Trong bài viết này ta sẽ review qua một số ý tưởng quan trọng của Swin Transformers và tìm hiểu xem tại sao model này lại hoạt động tốt ở các task Computer Vision.

Tại sao CNN lại phù hợp với các task trong Computer Vision

Bài báo đưa ra quan sát dựa trên nền tảng lý thuyết về lý do tại sao CNNs lại hiệu quả để mô hình hóa các miền trong tầm nhìn. Một số nhược điểm của Transformers đối với dữ liệu hình ảnh bao gồm:

  • Không như word token trong NLP, các phần tử visual có kích thước, độ phân giải khác nhau. Đặc biệt trong object detection.
  • Với ảnh có độ phân giải cao và các task Computer Vision yêu cầu dự đoán ở pixel-level là không thể thực hiện được với transformer vì độ phức tạp tính toán và bộ nhớ sử dụng của self-attention là bình phương của kích thước ảnh. Nếu đã thử so sánh việc train các model CNN và các model Transformers, đa phần bạn sẽ nhận thấy một điều rằng việc train các model Transformers thường rất lâu và yêu cầu nhiều bộ nhớ sử dụng, dễ xảy ra tình trạng OOM (out of memory).

Kiến trúc mô hình

image.png

Quan sát kiến trúc, bạn sẽ thấy có 4 block đặc biệt trong mô hình trên. Đầu tiên, ảnh input RGB được chia thành các patch bởi Patch Partition layer. Mỗi patch có kích thước 4×4×34 \times 4 \times 3 (3 là kênh màu RGB) và được coi như là một token. Các patch này được đi qua Linear embedding layer và được chiếu thành 1 không gian CC chiều giống như trong ViT.

Kiến trúc mô hình bao gồm nhiều stage (4 stage cho Swin-T). Các stage này được xây dựng bằng cách kết nối Patch merging layer với Swin Transformer block.

Swin Transformer block được xây dựng dựa trên việc chỉnh sửa self-attention. Một block bao gồm multi-head self-attention (MSA), layer normalization và 2 layer MLP. Block này có vai trò là computational backbone trong mạng. Số token H/4×W/4H/4 \times W/4 được duy trì xuyên suốt qua các Transformer block.

Hierarchical representation (biểu diễn phân cấp) được thực hiện thông qua các Patch merging layer. Lớp Patch Merging có nhiệm vụ làm giảm số lượng các token bằng cách gộp 4 patch (4 hàng xóm 2x2) thành 1 patch duy nhất (ảnh minh họa trên), như vậy số lượng token khi đi qua Stage 2 sẽ là H/8 x W/8 và độ dài của 1 token là 4C chiều (do gộp 4 path làm 1). Sau đó, các token sẽ được đưa qua 1 lớp Linear để giảm số chiều thành 2C và tiếp tục đưa qua một vài các Swin Transformer Block. Tương tự với các Stage 3 và 4, output của từng Stage lần lượt là H/16 x H/16 x 4C và H/32 x W/32 x 8C.

image.png

Nếu nhìn dưới góc độ như một mạng CNN thì ta có thể coi Merging layer là pooling layer và transformer block là conv layer. Cách tiếp cận này cho phép mạng có thể detect các object với các kích thước khác nhau một cách hiệu quả.

Shifted Windows

Vì Vision Transformer sử dụng self attention trên toàn bộ ảnh gây ra vấn đề là độ phức tạp của thuật toán sẽ tăng theo số patch. Mà số patch thì phụ thuộc vào kích cỡ ảnh đầu vào cũng như phụ thuộc vào bài toán mà chúng ta cần giải quyết (VD: những bài toán như segmentation yêu cầu dense prediction hoặc yêu cầu phải duy trì high resolution của ảnh thì chúng ta không thể để patch ảnh có kích thước quá lớn dẫn đến việc số lượng patch nhiều và làm tăng chi phí tính toán theo bậc 2 của ảnh đầu vào). Shifted window giúp giải quyết vấn đề này.

image.png

Sử dụng Shifted window với mục tiêu là tính self-attention trong local window. Một local window bao gồm M×MM \times M patch không overlap (M=7M = 7) và self attention được tính toán giữa các patch trên cùng 1 cửa sổ thay vì tính self-attention giữa các path trên toàn bộ bức ảnh. Kết quả là MSA ban đầu có độ phức tạp bình phương của hwhw trong khi Window-based MSA là tuyến tính.

image.png

Theo paper, việc thêm relative position bias cho tính toán self-attention làm tăng độ hiệu quả, sử dụng công thức bên dưới

image.png

Kết quả thực nghiệm cho thấy sự hiệu quả được thể hiện trong bảng dưới

image.png

Efficient shifting

Để xử lý hiệu quả cho các window ở viền có kích thước nhỏ hơn M×MM \times M, paper đề xuất ý tưởng là dịch chuyển theo tính chu kỳ, việc tính self attention sẽ được thực hiện trên 4 vùng cửa số giống với lúc chưa dịch và điều này tiết kiệm chi phí tính toán.

image.png

Nhóm tác giả xác nhận cách tiếp cận làm giảm độ trễ của mạng từ Performer, một trong những kiến trúc Transformer nhanh nhất. Giảm thời gian infer là rất quan trọng vì sau này nó có thể được đánh đổi với độ chính xác bằng cách sử dụng các mạng lớn hơn.

image.png

Cài đặt

Ta thử implement model Swin Transformer với bộ dữ liệu CIFAR-100 sử dụng Tensorflow

Import library

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

Chuẩn bị dữ liệu

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
plt.show()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
169009152/169001437 [==============================] - 3s 0us/step
169017344/169001437 [==============================] - 3s 0us/step
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 100)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100)

image.png

Config các hyperparameter

Một tham số quan trọng cần chọn là patch_size, kích thước của input patch. Để sử dụng từng pixel làm đầu vào riêng lẻ, bạn có thể đặt patch_size thành (1, 1). Dưới đây, ta sử dụng config như paper gốc.

patch_size = (2, 2)  # 2-by-2 sized patches
dropout_rate = 0.03  # Dropout rate
num_heads = 8  # Attention heads
embed_dim = 64  # Embedding dimension
num_mlp = 256  # MLP layer size
qkv_bias = True  # Convert embedded patches to query, key, and values with a learnable additive value
window_size = 2  # Size of attention window
shift_size = 1  # Size of shifting window
image_dimension = 32  # Initial image size

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]

learning_rate = 1e-3
batch_size = 128
num_epochs = 40
validation_split = 0.1
weight_decay = 0.0001
label_smoothing = 0.1

Helper functions

Tạo 2 helper function để hỗ trợ tạo các patch từ hình ảnh, merge patch và sử dụng dropout

def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, shape=(-1, height, width, channels))
    return x


class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x):
        input_shape = tf.shape(x)
        batch_size = input_shape[0]
        rank = x.shape.rank
        shape = (batch_size,) + (1,) * (rank - 1)
        random_tensor = (1 - self.drop_prob) + tf.random.uniform(shape, dtype=x.dtype)
        path_mask = tf.floor(random_tensor)
        output = tf.math.divide(x, 1 - self.drop_prob) * path_mask
        return output

Window based multi-head self-attention

Thường thì Transformer sẽ thực hiện global self-attention, trong đó mối quan hệ giữa một token với tất cả các token khác được tính toán. Tuy nhiên, ta sẽ thực hiện như trong bài báo là tính self-attention cho local window.

class WindowAttention(layers.Layer):
    def __init__(
        self, dim, window_size, num_heads, qkv_bias=True, dropout_rate=0.0, **kwargs
    ):
        super(WindowAttention, self).__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

    def build(self, input_shape):
        num_window_elements = (2 * self.window_size[0] - 1) * (
            2 * self.window_size[1] - 1
        )
        self.relative_position_bias_table = self.add_weight(
            shape=(num_window_elements, self.num_heads),
            initializer=tf.initializers.Zeros(),
            trainable=True,
        )
        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
        coords = np.stack(coords_matrix)
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

        self.relative_position_index = tf.Variable(
            initial_value=tf.convert_to_tensor(relative_position_index), trainable=False
        )

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
        x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.transpose(k, perm=(0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = tf.reshape(
            self.relative_position_index, shape=(-1,)
        )
        relative_position_bias = tf.gather(
            self.relative_position_bias_table, relative_position_index_flat
        )
        relative_position_bias = tf.reshape(
            relative_position_bias, shape=(num_window_elements, num_window_elements, -1)
        )
        relative_position_bias = tf.transpose(relative_position_bias, perm=(2, 0, 1))
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.get_shape()[0]
            mask_float = tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
            )
            attn = (
                tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
                + mask_float
            )
            attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
            attn = keras.activations.softmax(attn, axis=-1)
        else:
            attn = keras.activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv

Swin Transformer model

class SwinTransformer(layers.Layer):
    def __init__(
        self,
        dim,
        num_patch,
        num_heads,
        window_size=7,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super(SwinTransformer, self).__init__(**kwargs)

        self.dim = dim  # number of input dimensions
        self.num_patch = num_patch  # number of embedded patches
        self.num_heads = num_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.num_mlp = num_mlp  # number of MLP nodes

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = DropPath(dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, shape=[-1, self.window_size * self.window_size]
            )
            attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)

    def call(self, x):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=(-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, shape=(-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.reshape(
            attn_windows, shape=(-1, self.window_size, self.window_size, channels)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.reshape(x, shape=(-1, height * width, channels))
        x = self.drop_path(x)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x

Model training and evaluation

Extract và embed patches

class PatchExtract(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super(PatchExtract, self).__init__(**kwargs)
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[0]

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=(1, self.patch_size_x, self.patch_size_y, 1),
            strides=(1, self.patch_size_x, self.patch_size_y, 1),
            rates=(1, 1, 1, 1),
            padding="VALID",
        )
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super(PatchEmbedding, self).__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        return self.proj(patch) + self.pos_embed(pos)


class PatchMerging(tf.keras.layers.Layer):
    def __init__(self, num_patch, embed_dim):
        super(PatchMerging, self).__init__()
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.get_shape().as_list()
        x = tf.reshape(x, shape=(-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat((x0, x1, x2, x3), axis=-1)
        x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)

Build model

input = layers.Input(input_shape)
x = layers.RandomCrop(image_dimension, image_dimension)(input)
x = layers.RandomFlip("horizontal")(x)
x = PatchExtract(patch_size)(x)
x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=0,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(num_classes, activation="softmax")(x)

Training

model = keras.Model(input, output)
model.compile(
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

history = model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_split=validation_split,
)

Thử visualize quá trình training

plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()

image.png

Kết quả

loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
313/313 [==============================] - 3s 8ms/step - loss: 2.7039 - accuracy: 0.4288 - top-5-accuracy: 0.7366
Test loss: 2.7
Test accuracy: 42.88%
Test top 5 accuracy: 73.66%

Kết quả ở đây khá thấp do chỉ train trên 40 epoch. Bạn đọc có thể tăng số epoch (recommend là 150) để có được performance tốt hơn 😃

Tài liệu tham khảo

[1] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows [2] https://github.com/microsoft/Swin-Transformer [3] https://keras.io/examples/vision/swin_transformers/


All Rights Reserved

Viblo
Let's register a Viblo Account to get more interesting posts.