+2

Paper reading | Tóm tắt mô hình ResNeSt: Split-Attention Networks

Đóng góp của bài báo

Bài báo giới thiệu một kiến trúc mô hình đơn giản có tên ResNeSt sử dụng channel-wise attention trên các nhánh của mạng với mục tiêu tận dụng sức mạnh capture thông tin tương tác giữa các đặc trưng (cross-feature interaction) và học đa dạng các biểu diễn. Mô hình ResNeSt vượt qua mô hình EfficientNet trên khía cạnh đánh đổi độ chính xác và độ trễ (accuracy and latency trade-off) trên task image classification.

image.png

Split-Attention Networks

Toàn bộ ý tưởng hay ho của ResNeSt nằm trong Split-Attention block. Split-Attention block bao gồm 2 thành phần là featuremap group và các split attention.

image.png

Featuremap Group. Tại Featuremap Group, feature được chia thành các nhóm, ta có thể đặt số lượng Featuremap group bằng một cardinality hyperparameter KK. Featuremap group có thể gọi là Cardinal group (xem hình trên). Trong bài báo, nóm tác giả cũng giới thiệu một hyperparameter nữa là RR (radix) thể hiện số lượng split trong cardinal group. Do đó, số lượng feature group là G=KRG = KR. Tại mỗi feature group, ta thực hiện trích xuất feature sử dụng các layer Conv. Đầu ra của các layer này sẽ được đưa vào Split Attention.

image.png

Split Attention trong Cardinal Group. Đầu ra của các split được tổng hợp thông qua phép toán tính tổng element-wise tất cả các split trong cardinal group. Biểu diễn của cardinal group thứ kkU^k=j=R(k1)+1RkUj\hat{U}^k=\sum_{j=R(k-1)+1}^{R k} U_j trong đó UjU_j là biểu diễn đầu ra của từng split. Các thông tin ngữ cảnh toàn cục sau đó được tổng hợp thông qua một layer global average pooling theo chiều không gian. Thành phần thứ cc được tính như sau:

image.png

trong đó skRC/Ks^k \in \mathbb{R}^{C / K}.

Sau đó, mỗi featuremap của channel cc được tính toán như sau:

image.png

trong đó aik(c)a_i^k(c) là trọng số được tính như sau:

image.png

Gic\mathcal{G}_i^c có vai trò xác định trọng số của mỗi split cho channel cc dựa vào biểu diễn ngữ cảnh toàn cục sks^k.

ResNeSt Block. Các biểu diễn của cardinal group sau đó được concat theo chiều channel V=Concat(V1,V2,...,VK)V = Concat(V^1, V^2,..., V^K). Giống như block trong model ResNet, ta sử dụng một kết nối tắt: Y=V+XY = V + X nếu input và output featuremap có cùng kích thước. Nếu kích thước khác nhaum ta có thể sử dụng thêm một lớp convolution hoặc kết hợp convolution với pooling. Khi đó ta có Y=V+T(X)Y = V + \mathcal{T}(X).

Coding

Ta xây dựng các layer của model ResNeSt như sau:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GlobalAvgPool2d(nn.Module):
    '''
    global average pooling 2D class
    '''
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)


class ConvBlock(nn.Module):
    '''
    convolution 2D -> batch normalization -> ReLU
    '''
    def __init__(self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding
    ):
        super(ConvBlock, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.block(x)
        return x


'''
Split Attention
'''

class rSoftMax(nn.Module):
    '''
    (radix-majorize) softmax class

    input is cardinal-major shaped tensor.
    transpose to radix-major
    '''
    def __init__(self,
        groups=1,
        radix=2
    ):
        super(rSoftMax, self).__init__()

        self.groups = groups
        self.radix = radix

    def forward(self, x):
        B = x.size(0)
        # transpose to radix-major
        x = x.view(B, self.groups, self.radix, -1).transpose(1, 2)
        x = F.softmax(x, dim=1)
        x = x.view(B, -1, 1, 1)

        return x

class SplitAttention(nn.Module):
    def __init__(self,
        in_channels,
        channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        radix=2,
        reduction_factor=4
    ):
        super(SplitAttention, self).__init__()

        self.radix = radix

        self.radix_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=channels*radix,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups*radix,
                bias=bias
            ),
            nn.BatchNorm2d(channels*radix),
            nn.ReLU(inplace=True)
        )

        inter_channels = max(32, in_channels*radix//reduction_factor)

        self.attention = nn.Sequential(
            nn.Conv2d(
                in_channels=channels,
                out_channels=inter_channels,
                kernel_size=1,
                groups=groups
            ),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=inter_channels,
                out_channels=channels*radix,
                kernel_size=1,
                groups=groups
            )
        )

        self.rsoftmax = rSoftMax(
            groups=groups,
            radix=radix
        )

    def forward(self, x):
        
        '''
        input  : |             in_channels               |
        '''

        '''
        radix_conv : |                radix 0            |               radix 1             | ... |                radix r            |
                     | group 0 | group 1 | ... | group k | group 0 | group 1 | ... | group k | ... | group 0 | group 1 | ... | group k |
        '''
        x = self.radix_conv(x)

        '''
        split :  [ | group 0 | group 1 | ... | group k |,  | group 0 | group 1 | ... | group k |, ... ]

        sum   :  | group 0 | group 1 | ...| group k |
        '''
        B, rC = x.size()[:2]
        splits = torch.split(x, rC // self.radix, dim=1)
        gap = sum(splits)

        '''
        !! becomes cardinal-major !!
        attention : |             group 0              |             group 1              | ... |              group k             |
                    | radix 0 | radix 1| ... | radix r | radix 0 | radix 1| ... | radix r | ... | radix 0 | radix 1| ... | radix r |
        '''
        att_map = self.attention(gap)

        '''
        !! transposed to radix-major in rSoftMax !!
        rsoftmax : same as radix_conv
        '''
        att_map = self.rsoftmax(att_map)

        '''
        split : same as split
        sum : same as sum
        '''
        att_maps = torch.split(att_map, rC // self.radix, dim=1)
        out = sum([att_map*split for att_map, split in zip(att_maps, splits)])


        '''
        output : | group 0 | group 1 | ...| group k |

        concatenated tensors of all groups,
        which split attention is applied
        '''

        return out.contiguous()


'''
Bottleneck Block
'''

class BottleneckBlock(nn.Module):
    expansion = 4
    def __init__(self,
        in_channels,
        channels,
        stride=1,
        dilation=1,
        downsample=None,
        radix=2,
        groups=1,
        bottleneck_width=64,
        is_first=False
    ):
        super(BottleneckBlock, self).__init__()
        group_width = int(channels * (bottleneck_width / 64.)) * groups

        layers = [
            ConvBlock(
                in_channels=in_channels,
                out_channels=group_width,
                kernel_size=1,
                stride=1,
                padding=0
            ),
            SplitAttention(
                in_channels=group_width,
                channels=group_width,
                kernel_size=3,
                stride=stride,
                padding=dilation,
                dilation=dilation,
                groups=groups,
                bias=False,
                radix=radix
            )
        ]

        if stride > 1 or is_first:
            layers.append(
                nn.AvgPool2d(
                    kernel_size=3,
                    stride=stride,
                    padding=1
                )
            )
        
        layers += [
            nn.Conv2d(
                group_width,
                channels*4,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(channels*4)
        ]

        self.block = nn.Sequential(*layers)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        if self.downsample:
            residual = self.downsample(x)
        out = self.block(x)
        out += residual

        return F.relu(out)


if __name__ == "__main__":
    m = BottleneckBlock(256, 64)
    x = torch.randn(3, 256, 4, 4)
    print(m(x).size())

Xây dựng model ResNeSt từ các layer trên như sau:

'''
ResNeSt
'''
import torch
import torch.nn as nn

from layers import ConvBlock
from layers import GlobalAvgPool2d
from layers import BottleneckBlock

class ResNeSt(nn.Module):
    '''
    ResNeSt [1] class

    [1] ResNeSt : Split-Attention Networks,
        Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Mueller, R. Manmatha, Mu Li, Alexander Smola,
        https://arxiv.org/abs/2004.08955

    official implementation : https://github.com/zhanghang1989/ResNeSt

    '''
    def __init__(self,
        layers,
        radix=2,
        groups=1,
        bottleneck_width=64,
        n_classes=1000,
        stem_width=64
    ):
        super(ResNeSt, self).__init__()
        self.radix = radix
        self.groups = groups
        self.bottleneck_width = bottleneck_width

        self.deep_stem = nn.Sequential(
            ConvBlock(
                in_channels=3,
                out_channels=stem_width,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            ConvBlock(
                in_channels=stem_width,
                out_channels=stem_width,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            ConvBlock(
                in_channels=stem_width,
                out_channels=stem_width*2,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.MaxPool2d(
                kernel_size=3,
                stride=2,
                padding=1
            )
        )

        self.in_channels = stem_width*2

        self.layer1 = self._make_layers(
            channels=64,
            blocks=layers[0],
            stride=1,
            is_first=False
        )
        self.layer2 = self._make_layers(
            channels=128,
            blocks=layers[1],
            stride=2
        )
        self.layer3 = self._make_layers(
            channels=256,
            blocks=layers[2],
            stride=2
        )
        self.layer4 = self._make_layers(
            channels=512,
            blocks=layers[3],
            stride=2
        )

        self.classifier = nn.Sequential(
            GlobalAvgPool2d(),
            nn.Linear(
                in_features=512*BottleneckBlock.expansion,
                out_features=n_classes
            )
        )


    def _make_layers(self,
        channels,
        blocks,
        stride=1,
        is_first=True
    ):
        down_layers = None
        if not stride ==1 or not self.in_channels == channels * BottleneckBlock.expansion:
            down_layers = nn.Sequential(
                nn.AvgPool2d(
                    kernel_size=stride,
                    stride=stride,
                    ceil_mode=True,
                    count_include_pad=False
                ),
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=channels*BottleneckBlock.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(channels*BottleneckBlock.expansion)
            )

        layers = []
        layers.append(
            BottleneckBlock(
                in_channels=self.in_channels,
                channels=channels,
                stride=stride,
                downsample=down_layers,
                radix=self.radix,
                groups=self.groups,
                bottleneck_width=self.bottleneck_width,
                is_first=is_first
            )
        )

        self.in_channels = channels * BottleneckBlock.expansion
        for _ in range(1, blocks):
            layers.append(
                BottleneckBlock(
                    in_channels=self.in_channels,
                    channels=channels,
                    radix=self.radix,
                    groups=self.groups,
                    bottleneck_width=self.bottleneck_width
                )
            )

        return nn.Sequential(*layers)

    def forward(self, img):
        x = self.deep_stem(img)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.classifier(x)

        return x

if __name__ == "__main__":
    m = ResNeSt(
        [3, 4, 6, 3]
    )
    img = torch.randn(3, 3, 224, 224)
    print(m(img).size())

Thực nghiệm

Bảng dưới là hiệu suất của các cải tiến từ model ResNet trên tập dữ liệu ImageNet.

image.png

Bảng dưới là hiệu suất của model ResNeSt với các setting khác nhau. Ví dụ, 2s2x40d là radix = 2, cardinality = 2 và width = 40.

image.png

Bảng dưới so sánh độ chính xác và tốc độ inference của các SOTA model trên tập dữ liệu ImageNet. ResNeSt thể hiện sự cân bằng giữa độ chính xác và tốc độ inference một cách tối ưu nhất.

image.png

Kết quả trên task Object Detection với tập dữ liệu MS-COCO.

image.png

Bảng dưới so sánh kết quả trên task Instance Segmentation với tập dữ liệu MS-COCO.

image.png

Tương tự với tập dữ liệu ADE20K, ta có kết quả sau:

image.png

Với bộ dữ liệu Citscapes, ResNeSt vẫn thể hiện sự vượt trội với các model SOTA trước đó.

image.png

Không chỉ với các task hình ảnh đơn thuần, bảng dưới thể hiện kết quả trên task Pose estimation với tập MS-COCO.

image.png

Tham khảo

[1] ResNeSt: Split-Attention Networks

[2] Amazon Introduces ResNeSt: Strong, Split-Attention Networks

[3] https://github.com/zhanghang1989/ResNeSt/tree/master

[4] https://paperswithcode.com/method/channel-attention-module


All Rights Reserved

Viblo
Let's register a Viblo Account to get more interesting posts.