Optimal Mass Transport for Shape Progression Study

PyTorch Implementation of Benamou-Brenier Formulation

optimal transport
shape morphing
Benamou-Brenier's Formulation
pytorch
automatic differentiation
Author
Published

December 16, 2024

Introduction

In this blog post, let us continue from the previous blog post on defining the convex formulation of optimal mass transport (OMT). Specifically, let us look upon the need for application of OMT for interpolation of shapes and learn to implement Benamou-Brenier’s formulation of OMT in python. The problem is essentially a PDE optimization process and hence the numerical solution is expensive. The aglorithm used here is based on an efficient method as described in (Haber, Rehman, and Tannenbaum (2010)). It relies on augmented Lagrangian approach to efficiently solve the constrained optimization. Also, the use of graphics processing unit (GPU) has been evident in accelerating the computation time. (Rehman et al. (2009)) Hence, PyTorch is used for programming the problem here, as it is an excellent easy to use scientific computing framework with the feature of automatic differentiation (Neidinger (2010)) accelerated by GPUs for tensor operations.

Mathematical Formulation

The Benamou-Brenier formulation (Benamou and Brenier (2000)) is used, which interprets it as a fluid flow problem. This approach finds the path of continual mass transfer such that minimal “kinetic energy” needed to transform from one state to another.

The Benamou-Brenier formulation considers a probability density \(\rho(x, t)\) evolving over time \(t \in [0, 1]\) from an initial distribution \(\rho_0\) to a final distribution \(\rho_1\). The goal is to find a velocity field \(v(x, t)\) that minimizes the action, or “kinetic energy” cost:

\[ \min_{\rho, v} \int_0^1 \int_X \frac{1}{2} \|v(x, t)\|^2 \rho(x, t) \, dx \, dt, \]

subject to the continuity equation:

\[ \frac{\partial \rho}{\partial t} + \nabla \cdot (\rho v) = 0, \]

which ensures mass conservation from \(\rho_0\) to \(\rho_1\).

The problem is reposed in terms of momentum in x and y respectively as \(m_x = \rho v_x\) and \(m_y = \rho v_y\) as in (Haber, Rehman, and Tannenbaum (2010)). Staggered grid is used for \(m_x\), \(m_y\), and \(\rho\) to stabilize non physical checkerboard instability.

So, the problem is rewritten as:

\[ \min_{m} \int_0^1 \int_{\Omega} \frac{1}{\rho} (m_x^2 + m_y^2) \, dx \, dt, \]

such that

\[ \Delta_x m_x + \Delta_y m_y + \Delta_t \rho = 0, \]

where \(\rho\)(x,y,0) = Image at time 0, and \(\rho\)(x,y,1) = Image at time 1.

Formulating this constrained optimization problem into a KKT problem by Augmented Lagrangian method (Nocedal and Wright (1999)) which is then optimized by Adam’s stochastic gradient descent optimization. (Kingma (2014))

Case Study: Need for a Special Interpolation

Let us take the MRI image samples (Figure 1) of the vertical and horizontal cross-sections a kid’s brain at the age of 5 years, 7 years and 9 years, as shared by (Bastos et al. (2020)). What if we did not have the MRI scans when the kid was 7 year old? Specifically, if the kid has a disease. If we have an interpolation technique, we can use the scans at the ages 5 and 9 years and look at the sequence of interpolated scans to study which part of brain is under compression and gauge how it affected over time.

Figure 1: State of neurodegeneration in a kid at different ages. (Bastos et al. (2020)) OT can learn the progression or the transformation (T) of brain deformation from the state at 5 year age (\(\rho_0\)) to the final state at 7 year age (\(\rho_1\)) or 9 year age (\(\rho_2\)).

The popular interpolation linear interpolation or any polynomial interpolation is not the best choice in such a case. To study the concept and learn why, let us take a toy case of aging cell. Essentially to generate the interpolation of the shape of a living cell from the images of a young and healthy cell to an old deformed cell.

(a) A young and healthy cell.
(b) An old deformed cell.
Figure 2: Senescence of a toy cell

Linear Interpolation

The basic and most simple method of interpolation is linear interpolation. It is a commonly used technique for generating intermediate images between two given images. It assumes a linear blend of pixel intensities over the interpolation parameter ‘t’, which is time. Mathematically, the equation of the interpolant(\(I_t\)) is given by:

\[ I_t = (1 - t) I_0 + t I_1, \] where \(I_0\) is the image at time 0 and \(I_1\) is the image at time 1.

Figure 3: Interpolated shape of the toy cell at 5 different intermediate stages.

The Figure 3 shows the interpolation of the images of the toy cell when it is healthy and when it is old. It can be seen that what essentially happens is a smooth transition between the two images by linearly interpolating the intensity values of each pixel independently. The transform in shape is not something we see in real life. It is not really a physical transformation. The masses/pixels do not really transport in spatial domain, rather they teleport. This kind of interpolation is hence quite often not beneficial. In the case of MRI scans, Figure 4 shows how the linearly interpolated shape at different ages do not make sense physically. Let us implement an OMT based interpolation on a to a toy problem, whose initial and final images are randomly sampled from MNIST digits, can be seen in Figure 5.

Figure 4: Linear interpolation of brain MRI at 5 different ages.

PyTorch Implementation of OMT

Example Images

Figure 5: Images to use for OMT.

Click to download the first image file.

Click to download the second image file.

Import Libraries

Code
import matplotlib.pyplot as plt
import torch.optim as optim
import torch
from torch import nn
from torchvision.io import read_image
import torch.nn.functional as F

The most import library, here, is torch as it has the ability to perform automatic differentiation and GPU accelerated tensor operations. Torchvision is required for additional image operations and matplotlib for plots.

OMT Class

Code
class omtFD(nn.Module):
    def __init__(self, n1, n2, n3, bs=1):
        super().__init__()
        self.n1 = n1
        self.n2 = n2
        self.n3 = n3
        self.bs = bs

        self.m1 = nn.Parameter(torch.zeros(bs, 1, n1 - 1, n2, n3))
        self.m2 = nn.Parameter(torch.zeros(bs, 1, n1, n2 - 1, n3))
        self.rho = nn.Parameter(torch.zeros(bs, 1, n1, n2, n3 - 1) - 3)

    def OMTconstraintMF(self, I1, I2):
        dx, dy, dt = 1 / self.n1, 1 / self.n2, 1 / self.n3

        ox = torch.zeros(self.bs, 1, 1, self.n2, self.n3).to(self.rho.device)
        oy = torch.zeros(self.bs, 1, self.n1, 1, self.n3).to(self.rho.device)

        I1 = I1.unsqueeze(-1)
        I2 = I2.unsqueeze(-1)
        rhop = torch.exp(self.rho)

        # Add I1 and I2 to rho
        rhoBC = torch.cat((I1, rhop, I2), dim=-1)
        m1BC = torch.cat((ox, self.m1, ox), dim=-3)
        m2BC = torch.cat((oy, self.m2, oy), dim=-2)

        # Compute the Div
        m1x = (m1BC[:, :, :-1, :, :] - m1BC[:, :, 1:, :, :]) / dx
        m2y = (m2BC[:, :, :, :-1, :] - m2BC[:, :, :, 1:, :]) / dy
        rhot = (rhoBC[:, :, :, :, :-1] - rhoBC[:, :, :, :, 1:]) / dt
        return m1x + m2y + rhot

    def OMTgradRho(self, I1, I2):
        I1 = I1.unsqueeze(-1)
        I2 = I2.unsqueeze(-1)

        # Add I1 and I2 to rho
        rhop = torch.exp(self.rho)
        rhoBC = torch.cat((I1, rhop, I2), dim=-1)
        rhox = (rhoBC[:, :, :-1, :, :] - rhoBC[:, :, 1:, :, :])
        rhoy = (rhoBC[:, :, :, :-1, :] - rhoBC[:, :, :, 1:, :])
        rhot = (rhoBC[:, :, :, :, :-1] - rhoBC[:, :, :, :, 1:])

        return rhox, rhoy, rhot

    def OMTobjfunMF(self, I1, I2):
        I1 = I1.unsqueeze(-1)
        I2 = I2.unsqueeze(-1)

        rhop = torch.exp(self.rho)

        ox = torch.zeros(self.bs, 1, 1, self.n2, self.n3).to(self.rho.device)
        oy = torch.zeros(self.bs, 1, self.n1, 1, self.n3).to(self.rho.device)

        # Add I1 and I2 to rho
        rhoBC = torch.cat((I1, rhop, I2), dim=-1)
        m1BC = torch.cat((ox, self.m1, ox), dim=-3)
        m2BC = torch.cat((oy, self.m2, oy), dim=-2)

        # Average quantities to cell center
        m1c = (m1BC[:, :, :-1, :, :] + m1BC[:, :, 1:, :, :]) / 2
        m2c = (m2BC[:, :, :, :-1, :] + m2BC[:, :, :, 1:, :]) / 2
        rhoc = (rhoBC[:, :, :, :, :-1] + rhoBC[:, :, :, :, 1:]) / 2

        f = ((m1c ** 2 + m2c ** 2) / rhoc).mean()
        return f

This class models the dynamics of mass transport with momentum and density fields while ensuring mass conservation via the continuity equation. The learnable parameters defined here are \(m_1\), \(m_2\), and \(\rho\). The specific class functions are OMTconstraintMF, OMTgradRho, and OMTobjfunMF. A common line you can notice is rhop = torch.exp(self.rho), which ensures that rhop is used instead of rho so as to let only positive values to the PDE system so as to conserve the convexity of optimization. OMTconstraintMF calculates the residual of mass conservation by first order finite difference, which is later denoted as constraint(c). OMTgradRho calculates the gradients of the density, which is denoted by rhop. OMTobjfunMF calculates the objective functional (f) which is the kinetic energy as per Benamou-Brenier’s formulation. In this function, to control checkerboard oscillations, the variables are calculated at cell centres.

Function to Generate OMT Interpolation

Code
def LearnOMT(x0, x1, nt=20, device='cpu'):
    t = torch.linspace(0, 1, nt)

    eps = 1e-2

    I1_min = x0.min()
    I2_min = x1.min()
    I1_mean = (x0 - I1_min + eps).mean()
    I2_mean = (x1 - I2_min + eps).mean()

    I1 = (x0 - I1_min + eps) / I1_mean
    I2 = (x1 - I2_min + eps) / I2_mean

    bs = I1.shape[0]
    ch = I1.shape[1]
    n1 = I1.shape[2]
    n2 = I1.shape[3]
    n3 = nt

    omt = omtFD(n1, n2, n3, bs=bs * ch).to(device)

    # input shape for omt: (bs, 1, nx, ny)
    # final output shape for omt: (bs, 1, nx, ny, nt)

    I1 = I1.reshape([bs * ch, 1, n1, n2])
    I2 = I2.reshape([bs * ch, 1, n1, n2])

    optimizer = optim.Adam(omt.parameters(), lr=1e-2)

    num_epochs = 1000  # with incomplete optimization you do not see the transport, rather it appears as one dims and the other brightens up
    opt_iter = 20
    torch.autograd.set_detect_anomaly(True)

    # Train the model
    mu = 1e-2
    delta = 1e-1

    # initialize lagrange multiplier
    p = torch.zeros(n1, n2, n3).to(device)
    for epoch in range(num_epochs):
        # Optimization iterations for fixed p
        for jj in range(opt_iter):
            # Zero the parameter gradients
            optimizer.zero_grad()

            # function evaluation
            f = omt.OMTobjfunMF(I1, I2)
            # Constraint
            c = omt.OMTconstraintMF(I1, I2)

            const = -(p * c).mean() + 0.5 * mu * (c ** 2).mean()

            loss = f + const

            # Backward pass and optimize
            loss.backward()
            torch.nn.utils.clip_grad_value_(omt.parameters(), clip_value=0.5)
            optimizer.step()

        #print('f:', f.item(), '\tconst:', const.item())

        # Update Lagrange multiplier
        with torch.no_grad():
            p = p - delta * mu * c

    rho = omt.rho.detach()
    rhop = torch.exp(rho)

    # Add I1 and I2 to rho
    rhop = torch.cat((I1.unsqueeze(-1), rhop, I2.unsqueeze(-1)), dim=-1)
    rhop = rhop.reshape([bs, ch, n1, n2, n3 + 1])
    rhop = rhop * I1_mean - eps + I1_min  # assuming mass of x0 = mass of x1
    rhop = rhop.detach()   # shape of rhop: (bs, ch, nx, ny, nt+1) includes x0, "nt-1" timesteps, x1

    return rhop

This module starts with normalizations to match the masses of two images and add a background mass. The input tensors x0 and x1 (representing initial and final states) needs to have positive values and scale them to have a mean of 1. This ensures numerical stability during the optimization. The Adam optimizer is set up to train the parameters of the omtFD class.

Training Loop:

The training hyperparameters are defined: num_epochs (the number of optimization epochs), opt_iter (the number of inner iterations per epoch), and mu and delta (regularization parameters for the Lagrange multiplier). The Lagrange multiplier p is initialized to enforce the mass conservation constraint. For each epoch, the following steps are performed for opt_iter iterations:

  • The OMT objective function (f) is computed using omt.OMTobjfunMF.
  • The mass conservation constraint (c) is evaluated using omt.OMTconstraintMF.
  • The total loss is computed as the sum of the objective (f) and the constraint penalty, weighted by p and mu.
  • Backpropagation is performed, gradients are clipped for stability, and the model parameters are updated.
  • The Lagrange multiplier p is updated gradually to enforce the constraint over epochs.

Output Density Interpolation: The optimized density field rho is extracted from the model and exponentiated to compute the final density rhop. The initial and final states (I1 and I2) are concatenated with rhop to include all intermediate densities, forming the complete transport trajectory. The density trajectory rhop is rescaled back to the original input scale and format, ensuring mass conservation. The interpolated density trajectory rhop is returned, with a shape of (batch_size, channels, n1, n2, nt+1). This includes the initial state (x0), the intermediate densities, and the final state (x1).

Running OMT

Code
device = 'cuda:0' # default: 'cpu'

x0 = torch.load('x0.pt')
x1 = torch.load('x1.pt')
# shape of x0 and x1 = [:,:]

size_img = 64
x0 = F.interpolate(x0.unsqueeze(0).unsqueeze(0), size = [size_img,size_img])
x1 = F.interpolate(x1.unsqueeze(0).unsqueeze(0), size = [size_img,size_img])

x0 = x0.to(device)
x1 = x1.to(device)

OMTInterpolation = LearnOMT(x0, x1, nt=10, device = device)
# shape of Interpolation = [|B|,|C|,|X|,|Y|,|T|]

Generate Linear Interpolation for Compare

Code
OMTInterpolation = OMTInterpolation[0,0].detach().cpu().numpy()
x0 = x0[0,0].unsqueeze(-1)
x1 = x1[0,0].unsqueeze(-1)
Frac = torch.linspace(0,1,OMTInterpolation.shape[-1])
Frac = Frac.unsqueeze(0).unsqueeze(0)

# Linear Interpolation
LinearInt = x1*Frac + (1-Frac)*x0

LinearInt = LinearInt.detach().cpu().numpy()

Visualization

Code
# shape of Interpolation = [|B|,|C|,|X|,|Y|,|T|]
fig, ax = plt.subplots(nrows=2, ncols=OMTInterpolation.shape[-1], figsize=(50, 5))

for i in range(OMTInterpolation.shape[-1]):
    ax[0, i].imshow(LinearInt[:, :, i])
    ax[0, i].axis('off')
    ax[1, i].imshow(OMTInterpolation[:, :, i])
    ax[1, i].axis('off')

ax[0,0].set_title('Linear Interpolation', loc='left')
ax[1,0].set_title('OMT Interpolation', loc='left')

plt.show()

Conclusion

Figure 6: Comparison of linear interpolation with OMT interpolation

The application of Benamou-Brenier’s optimal mass transport (OMT) to model the shape progression demonstrates the power of combining mathematical rigor with modern computational tools. By leveraging the augmented Lagrangian method for efficient PDE-constrained optimization and the GPU acceleration capabilities of PyTorch, we achieved a computationally viable approach to interpolating MRI images across time. This study highlights the potential of OMT in predicting missing temporal states in medical imaging, such as generating plausible MRI scans at intermediate ages. Interpolation result can be observed in Figure 6. The OMT interpolation looks very physical. Such advancements can contribute significantly to understanding shape morphing and interpolation. Future work could explore extending this approach to complex 2D shapes or 3D volumes like those in MRI scans, integrating other imaging modalities, and refining the method to address even larger datasets.

References

Bastos, Filipa, Mathieu Quinodoz, Marie-Claude Addor, Beryl Royer-Bertrand, Heidi Fodstad, Carlo Rivolta, Claudia Poloni, Andrea Superti-Furga, Eliane Roulet-Perez, and Sébastien Lebon. 2020. “Childhood Neurodegeneration Associated with a Specific UBTF Variant: A New Case Report and Review of the Literature.” BMC Neurology 20 (January).
Benamou, Jean-David, and Yann Brenier. 2000. “A Computational Fluid Mechanics Solution to the Monge-Kantorovich Mass Transfer Problem.” Numerische Mathematik 84 (3): 375–93.
Haber, Eldad, Tauseef Rehman, and Allen Tannenbaum. 2010. “An Efficient Numerical Method for the Solution of the L2 Optimal Mass Transfer Problem.” SIAM Journal on Scientific Computing : A Publication of the Society for Industrial and Applied Mathematics 32 (January): 197–211. https://doi.org/10.1137/080730238.
Kingma, Diederik P. 2014. “Adam: A Method for Stochastic Optimization.” arXiv Preprint arXiv:1412.6980.
Neidinger, Richard D. 2010. “Introduction to Automatic Differentiation and MATLAB Object-Oriented Programming.” SIAM Review 52 (3): 545–63.
Nocedal, Jorge, and Stephen J Wright. 1999. Numerical Optimization. Springer.
Rehman, Tauseef ur, Eldad Haber, Gallagher Pryor, John Melonakos, and Allen Tannenbaum. 2009. “3D Nonrigid Registration via Optimal Mass Transport on the GPU.” Medical Image Analysis 13 (6): 931–40.