Launch this notebook on on mybinder.org:

https://mybinder.org/badge_logo.svg

Transport Map from density#

The objective of this example is to show how a transport map can be build in MParT when the the unnormalized probability density function of the target density is known.

Problem description#

We consider T(z;w) a monotone triangular transport map parameterized by w (e.g., polynomial coefficients). This map which is invertible and has an invertible Jacobian for any parameter w, transports samples zi from the reference density η to samples T(zi;w) from the map induced density π~w(z) defined as:

π~w(z)=η(T1(z;w))|det T1(z;w)|,

where det T1 is the determinant of the inverse map Jacobian at the point z. We refer to π~w(x) as the map-induced density or pushforward distribution and will commonly interchange notation for densities and measures to use the notation π~=Tη.

The objective of this example is, knowing some unnormalized target density π¯, find the map T that transport samples drawn from η to samples drawn from the target π.

Imports#

First, import MParT and other packages used in this notebook. Note that it is possible to specify the number of threads used by MParT by setting the KOKKOS_NUM_THREADS environment variable before importing MParT.

[1]:
import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

import os
os.environ['KOKKOS_NUM_THREADS'] = '8'

import mpart as mt
print('Kokkos is using', mt.Concurrency(), 'threads')
plt.rcParams['figure.dpi'] = 110
Kokkos::OpenMP::initialize WARNING: You are likely oversubscribing your CPU cores.
  process threads available :   4,  requested thread :   8
Kokkos::OpenMP::initialize WARNING: You are likely oversubscribing your CPU cores.
                                    Detected: 4 cores per node.
                                    Detected: 1 MPI_ranks per node.
                                    Requested: 8 threads per process.
Kokkos is using 8 threads

Target density and exact map#

In this example we use a 2D target density known as the banana density where the unnormalized probability density, samples and the exact transport map are known.

The banana density is defined as:

π(x1,x2)N1(x1)×N1(x2x12)

where N1 is the 1D standard normal density.

The exact transport map that transport the 2D standard normal density to π is known as:

Ttrue(z1,z2)=[z1z2+z12]

Contours of the target density can be visualized as:

[2]:
# Unnomalized target density required for objective
def target_logpdf(x):
  rv1 = multivariate_normal(np.zeros(1),np.eye(1))
  rv2 = multivariate_normal(np.zeros(1),np.eye(1))
  logpdf1 = rv1.logpdf(x[0])
  logpdf2 = rv2.logpdf(x[1]-x[0]**2)
  logpdf = logpdf1 + logpdf2
  return logpdf

# Grid for plotting
ngrid=100
x1_t = np.linspace(-3,3,ngrid)
x2_t = np.linspace(-3,7.5,ngrid)
xx1,xx2 = np.meshgrid(x1_t,x2_t)

xx = np.vstack((xx1.reshape(1,-1),xx2.reshape(1,-1)))

# Target contours
target_pdf_at_grid = np.exp(target_logpdf(xx))

fig, ax = plt.subplots()
CS1 = ax.contour(xx1, xx2, target_pdf_at_grid.reshape(ngrid,ngrid))
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
h1,_ = CS1.legend_elements()
legend1 = ax.legend([h1[0]], ['target density'])
plt.show()
../../../_images/source_tutorials_python_FromDensity2D_banana_6_0.png

Map training#

Defining objective function and its gradient#

Knowing the closed form of the unnormalized target density π¯, the objective is to find a map-induced density π~w(z) that is a good approximation of the target π.

In order to characterize this posterior density, one method is to build a monotone triangular transport map T such that the KL divergence DKL(η||Tπ) is minmized. If T is map parameterized by w, the objective function derived from the discrete KL divergence reads:

J(w)=1Ni=1N(logπ(T(zi;w))+logdet zT(zi;w)),ziN(0,Id),

where T is the transport map pushing forward the standard normal N(0,Id) to the target density π(z). The gradient of this objective function reads

wJ(w)=1Ni=1N(wT(zi;w).xlogπ(T(zi;w))+wlogdet zT(zi;w)),ziN(0,Id).

The objective function and gradient can be defined using MParT as:

[3]:
# KL divergence objective
def obj(coeffs, transport_map, x):
    num_points = x.shape[1]
    transport_map.SetCoeffs(coeffs)
    map_of_x = transport_map.Evaluate(x)
    logpdf= target_logpdf(map_of_x)
    log_det = transport_map.LogDeterminant(x)
    return -np.sum(logpdf + log_det)/num_points

# Gradient of unnomalized target density required for gradient objective
def target_grad_logpdf(x):
  grad1 = -x[0,:] + (2*x[0,:]*(x[1,:]-x[0,:]**2))
  grad2 = (x[0,:]**2-x[1,:])
  return np.vstack((grad1,grad2))

# Gradient of KL divergence objective
def grad_obj(coeffs, transport_map, x):
    num_points = x.shape[1]
    transport_map.SetCoeffs(coeffs)
    map_of_x = transport_map.Evaluate(x)
    sens_vecs = target_grad_logpdf(map_of_x)
    grad_logpdf = transport_map.CoeffGrad(x, sens_vecs)
    grad_log_det = transport_map.LogDeterminantCoeffGrad(x)
    return -np.sum(grad_logpdf + grad_log_det, 1)/num_points

Map parameterization#

For the parameterization of T we use a total order multivariate expansion of hermite functions. Knowing Ttrue, any parameterization with total order greater than one will include the true solution of the map finding problem.

[4]:
# Set-up first component and initialize map coefficients
map_options = mt.MapOptions()

total_order = 2

# Create dimension 2 triangular map
transport_map = mt.CreateTriangular(2,2,total_order,map_options)

Approximation before optimization#

Coefficients of triangular map are set to 0 upon creation.

[5]:
# Make reference samples for training
num_points = 10000
z = np.random.randn(2,num_points)

# Make reference samples for testing
test_z = np.random.randn(2,5000)

# Pushed samples
x = transport_map.Evaluate(test_z)

# Before optimization plot
plt.figure()
plt.contour(xx1, xx2, target_pdf_at_grid.reshape(ngrid,ngrid))
plt.scatter(x[0],x[1], facecolor='blue', alpha=0.1, label='Pushed samples')
plt.legend()
plt.show()
../../../_images/source_tutorials_python_FromDensity2D_banana_14_0.png

At initialization, samples are “far” from being distributed according to the banana distribution.

Initial objective and coefficients:

[6]:
# Print initial coeffs and objective
print('==================')
print('Starting coeffs')
print(transport_map.CoeffMap())
print('Initial objective value: {:.2E}'.format(obj(transport_map.CoeffMap(), transport_map, test_z)))
print('==================')
==================
Starting coeffs
[0. 0. 0. 0. 0. 0. 0. 0. 0.]
Initial objective value: 3.40E+00
==================

Minimization#

[7]:
print('==================')
options={'gtol': 1e-4, 'disp': True}
res = minimize(obj, transport_map.CoeffMap(), args=(transport_map, z), jac=grad_obj, method='BFGS', options=options)

# Print final coeffs and objective
print('Final coeffs:')
print(transport_map.CoeffMap())
print('Final objective value: {:.2E}'.format(obj(transport_map.CoeffMap(), transport_map, test_z)))
print('==================')
==================
Optimization terminated successfully.
         Current function value: 2.848868
         Iterations: 17
         Function evaluations: 19
         Gradient evaluations: 19
Final coeffs:
[ 1.64386417e-02  8.40185180e-01  1.11002314e-02  9.88287812e-01
  8.46815229e-01 -7.30824747e-03  7.15717444e-02 -1.61399131e-03
  2.21101864e+00]
Final objective value: 2.83E+00
==================

Approximation after optimization#

Pushed samples#

[8]:
# Pushed samples
x = transport_map.Evaluate(test_z)

# After optimization plot
plt.figure()
plt.contour(xx1, xx2, target_pdf_at_grid.reshape(ngrid,ngrid))
plt.scatter(x[0],x[1], facecolor='blue', alpha=0.1, label='Pushed samples')
plt.legend()
plt.show()
../../../_images/source_tutorials_python_FromDensity2D_banana_22_0.png

After optimization, pushed samples T(zi), ziN(0,I) are approximately distributed according to the target π

Variance diagnostic#

A commonly used accuracy check when facing computation maps from density is the so-called variance diagnostic defined as:

ϵσ=12Varρ[logρTπ¯]

This diagnostic is asymptotically equivalent to the minimized KL divergence DKL(η||Tπ) and should converge to zero when the computed map converge to the true map.

The variance diagnostic can be computed as follow:

[9]:
def variance_diagnostic(tri_map,ref,target_logpdf,x):
  ref_logpdf = ref.logpdf(x.T)
  y = tri_map.Evaluate(x)
  pullback_logpdf = target_logpdf(y) + tri_map.LogDeterminant(x)
  diff = ref_logpdf - pullback_logpdf
  expect = np.mean(diff)
  var = 0.5*np.mean((diff-expect)**2)
  return var
[10]:
# Reference distribution
ref_distribution = multivariate_normal(np.zeros(2),np.eye(2));

# Compute variance diagnostic
var_diag = variance_diagnostic(transport_map,ref_distribution,target_logpdf,test_z)

# Print variance diagnostic
print('==================')
print('Variance diagnostic: {:.2E}'.format(var_diag))
print('==================')
==================
Variance diagnostic: 3.34E-04
==================

Pushforward density#

We can also plot the contour of the unnormalized density π¯ and the pushforward approximation Tη:

[11]:
# Pushforward definition
def push_forward_pdf(tri_map,ref,x):
  xinv = tri_map.Inverse(x,x)
  log_det_grad_x_inverse = - tri_map.LogDeterminant(xinv)
  log_pdf = ref.logpdf(xinv.T)+log_det_grad_x_inverse
  return np.exp(log_pdf)

map_approx_grid = push_forward_pdf(transport_map,ref_distribution,xx)

fig, ax = plt.subplots()
CS1 = ax.contour(xx1, xx2, target_pdf_at_grid.reshape(ngrid,ngrid))
CS2 = ax.contour(xx1, xx2, map_approx_grid.reshape(ngrid,ngrid),linestyles='--')
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
h1,_ = CS1.legend_elements()
h2,_ = CS2.legend_elements()
legend1 = ax.legend([h1[0], h2[0]], ['Unnormalized target', 'TM approximation'])
plt.show()

../../../_images/source_tutorials_python_FromDensity2D_banana_32_0.png
[ ]: