【人工智能】关于FreqFusion.py官方代码的研究(修正版)

图片[1] - Python AI C++笔记 - 【人工智能】关于FreqFusion.py官方代码的研究(修正版) - Python AI C++笔记 - 小竹の笔记本
# TPAMI 2024:Frequency-aware Feature Fusion for Dense Image Prediction

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import warnings
import numpy as np

try:
    from mmcv.ops.carafe import normal_init, xavier_init, carafe
except ImportError:

    def xavier_init(module: nn.Module,
                    gain: float = 1,
                    bias: float = 0,
                    distribution: str = 'normal') -> None:
        assert distribution in ['uniform', 'normal']
        if hasattr(module, 'weight') and module.weight is not None:
            if distribution == 'uniform':
                nn.init.xavier_uniform_(module.weight, gain=gain)
            else:
                nn.init.xavier_normal_(module.weight, gain=gain)
        if hasattr(module, 'bias') and module.bias is not None:
            nn.init.constant_(module.bias, bias)

    def carafe(x, normed_mask, kernel_size, group=1, up=1):
            b, c, h, w = x.shape
            _, m_c, m_h, m_w = normed_mask.shape
            print('x', x.shape)
            print('normed_mask', normed_mask.shape)
            # assert m_c == kernel_size ** 2 * up ** 2
            assert m_h == up * h
            assert m_w == up * w
            pad = kernel_size // 2
            # print(pad)
            pad_x = F.pad(x, pad=[pad] * 4, mode='reflect')
            # print(pad_x.shape)
            unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0)
            # unfold_x = unfold_x.reshape(b, c, 1, kernel_size, kernel_size, h, w).repeat(1, 1, up ** 2, 1, 1, 1, 1)
            unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w)
            unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest')
            # normed_mask = normed_mask.reshape(b, 1, up ** 2, kernel_size, kernel_size, h, w)
            unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w)
            normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)
            res = unfold_x * normed_mask
            # test
            # res[:, :, 0] = 1
            # res[:, :, 1] = 2
            # res[:, :, 2] = 3
            # res[:, :, 3] = 4
            res = res.sum(dim=2).reshape(b, c, m_h, m_w)
            # res = F.pixel_shuffle(res, up)
            # print(res.shape)
            # print(res)
            return res

    def normal_init(module, mean=0, std=1, bias=0):
        if hasattr(module, 'weight') and module.weight is not None:
            nn.init.normal_(module.weight, mean, std)
        if hasattr(module, 'bias') and module.bias is not None:
            nn.init.constant_(module.bias, bias)


def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None,
           warning=True):
    if warning:
        if size is not None and align_corners:
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
            output_h, output_w = tuple(int(x) for x in size)
            if output_h > input_h or output_w > input_w:
                if ((output_h > 1 and output_w > 1 and input_h > 1
                     and input_w > 1) and (output_h - 1) % (input_h - 1)
                        and (output_w - 1) % (input_w - 1)):
                    warnings.warn(
                        f'When align_corners={align_corners}, '
                        'the output would more aligned if '
                        f'input size {(input_h, input_w)} is `x+1` and '
                        f'out size {(output_h, output_w)} is `nx+1`')
    return F.interpolate(input, size, scale_factor, mode, align_corners)

def hamming2D(M, N):
    """
    生成二维Hamming窗

    参数:
    - M:窗口的行数
    - N:窗口的列数

    返回:
    - 二维Hamming窗
    """
    # 生成水平和垂直方向上的Hamming窗
    # hamming_x = np.blackman(M)
    # hamming_x = np.kaiser(M)
    hamming_x = np.hamming(M)
    hamming_y = np.hamming(N)
    # 通过外积生成二维Hamming窗
    hamming_2d = np.outer(hamming_x, hamming_y)
    return hamming_2d

class FreqFusion(nn.Module):
    def __init__(self,
                hr_channels,
                lr_channels,
                scale_factor=1,
                lowpass_kernel=5,
                highpass_kernel=3,
                up_group=1,
                encoder_kernel=3,
                encoder_dilation=1,
                compressed_channels=64,        
                align_corners=False,
                upsample_mode='nearest',
                feature_resample=False, # use offset generator or not
                feature_resample_group=4,
                comp_feat_upsample=True, # use ALPF & AHPF for init upsampling
                use_high_pass=True,
                use_low_pass=True,
                hr_residual=True,
                semi_conv=True,
                hamming_window=True, # for regularization, do not matter really
                feature_resample_norm=True,
                **kwargs):
        super().__init__()
        self.scale_factor = scale_factor
        self.lowpass_kernel = lowpass_kernel
        self.highpass_kernel = highpass_kernel
        self.up_group = up_group
        self.encoder_kernel = encoder_kernel
        self.encoder_dilation = encoder_dilation
        self.compressed_channels = compressed_channels
        self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels,1)
        self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels,1)
        self.content_encoder = nn.Conv2d( # ALPF generator
            self.compressed_channels,
            lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
            self.encoder_kernel,
            padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
            dilation=self.encoder_dilation,
            groups=1)
        
        self.align_corners = align_corners
        self.upsample_mode = upsample_mode
        self.hr_residual = hr_residual
        self.use_high_pass = use_high_pass
        self.use_low_pass = use_low_pass
        self.semi_conv = semi_conv
        self.feature_resample = feature_resample
        self.comp_feat_upsample = comp_feat_upsample
        if self.feature_resample:
            self.dysampler = LocalSimGuidedSampler(in_channels=compressed_channels, scale=2, style='lp', groups=feature_resample_group, use_direct_scale=True, kernel_size=encoder_kernel, norm=feature_resample_norm)
        if self.use_high_pass:
            self.content_encoder2 = nn.Conv2d( # AHPF generator
                self.compressed_channels,
                highpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
                self.encoder_kernel,
                padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
                dilation=self.encoder_dilation,
                groups=1)
        self.hamming_window = hamming_window
        lowpass_pad=0
        highpass_pad=0
        if self.hamming_window:
            self.register_buffer('hamming_lowpass', torch.FloatTensor(hamming2D(lowpass_kernel + 2 * lowpass_pad, lowpass_kernel + 2 * lowpass_pad))[None, None,])
            self.register_buffer('hamming_highpass', torch.FloatTensor(hamming2D(highpass_kernel + 2 * highpass_pad, highpass_kernel + 2 * highpass_pad))[None, None,])
        else:
            self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))
            self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            # print(m)
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')
        normal_init(self.content_encoder, std=0.001)
        if self.use_high_pass:
            normal_init(self.content_encoder2, std=0.001)

    def kernel_normalizer(self, mask, kernel, scale_factor=None, hamming=1):
        if scale_factor is not None:
            mask = F.pixel_shuffle(mask, self.scale_factor)
        n, mask_c, h, w = mask.size()
        mask_channel = int(mask_c / float(kernel**2)) # group
        # mask = mask.view(n, mask_channel, -1, h, w)
        # mask = F.softmax(mask, dim=2, dtype=mask.dtype)
        # mask = mask.view(n, mask_c, h, w).contiguous()

        mask = mask.view(n, mask_channel, -1, h, w)
        mask = F.softmax(mask, dim=2, dtype=mask.dtype)
        mask = mask.view(n, mask_channel, kernel, kernel, h, w)
        mask = mask.permute(0, 1, 4, 5, 2, 3).view(n, -1, kernel, kernel)
        # mask = F.pad(mask, pad=[padding] * 4, mode=self.padding_mode) # kernel + 2 * padding
        mask = mask * hamming
        mask /= mask.sum(dim=(-1, -2), keepdims=True)
        # print(hamming)
        # print(mask.shape)
        mask = mask.view(n, mask_channel, h, w, -1)
        mask =  mask.permute(0, 1, 4, 2, 3).view(n, -1, h, w).contiguous()
        return mask

    def forward(self, hr_feat, lr_feat, use_checkpoint=False): # use check_point to save GPU memory
        if use_checkpoint:
            return checkpoint(self._forward, hr_feat, lr_feat)
        else:
            return self._forward(hr_feat, lr_feat)

    def _forward(self, hr_feat, lr_feat):
        compressed_hr_feat = self.hr_channel_compressor(hr_feat)
        compressed_lr_feat = self.lr_channel_compressor(lr_feat)
        if self.semi_conv:
            if self.comp_feat_upsample:
                if self.use_high_pass:
                    mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) #从hr_feat得到初始高通滤波特征
                    mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel, hamming=self.hamming_highpass) #kernel归一化得到初始高通滤波
                    compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1) #利用初始高通滤波对压缩hr_feat的高频增强 (x-x的低通结果=x的高通结果)
                    
                    mask_lr_hr_feat = self.content_encoder(compressed_hr_feat) #从hr_feat得到初始低通滤波特征
                    mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel, hamming=self.hamming_lowpass) #kernel归一化得到初始低通滤波
                    
                    mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat) #从hr_feat得到另一部分初始低通滤波特征
                    mask_lr_lr_feat = F.interpolate( #利用初始低通滤波对另一部分初始低通滤波特征上采样
                        carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
                    mask_lr = mask_lr_hr_feat + mask_lr_lr_feat #将两部分初始低通滤波特征合在一起

                    mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass) #得到初步融合的初始低通滤波
                    mask_hr_lr_feat = F.interpolate( #使用初始低通滤波对lr_feat处理,分辨率得到提高
                        carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
                    mask_hr = mask_hr_hr_feat + mask_hr_lr_feat # 最终高通滤波特征
                else: raise NotImplementedError
            else:
                mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
                if self.use_high_pass:
                    mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
        else:
            compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:], mode='nearest') + compressed_hr_feat
            mask_lr = self.content_encoder(compressed_x)
            if self.use_high_pass: 
                mask_hr = self.content_encoder2(compressed_x)
        
        mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)
        if self.semi_conv:
                lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
        else:
            lr_feat = resize(
                input=lr_feat,
                size=hr_feat.shape[2:],
                mode=self.upsample_mode,
                align_corners=None if self.upsample_mode == 'nearest' else self.align_corners)
            lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 1)

        if self.use_high_pass:
            mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)
            hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)
            if self.hr_residual:
                # print('using hr_residual')
                hr_feat = hr_feat_hf + hr_feat
            else:
                hr_feat = hr_feat_hf

        if self.feature_resample:
            # print(lr_feat.shape)
            lr_feat = self.dysampler(hr_x=compressed_hr_feat, 
                                     lr_x=compressed_lr_feat, feat2sample=lr_feat)
                
        return  mask_lr, hr_feat, lr_feat



class LocalSimGuidedSampler(nn.Module):
    """
    offset generator in FreqFusion
    """
    def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3, sim_type='cos', norm=True, direction_feat='sim_concat'):
        super().__init__()
        assert scale==2
        assert style=='lp'

        self.scale = scale
        self.style = style
        self.groups = groups
        self.local_window = local_window
        self.sim_type = sim_type
        self.direction_feat = direction_feat

        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2
        if self.direction_feat == 'sim':
            self.offset = nn.Conv2d(local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
        elif self.direction_feat == 'sim_concat':
            self.offset = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
        else: raise NotImplementedError
        normal_init(self.offset, std=0.001)
        if use_direct_scale:
            if self.direction_feat == 'sim':
                self.direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
            elif self.direction_feat == 'sim_concat':
                self.direct_scale = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
            else: raise NotImplementedError
            constant_init(self.direct_scale, val=0.)

        out_channels = 2 * groups
        if self.direction_feat == 'sim':
            self.hr_offset = nn.Conv2d(local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
        elif self.direction_feat == 'sim_concat':
            self.hr_offset = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
        else: raise NotImplementedError
        normal_init(self.hr_offset, std=0.001)
        
        if use_direct_scale:
            if self.direction_feat == 'sim':
                self.hr_direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
            elif self.direction_feat == 'sim_concat':
                self.hr_direct_scale = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
            else: raise NotImplementedError
            constant_init(self.hr_direct_scale, val=0.)

        self.norm = norm
        if self.norm:
            self.norm_hr = nn.GroupNorm(in_channels // 8, in_channels)
            self.norm_lr = nn.GroupNorm(in_channels // 8, in_channels)
        else:
            self.norm_hr = nn.Identity()
            self.norm_lr = nn.Identity()
        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)
    
    def sample(self, x, offset, scale=None):
        if scale is None: scale = self.scale
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), scale).view(
            B, 2, -1, scale * H, scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, x.size(-2), x.size(-1)), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, scale * H, scale * W)
    
    def forward(self, hr_x, lr_x, feat2sample):
        hr_x = self.norm_hr(hr_x)
        lr_x = self.norm_lr(lr_x)

        if self.direction_feat == 'sim':
            hr_sim = compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')
            lr_sim = compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')
        elif self.direction_feat == 'sim_concat':
            hr_sim = torch.cat([hr_x, compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')], dim=1)
            lr_sim = torch.cat([lr_x, compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')], dim=1)
            hr_x, lr_x = hr_sim, lr_sim
        # offset = self.get_offset(hr_x, lr_x)
        offset = self.get_offset_lp(hr_x, lr_x, hr_sim, lr_sim)
        return self.sample(feat2sample, offset)
    
    # def get_offset_lp(self, hr_x, lr_x):
    def get_offset_lp(self, hr_x, lr_x, hr_sim, lr_sim):
        if hasattr(self, 'direct_scale'):
            # offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos
            offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos
            # offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_sim) + F.pixel_unshuffle(self.hr_direct_scale(hr_sim), self.scale)).sigmoid() + self.init_pos
        else:
            offset =  (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * 0.25 + self.init_pos
        return offset

    def get_offset(self, hr_x, lr_x):
        if self.style == 'pl':
            raise NotImplementedError
        return self.get_offset_lp(hr_x, lr_x)
    

def compute_similarity(input_tensor, k=3, dilation=1, sim='cos'):
    """
    计算输入张量中每一点与周围KxK范围内的点的余弦相似度。

    参数:
    - input_tensor: 输入张量,形状为[B, C, H, W]
    - k: 范围大小,表示周围KxK范围内的点

    返回:
    - 输出张量,形状为[B, KxK-1, H, W]
    """
    B, C, H, W = input_tensor.shape
    # 使用零填充来处理边界情况
    # padded_input = F.pad(input_tensor, (k // 2, k // 2, k // 2, k // 2), mode='constant', value=0)

    # 展平输入张量中每个点及其周围KxK范围内的点
    unfold_tensor = F.unfold(input_tensor, k, padding=(k // 2) * dilation, dilation=dilation) # B, CxKxK, HW
    # print(unfold_tensor.shape)
    unfold_tensor = unfold_tensor.reshape(B, C, k**2, H, W)

    # 计算余弦相似度
    if sim == 'cos':
        similarity = F.cosine_similarity(unfold_tensor[:, :, k * k // 2:k * k // 2 + 1], unfold_tensor[:, :, :], dim=1)
    elif sim == 'dot':
        similarity = unfold_tensor[:, :, k * k // 2:k * k // 2 + 1] * unfold_tensor[:, :, :]
        similarity = similarity.sum(dim=1)
    else:
        raise NotImplementedError

    # 移除中心点的余弦相似度,得到[KxK-1]的结果
    similarity = torch.cat((similarity[:, :k * k // 2], similarity[:, k * k // 2 + 1:]), dim=1)

    # 将结果重塑回[B, KxK-1, H, W]的形状
    similarity = similarity.view(B, k * k - 1, H, W)
    return similarity


if __name__ == '__main__':
    # x = torch.rand(4, 128, 16, 16)
    # mask = torch.rand(4, 4 * 25, 16, 16)
    # carafe(x, mask, kernel_size=5, group=1, up=2)

    hr_feat = torch.rand(1, 128, 512, 512)
    lr_feat = torch.rand(1, 128, 256, 256)
    model = FreqFusion(hr_channels=128, lr_channels=128)
    mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat)
    print(mask_lr.shape)

以上是论文中整个FreqFusion的模块图和GitHub官方代码。其中代码的核心点在于:

FreqFusion类:负责整体的频率感知特征融合,结合了高频和低频信息以改进特征的类内一致性和边界清晰度。

LocalSimGuidedSampler类:偏移生成器,用于基于局部相似性指导重新采样特征。

辅助函数:如 carafe(动态上采样操作)、resize、hamming2D(生成二维Hamming窗)、compute_similarity 等,用于支持主要功能。

FreqFusion的结构

Adaptive Low-Pass Filter (ALPF) Generator

  • 输入:初始融合后的特征Zl
  • 过程:通过 3×3 卷积层和 softmax 操作生成空间可变的低通滤波器,对高层特征进行平滑和上采样。
  • 输出:平滑后的高层特征。

Offset Generator

  • 输入:初始融合后的特征Zl和局部相似度S。
  • 过程:预测偏移量,用于重新采样类内一致性较高的特征。
  • 输出:偏移后的特征,解决大面积不一致特征和细薄边界的问题。

Adaptive High-Pass Filter (AHPF) Generator

  • 输入:初始融合后的特征Zl
  • 过程:生成空间可变的高通滤波器,增强从低层特征中提取的边界细节。
  • 输出:增强边界的特征。

论文中模块图的具体流程(纯手敲)

图片[2] - Python AI C++笔记 - 【人工智能】关于FreqFusion.py官方代码的研究(修正版) - Python AI C++笔记 - 小竹の笔记本

整个模块输入两个特征图(Y^{l+1},X^l),输出一个融合后的特征图(Y^l)

Y^{l+1}对应代码中的hr_feat,X^l对应代码中的lr_feat,hr_feat和lr_feat要保持同样的通道数,但是后两个通道的大小(也就是特征图大小)hr_feat是lr_feat的二倍。

两个输入的特征图都经过一个1x1卷积层来实现通道压缩操作(r是一个参数:通道缩减率),然后lr_feat压缩后的特征图分别进入ALPFG和AHPFG生成了对应的自适应低通滤波器和自适应高通滤波器,这些滤波器将作为卷积核与hr_feat和lr_feat压缩后的特征进行卷积操作。之后图中下路与原始lr_feat特征进行逐像素相加之后与上路进行Pixel Shuffle(像素重排)(为了对粗糙的高级特征进行上采样)后的特征进行初始融合(也是逐像素相加),生成了Z^l(融合后的压缩特征)。

Z^l(融合后的压缩特征)将分别进入ALPFG,OffsetG,AHPFG,分别获得了对应的自适应低通滤波器,针对高级特征中每个像素的最终预测偏移量O^l(由偏移的方向D,控制偏移的幅度A决定)和自适应高通滤波器。OffsetG的作用就是为了解决ALPFG在对特征进行平滑处理时可能在纠正大面积不一致特征区域或者细化纤细及边界区域方面的不足。

由第二个ALPFG生成的高通滤波器作为卷积核对hr_feat进行卷积,得到特征图在通过Pixel Shuffle上采样恢复分辨率得到低分辨率特征,之后使用OffsetG生成的针对高级特征中每个像素的最终预测偏移量O^l进行重新采样操作生成了对其后的特征。

最终融合是将上面重新采样的特征与第二个AHPFG生成的低通滤波器作为卷积核对lr_feat进行卷积得到的特征图与lr_feat进行逐像素相加之后得到的高分辨率特征进行逐像素相加,最终得到输出Y^l

注释后的代码

# TPAMI 2024:Frequency-aware Feature Fusion for Dense Image Prediction

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import warnings
import numpy as np

# try:
#     from mmcv.ops.carafe import normal_init, xavier_init, carafe
# except ImportError:

def xavier_init(module: nn.Module,
                gain: float = 1,
                bias: float = 0,
                distribution: str = 'normal') -> None:
    assert distribution in ['uniform', 'normal']
    if hasattr(module, 'weight') and module.weight is not None:
        if distribution == 'uniform':
            nn.init.xavier_uniform_(module.weight, gain=gain)
        else:
            nn.init.xavier_normal_(module.weight, gain=gain)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def carafe(x, normed_mask, kernel_size, group=1, up=1):
    b, c, h, w = x.shape
    _, m_c, m_h, m_w = normed_mask.shape
    print('x', x.shape)
    print('normed_mask', normed_mask.shape)
    # assert m_c == kernel_size ** 2 * up ** 2
    assert m_h == up * h
    assert m_w == up * w
    pad = kernel_size // 2
    # print(pad)
    pad_x = F.pad(x, pad=[pad] * 4, mode='reflect')
    # print(pad_x.shape)
    unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0)
    # unfold_x = unfold_x.reshape(b, c, 1, kernel_size, kernel_size, h, w).repeat(1, 1, up ** 2, 1, 1, 1, 1)
    unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w)
    unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest')
    # normed_mask = normed_mask.reshape(b, 1, up ** 2, kernel_size, kernel_size, h, w)
    unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w)
    normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)
    res = unfold_x * normed_mask
    # test
    # res[:, :, 0] = 1
    # res[:, :, 1] = 2
    # res[:, :, 2] = 3
    # res[:, :, 3] = 4
    res = res.sum(dim=2).reshape(b, c, m_h, m_w)
    # res = F.pixel_shuffle(res, up)
    # print(res.shape)
    # print(res)
    return res


def normal_init(module, mean=0, std=1, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.normal_(module.weight, mean, std)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)

# 以上函数用于异常处理

def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None,
           warning=True):
    if warning:
        if size is not None and align_corners:
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
            output_h, output_w = tuple(int(x) for x in size)
            if output_h > input_h or output_w > input_w:
                if ((output_h > 1 and output_w > 1 and input_h > 1
                     and input_w > 1) and (output_h - 1) % (input_h - 1)
                        and (output_w - 1) % (input_w - 1)):
                    warnings.warn(
                        f'When align_corners={align_corners}, '
                        'the output would more aligned if '
                        f'input size {(input_h, input_w)} is `x+1` and '
                        f'out size {(output_h, output_w)} is `nx+1`')
    return F.interpolate(input, size, scale_factor, mode, align_corners)


def hamming2D(M, N):
    """
    生成二维Hamming窗

    参数:
    - M:窗口的行数
    - N:窗口的列数

    返回:
    - 二维Hamming窗
    """
    # 生成水平和垂直方向上的Hamming窗
    # hamming_x = np.blackman(M)
    # hamming_x = np.kaiser(M)
    hamming_x = np.hamming(M)
    hamming_y = np.hamming(N)
    # 通过外积生成二维Hamming窗
    hamming_2d = np.outer(hamming_x, hamming_y)
    return hamming_2d


class FreqFusion(nn.Module):
    def __init__(self,
                 hr_channels,
                 lr_channels,
                 scale_factor=1,
                 lowpass_kernel=5,
                 highpass_kernel=3,
                 up_group=1,
                 encoder_kernel=3,
                 encoder_dilation=1,
                 compressed_channels=64,
                 align_corners=False,  # 控制上采样过程中是否对齐角点。
                 upsample_mode='nearest',
                 feature_resample=False,
                 # use offset generator or not 如果设置为 True,将使用一个偏移生成器(LocalSimGuidedSampler)来引导特征图的重采样,以减少插值引起的信息损失。
                 feature_resample_group=4,
                 comp_feat_upsample=True,
                 # use ALPF & AHPF for init upsampling 是否在初始阶段使用高通和低通滤波器(ALPF 和 AHPF)对特征进行上采样。
                 use_high_pass=True,  # 是否使用高通滤波器提取高频信息。
                 use_low_pass=True,  # 是否使用低通滤波器提取低频信息。
                 hr_residual=True,  # 是否在高分辨率特征处理时使用残差连接。
                 semi_conv=True,  # 是否启用半卷积模式。半卷积模式下,特征图处理会结合滤波器生成的掩码,专注于部分区域的处理,从而减少不必要的计算,提升效率。
                 hamming_window=True,  # for regularization, do not matter really 是否使用 Hamming 窗对滤波器核进行正则化。
                 feature_resample_norm=True,  # 是否对重采样特征图进行归一化。
                 **kwargs):
        super().__init__()
        self.scale_factor = scale_factor
        self.lowpass_kernel = lowpass_kernel
        self.highpass_kernel = highpass_kernel
        self.up_group = up_group
        self.encoder_kernel = encoder_kernel
        self.encoder_dilation = encoder_dilation
        self.compressed_channels = compressed_channels
        self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels, 1)
        self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels, 1)
        self.content_encoder = nn.Conv2d(  # ALPF generator
            self.compressed_channels,
            lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
            self.encoder_kernel,
            padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
            dilation=self.encoder_dilation,
            groups=1)

        self.align_corners = align_corners
        self.upsample_mode = upsample_mode
        self.hr_residual = hr_residual
        self.use_high_pass = use_high_pass
        self.use_low_pass = use_low_pass
        self.semi_conv = semi_conv
        self.feature_resample = feature_resample
        self.comp_feat_upsample = comp_feat_upsample
        if self.feature_resample:
            self.dysampler = LocalSimGuidedSampler(in_channels=compressed_channels, scale=2, style='lp',
                                                   groups=feature_resample_group, use_direct_scale=True,
                                                   kernel_size=encoder_kernel, norm=feature_resample_norm)
        if self.use_high_pass:
            self.content_encoder2 = nn.Conv2d(  # AHPF generator
                self.compressed_channels,
                highpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
                self.encoder_kernel,
                padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
                dilation=self.encoder_dilation,
                groups=1)
        self.hamming_window = hamming_window
        lowpass_pad = 0
        highpass_pad = 0
        if self.hamming_window:
            self.register_buffer('hamming_lowpass', torch.FloatTensor(
                hamming2D(lowpass_kernel + 2 * lowpass_pad, lowpass_kernel + 2 * lowpass_pad))[None, None,])
            self.register_buffer('hamming_highpass', torch.FloatTensor(
                hamming2D(highpass_kernel + 2 * highpass_pad, highpass_kernel + 2 * highpass_pad))[None, None,])
        else:
            self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))
            self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            # print(m)
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')
        normal_init(self.content_encoder, std=0.001)
        if self.use_high_pass:
            normal_init(self.content_encoder2, std=0.001)

    def kernel_normalizer(self, mask, kernel, scale_factor=None, hamming=1):
        if scale_factor is not None:
            mask = F.pixel_shuffle(mask, self.scale_factor)
        n, mask_c, h, w = mask.size()
        mask_channel = int(mask_c / float(kernel ** 2))  # group
        # mask = mask.view(n, mask_channel, -1, h, w)
        # mask = F.softmax(mask, dim=2, dtype=mask.dtype)
        # mask = mask.view(n, mask_c, h, w).contiguous()

        mask = mask.view(n, mask_channel, -1, h, w)
        mask = F.softmax(mask, dim=2, dtype=mask.dtype)
        mask = mask.view(n, mask_channel, kernel, kernel, h, w)
        mask = mask.permute(0, 1, 4, 5, 2, 3).view(n, -1, kernel, kernel)
        # mask = F.pad(mask, pad=[padding] * 4, mode=self.padding_mode) # kernel + 2 * padding
        mask = mask * hamming
        mask /= mask.sum(dim=(-1, -2), keepdims=True)
        # print(hamming)
        # print(mask.shape)
        mask = mask.view(n, mask_channel, h, w, -1)
        mask = mask.permute(0, 1, 4, 2, 3).view(n, -1, h, w).contiguous()
        return mask

    def forward(self, hr_feat, lr_feat, use_checkpoint=False):  # use check_point to save GPU memory
        if use_checkpoint:
            return checkpoint(self._forward, hr_feat, lr_feat)
        else:
            return self._forward(hr_feat, lr_feat)

    def _forward(self, hr_feat, lr_feat):
        # 将高分辨率和低分辨率特征分别压缩到相同的通道数
        compressed_hr_feat = self.hr_channel_compressor(hr_feat)
        compressed_lr_feat = self.lr_channel_compressor(lr_feat)

        if self.semi_conv:  # 如果启用半卷积模式
            if self.comp_feat_upsample:  # 如果启用特征上采样
                if self.use_high_pass:  # 如果启用高通滤波
                    # 从高分辨率特征生成初始高通滤波掩码
                    mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat)
                    # 对初始高通滤波掩码进行归一化处理
                    mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel,
                                                          hamming=self.hamming_highpass)
                    # 使用初始高通滤波掩码增强压缩的高分辨率特征的高频信息
                    compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(
                        compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1)

                    # 从压缩的高分辨率特征生成初始低通滤波掩码
                    mask_lr_hr_feat = self.content_encoder(compressed_hr_feat)
                    # 对低通滤波掩码进行归一化处理
                    mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel,
                                                          hamming=self.hamming_lowpass)

                    # 从压缩的低分辨率特征生成另一部分初始低通滤波特征
                    mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat)
                    # 对低分辨率的低通特征进行上采样,使其与高分辨率特征对齐
                    mask_lr_lr_feat = F.interpolate(
                        carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2),
                        size=compressed_hr_feat.shape[-2:],
                        mode='nearest')
                    # 合并两部分低通滤波特征
                    mask_lr = mask_lr_hr_feat + mask_lr_lr_feat

                    # 再次对合并后的低通滤波掩码进行归一化处理
                    mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)

                    # 从低分辨率特征生成高通滤波特征,并使用低通掩码对其上采样
                    mask_hr_lr_feat = F.interpolate(
                        carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel,
                               self.up_group, 2),
                        size=compressed_hr_feat.shape[-2:],
                        mode='nearest')
                    # 合并高分辨率与低分辨率的高通滤波特征
                    mask_hr = mask_hr_hr_feat + mask_hr_lr_feat
                else:
                    raise NotImplementedError
            else:  # 如果未启用特征上采样
                mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(
                    self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
                if self.use_high_pass:
                    mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(
                        self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
        else:  # 如果未启用半卷积模式
            compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:],
                                         mode='nearest') + compressed_hr_feat
            mask_lr = self.content_encoder(compressed_x)
            if self.use_high_pass:
                mask_hr = self.content_encoder2(compressed_x)

        # 对低通滤波掩码进行归一化处理
        mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)

        if self.semi_conv:
            # 使用低通滤波掩码对低分辨率特征进行细化
            lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
        else:
            # 简单上采样低分辨率特征到高分辨率尺寸
            lr_feat = resize(
                input=lr_feat,
                size=hr_feat.shape[2:],
                mode=self.upsample_mode,
                align_corners=None if self.upsample_mode == 'nearest' else self.align_corners)
            lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 1)

        if self.use_high_pass:  # 如果启用高通滤波
            # 对高通滤波掩码进行归一化处理
            mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)
            # 对高分辨率特征应用高通滤波并生成高频增强特征
            hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)
            if self.hr_residual:  # 如果启用高分辨率残差连接
                hr_feat = hr_feat_hf + hr_feat
            else:
                hr_feat = hr_feat_hf

        if self.feature_resample:  # 如果启用特征重采样
            # 使用动态采样器对低分辨率特征进行细化
            lr_feat = self.dysampler(hr_x=compressed_hr_feat,
                                     lr_x=compressed_lr_feat, feat2sample=lr_feat)

        # 返回低通滤波掩码、高分辨率特征和低分辨率特征
        return mask_lr, hr_feat, lr_feat


class LocalSimGuidedSampler(nn.Module):
    """
    offset generator in FreqFusion
    """

    def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3,
                 sim_type='cos', norm=True, direction_feat='sim_concat'):
        super().__init__()
        assert scale == 2
        assert style == 'lp'

        self.scale = scale
        self.style = style
        self.groups = groups
        self.local_window = local_window
        self.sim_type = sim_type
        self.direction_feat = direction_feat

        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2
        if self.direction_feat == 'sim':
            self.offset = nn.Conv2d(local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
                                    padding=kernel_size // 2)
        elif self.direction_feat == 'sim_concat':
            self.offset = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
                                    padding=kernel_size // 2)
        else:
            raise NotImplementedError
        normal_init(self.offset, std=0.001)
        if use_direct_scale:
            if self.direction_feat == 'sim':
                self.direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                                              padding=kernel_size // 2)
            elif self.direction_feat == 'sim_concat':
                self.direct_scale = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels,
                                              kernel_size=kernel_size, padding=kernel_size // 2)
            else:
                raise NotImplementedError
            constant_init(self.direct_scale, val=0.)

        out_channels = 2 * groups
        if self.direction_feat == 'sim':
            self.hr_offset = nn.Conv2d(local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
                                       padding=kernel_size // 2)
        elif self.direction_feat == 'sim_concat':
            self.hr_offset = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
                                       padding=kernel_size // 2)
        else:
            raise NotImplementedError
        normal_init(self.hr_offset, std=0.001)

        if use_direct_scale:
            if self.direction_feat == 'sim':
                self.hr_direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                                                 padding=kernel_size // 2)
            elif self.direction_feat == 'sim_concat':
                self.hr_direct_scale = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels,
                                                 kernel_size=kernel_size, padding=kernel_size // 2)
            else:
                raise NotImplementedError
            constant_init(self.hr_direct_scale, val=0.)

        self.norm = norm
        if self.norm:
            self.norm_hr = nn.GroupNorm(in_channels // 8, in_channels)
            self.norm_lr = nn.GroupNorm(in_channels // 8, in_channels)
        else:
            self.norm_hr = nn.Identity()
            self.norm_lr = nn.Identity()
        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset, scale=None):
        if scale is None: scale = self.scale
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), scale).view(
            B, 2, -1, scale * H, scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, x.size(-2), x.size(-1)), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, scale * H, scale * W)

    def forward(self, hr_x, lr_x, feat2sample):
        hr_x = self.norm_hr(hr_x)
        lr_x = self.norm_lr(lr_x)

        if self.direction_feat == 'sim':
            hr_sim = compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')
            lr_sim = compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')
        elif self.direction_feat == 'sim_concat':
            hr_sim = torch.cat([hr_x, compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')], dim=1)
            lr_sim = torch.cat([lr_x, compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')], dim=1)
            hr_x, lr_x = hr_sim, lr_sim
        # offset = self.get_offset(hr_x, lr_x)
        offset = self.get_offset_lp(hr_x, lr_x, hr_sim, lr_sim)
        return self.sample(feat2sample, offset)

    # def get_offset_lp(self, hr_x, lr_x):
    def get_offset_lp(self, hr_x, lr_x, hr_sim, lr_sim):
        if hasattr(self, 'direct_scale'):
            # offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos
            offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (
                        self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x),
                                                                    self.scale)).sigmoid() + self.init_pos
            # offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_sim) + F.pixel_unshuffle(self.hr_direct_scale(hr_sim), self.scale)).sigmoid() + self.init_pos
        else:
            offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * 0.25 + self.init_pos
        return offset

    def get_offset(self, hr_x, lr_x):
        if self.style == 'pl':
            raise NotImplementedError
        return self.get_offset_lp(hr_x, lr_x)


def compute_similarity(input_tensor, k=3, dilation=1, sim='cos'):
    """
    计算输入张量中每一点与周围KxK范围内的点的余弦相似度。

    参数:
    - input_tensor: 输入张量,形状为[B, C, H, W]
    - k: 范围大小,表示周围KxK范围内的点

    返回:
    - 输出张量,形状为[B, KxK-1, H, W]
    """
    B, C, H, W = input_tensor.shape
    # 使用零填充来处理边界情况
    # padded_input = F.pad(input_tensor, (k // 2, k // 2, k // 2, k // 2), mode='constant', value=0)

    # 展平输入张量中每个点及其周围KxK范围内的点
    unfold_tensor = F.unfold(input_tensor, k, padding=(k // 2) * dilation, dilation=dilation)  # B, CxKxK, HW
    # print(unfold_tensor.shape)
    unfold_tensor = unfold_tensor.reshape(B, C, k ** 2, H, W)

    # 计算余弦相似度
    if sim == 'cos':
        similarity = F.cosine_similarity(unfold_tensor[:, :, k * k // 2:k * k // 2 + 1], unfold_tensor[:, :, :], dim=1)
    elif sim == 'dot':
        similarity = unfold_tensor[:, :, k * k // 2:k * k // 2 + 1] * unfold_tensor[:, :, :]
        similarity = similarity.sum(dim=1)
    else:
        raise NotImplementedError

    # 移除中心点的余弦相似度,得到[KxK-1]的结果
    similarity = torch.cat((similarity[:, :k * k // 2], similarity[:, k * k // 2 + 1:]), dim=1)

    # 将结果重塑回[B, KxK-1, H, W]的形状
    similarity = similarity.view(B, k * k - 1, H, W)
    return similarity


if __name__ == '__main__':
    # x = torch.rand(4, 128, 16, 16)
    # mask = torch.rand(4, 4 * 25, 16, 16)
    # carafe(x, mask, kernel_size=5, group=1, up=2)

    hr_feat = torch.rand(1, 128, 512, 512)
    lr_feat = torch.rand(1, 128, 256, 256)
    model = FreqFusion(hr_channels=128, lr_channels=128)
    mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat)
    print(mask_lr.shape)

其中FreqFusion的参数:

图片[3] - Python AI C++笔记 - 【人工智能】关于FreqFusion.py官方代码的研究(修正版) - Python AI C++笔记 - 小竹の笔记本

官方FreqFusion使用示例

小竹提示:在真正使用过程中一定要注意hr_feat和lr_feat这两个输入的特征的通道数量一致,且hr_feat的特征图大小为lr_feat的二倍。这个一致的通道数也作为FreqFusion的前两个参数保持一致。

以下是官方GitHub仓库的readme中的示例:

FreqFusion的简洁代码可在此处获得。通过利用它们的频率特性,FreqFusion能够增强低分辨率和高分辨率特征的质量,用法非常简单。

ff = FreqFusion(hr_channels=64, lr_channels=64)
hr_feat = torch.rand(1, 64, 32, 32)
lr_feat = torch.rand(1, 64, 16, 16)
_, hr_feat, lr_feat = ff(hr_feat=hr_feat, lr_feat=lr_feat) # lr_feat [1, 64, 32, 32]

我应该在哪里集成 FreqFusion?

您应该在需要执行上采样的任何地方集成 FreqFusion。FreqFusion 能够充分利用低分辨率和高分辨率特征,它可以非常有效地从低分辨率高级特征中恢复高分辨率、语义准确的特征,同时增强高分辨率低级特征的细节。

特征融合的 concat 版本示例(SegNeXt、SegFormer):

您可以参考ham_head.py

x1, x2, x3, x4 = backbone(img) #x1, x2, x3, x4 in 1/4, 1/8, 1/16, 1/32
x1, x2, x3, x4 = conv1x1(x1), conv1x1(x2), conv1x1(x3), conv1x1(x4) # channel=c
ff1 = FreqFusion(hr_channels=c, lr_channels=c)
ff2 = FreqFusion(hr_channels=c, lr_channels=2 * c)
ff3 = FreqFusion(hr_channels=c, lr_channels=3 * c)
_, x3, x4_up = ff1(hr_feat=x3, lr_feat=x4)
_, x2, x34_up = ff2(hr_feat=x2, lr_feat=torch.cat([x3, x4_up]))
_, x1, x234_up = ff3(hr_feat=x1, lr_feat=torch.cat([x2, x34_up]))
x1234 = torch.cat([x1, x234_up] # channel=4c, 1/4 img size

特征融合的 concat 版本的另一个示例(您可以尝试 UNet):

x1, x2, x3, x4 = backbone(img) #x1, x2, x3, x4 in 1/4, 1/8, 1/16, 1/32
x1, x2, x3, x4 = conv1x1(x1), conv1x1(x2), conv1x1(x3), conv1x1(x4) # conv1x1s in original FPN to align channel=c
ff1 = FreqFusion(hr_channels=c, lr_channels=c)
ff2 = FreqFusion(hr_channels=c, lr_channels=c)
ff3 = FreqFusion(hr_channels=c, lr_channels=c)
y4 = x4 # channel=c
_, x3, y4_up = ff1(hr_feat=x3, lr_feat=y4)
y3 = conv(torch.cat([x3 + y4_up])) # channel=c
_, x2, y3_up = ff2(hr_feat=x2, lr_feat=y3)
y2 = conv(torch.cat([x2 + y3_up])) # channel=c
_, x2, y2_up = ff3(hr_feat=x1, lr_feat=y2)
y1 = conv(torch.cat([x1 + y2_up])) # channel=c

特征融合的添加版本示例(基于 FPN 的方法):

您可以参考FPN.py。

x1, x2, x3, x4 = backbone(img) #x1, x2, x3, x4 in 1/4, 1/8, 1/16, 1/32
x1, x2, x3, x4 = conv1x1(x1), conv1x1(x2), conv1x1(x3), conv1x1(x4) # conv1x1s in original FPN to align channel=c
ff1 = FreqFusion(hr_channels=c, lr_channels=c)
ff2 = FreqFusion(hr_channels=c, lr_channels=c)
ff3 = FreqFusion(hr_channels=c, lr_channels=c)
y4 = x4
_, x3, y4_up = ff1(hr_feat=x3, lr_feat=y4)
y3 = x3 + y4_up
_, x2, y3_up = ff2(hr_feat=x2, lr_feat=y3)
y2 = x2 + y3_up
_, x2, y2_up = ff3(hr_feat=x1, lr_feat=y2)
y1 = x1 + y2_up

我自己在ConvLSRNet中集成了FreqFusion,最终实验结果并不好,降低了精度。

© 版权声明
THE END
点赞8 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片快捷回复

    暂无评论内容