Source code for fvdb.utils.metrics.ssim

# This file contains source code from the fused-ssim library obtained from
# https://github.com/rahul-goel/fused-ssim. The fused-ssim library is licensed under the MIT
# License. Refer to ORSB 5512107 for more. Original license text follows.

# Copyright (c) 2024 Rahul Goel

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#


from typing import NamedTuple

import torch

import fvdb

allowed_padding = ["same", "valid"]


class FusedSSIMMap(torch.autograd.Function):
    @staticmethod
    def forward(ctx, C1, C2, img1, img2, padding="same", train=True):
        (
            ssim_map,
            dm_dmu1,
            dm_dsigma1_sq,
            dm_dsigma12,
        ) = torch.ops.fvdb._fused_ssim.default(C1, C2, img1, img2, train)

        if padding == "valid":
            ssim_map = ssim_map[:, :, 5:-5, 5:-5]

        ctx.save_for_backward(img1.detach(), img2, dm_dmu1, dm_dsigma1_sq, dm_dsigma12)
        ctx.C1 = C1
        ctx.C2 = C2
        ctx.padding = padding

        return ssim_map

    @staticmethod
    def backward(ctx, opt_grad):
        img1, img2, dm_dmu1, dm_dsigma1_sq, dm_dsigma12 = ctx.saved_tensors
        C1, C2, padding = ctx.C1, ctx.C2, ctx.padding
        dL_dmap = opt_grad
        if padding == "valid":
            dL_dmap = torch.zeros_like(img1)
            dL_dmap[:, :, 5:-5, 5:-5] = opt_grad
        grad = torch.ops.fvdb._fused_ssim_backward.default(
            C1, C2, img1, img2, dL_dmap, dm_dmu1, dm_dsigma1_sq, dm_dsigma12
        )
        return None, None, grad, None, None, None


def fused_ssim(img1, img2, padding="same", train=True):
    C1 = 0.01**2
    C2 = 0.03**2

    assert padding in allowed_padding

    img1 = img1.contiguous()
    map = FusedSSIMMap.apply(C1, C2, img1, img2, padding, train)
    return map.mean()  # type: ignore


from typing import Literal


[docs] def ssim( img1: torch.Tensor, img2: torch.Tensor, padding: Literal["same", "valid"] = "same", train: bool = True, ) -> torch.Tensor: """ Compute the Structural Similarity Index (SSIM) between two images. Args: img1 (torch.Tensor): A batch of images of shape ``(B, C, H, W)`` img2 (torch.Tensor): A batch of images of shape ``(B, C, H, W)`` padding (str): The padding to use for the images (``"same"`` or ``"valid"``). Default is ``"same"``. train (bool): Whether or not to compute the gradients through the SSIM loss. Default is ``True``. Returns: ssim (torch.Tensor): The average SSIM between each image over the batch. """ return fused_ssim(img1, img2, padding, train)