Skip to main content
Log in

A Machine Learning Framework for Geodesics Under Spherical Wasserstein–Fisher–Rao Metric and Its Application for Weighted Sample Generation

  • Published:
Journal of Scientific Computing Aims and scope Submit manuscript

Abstract

Wasserstein–Fisher–Rao (WFR) distance is a family of metrics to gauge the discrepancy of two Radon measures, which takes into account both transportation and weight change. Spherical WFR distance is a projected version of WFR distance for probability measures so that the space of Radon measures equipped with WFR can be viewed as metric cone over the space of probability measures with spherical WFR. Compared to the case for Wasserstein distance, the understanding of geodesics under the spherical WFR is less clear and still an ongoing research focus. In this paper, we develop a deep learning framework to compute the geodesics under the spherical WFR metric, and the learned geodesics can be adopted to generate weighted samples. Our approach is based on a Benamou–Brenier type dynamic formulation for spherical WFR. To overcome the difficulty in enforcing the boundary constraint brought by the weight change, a Kullback–Leibler divergence term based on the inverse map is introduced into the cost function. Moreover, a new regularization term using the particle velocity is introduced as a substitute for the Hamilton–Jacobi equation for the potential in dynamic formula. When used for sample generation, our framework can be beneficial for applications with given weighted samples, especially in the Bayesian inference, compared to sample generation with previous flow models.

This is a preview of subscription content, log in via an institution to check access.

Access this article

Price excludes VAT (USA)
Tax calculation will be finalised during checkout.

Instant access to the full article PDF.

Algorithm 1
Algorithm 2
Algorithm 3
Fig. 1
Fig. 2
Fig. 3
Fig. 4
Fig. 5
Fig. 6
Fig. 7
Fig. 8
Fig. 9

Similar content being viewed by others

Data Availibility

The datasets generated during the current study are available from the corresponding author on reasonable request.

Notes

  1. https://www.csie.ntu.edu.tw/\(\sim \)cjlin/libsvmtools/datasets/binary.html.

References

  1. Ambrosio, L., Gigli, N., Savaré, G.: Gradient Flows: in Metric Spaces and in the Space of Probability Measures. Springer (2005)

  2. Apte, A., Hairer, M., Stuart, A.M., Voss, J.: Sampling the Posterior: An Approach to Non-Gaussian Data Assimilation. Physica D: Nonlinear Phenomena, 230(1–2), 50–64 (2007)

  3. Arjovsky, M., Chintala, S., Bottou, L.: Wasserstein generative adversarial networks. In: International Conference on Machine Learning, pp. 214–223 (2017)

  4. Braides, A.: Gamma-Convergence for Beginners, vol. 22. Clarendon Press (2002)

  5. Brenier, Y., Vorotnikov, D.: On optimal transport of matrix-valued measures. SIAM J. Math. Anal. 52(3), 2849–2873 (2020)

    Article  MathSciNet  Google Scholar 

  6. Chen, R.T.Q., Rubanova, Y., Bettencourt, J., Duvenaud, D.K.: Neural ordinary differential equations. Adv. Neural Inf. Process. Syst. 31 (2018)

  7. Chizat, L., Peyré, G., Schmitzer, B., Vialard, F.-X.: An interpolating distance between optimal transport and Fisher–Rao metrics. Found. Comput. Math. 18(1), 1–44 (2018)

    Article  MathSciNet  Google Scholar 

  8. Chizat, L., Peyré, G., Schmitzer, B., Vialard, F.-X.: Unbalanced optimal transport: dynamic and Kantorovich formulations. J. Funct. Anal. 274(11), 3090–3123 (2018)

    Article  MathSciNet  Google Scholar 

  9. Chwialkowski, K., Strathmann, H., Gretton, A.: A kernel test of goodness of fit. In: International Conference on Machine Learning, pp. 2606–2615 (2016)

  10. De Giorgi, E.: New Problems on Minimizing Movements. Ennio de Giorgi: Selected Papers, pp. 699–713 (1993)

  11. Devlin, J., Chang, M.-W., Lee, K., Toutanova, K.: BERT: pre-training of deep bidirectional transformers for language understanding. In: Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4171–4186. Association for Computational Linguistics (2019)

  12. Evans, L.C.: An Introduction to Mathematical Optimal Control Theory Version 0.2. Lecture Notes available at http://math.berkeley.edu/\(\sim \)evans/control.course.pdf (1983)

  13. Finlay, C., Jacobsen, J.-H., Nurbekyan, L., Oberman, A.: How to train your neural ODE: the world of Jacobian and kinetic regularization. In: International Conference on Machine Learning, pp. 3154–3164 (2020)

  14. Galichon, A.: A survey of some recent applications of optimal transport methods to econometrics. Econom. J. 20(2), C1–C11 (2017)

    Article  MathSciNet  Google Scholar 

  15. Galichon, A.: Optimal Transport Methods in Economics. Princeton University Press (2018)

  16. Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., Bengio, Y.: Generative adversarial nets. Adv. Neural Inf. Process. Syst. 27 (2014)

  17. Gorham, J., Mackey, L.: Measuring sample quality with kernels. In: International Conference on Machine Learning, pp. 1292–1301 (2017)

  18. Gretton, A., Borgwardt, K.M., Rasch, M.J., Schölkopf, B., Smola, A.: A kernel two-sample test. J. Mach. Learn. Res. 13(1), 723–773 (2012)

  19. Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., Courville, A.C.: Improved Training of Wasserstein GANs. Adv. Neural Inf. Process. Syst. 30 (2017)

  20. He, K., Zhang, X., Ren, S., Sun, J.: Deep Residual Learning for Image Recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016)

  21. Hu, T., Chen, Z., Sun, H., Bai, J., Ye, M., Cheng, G.: Stein neural sampler. arXiv preprint arXiv:1810.03545 (2018)

  22. Johnson, R., Zhang, T.: A framework of composite functional gradient methods for generative adversarial models. IEEE Trans. Pattern Anal. Mach. Intell. 43(1), 17–32 (2019)

    Article  Google Scholar 

  23. Jordan, R., Kinderlehrer, D., Otto, F.: The variational formulation of the Fokker–Planck equation. SIAM J. Math. Anal. 29(1), 1–17 (1998)

    Article  MathSciNet  Google Scholar 

  24. Kingma, D.P., Ba, J.: Adam: a method for stochastic optimization. In: International Conference on Learning Representations (2015)

  25. Kingma, D.P., Welling, M.: Auto-encoding variational bayes. In: International Conference on Learning Representations (2014)

  26. Kingma, D.P., Welling, M., et al.: An introduction to variational autoencoders. Found. Trends® Mach. Learn. 12(4), 307–392 (2019)

  27. Kondratyev, S., Monsaingeon, L., Vorotnikov, D.: A new optimal transport distance on the space of finite Radon measures. Adv. Differ. Equ. 21(11/12), 1117–1164 (2016)

    MathSciNet  Google Scholar 

  28. Kondratyev, S., Vorotnikov, D.: Spherical Hellinger–Kantorovich gradient flows. SIAM J. Math. Anal. 51(3), 2053–2084 (2019)

    Article  MathSciNet  Google Scholar 

  29. Laschos, V., Mielke, A.: Geometric properties of cones with applications on the Hellinger–Kantorovich space, and a new distance on the space of probability measures. J. Funct. Anal. 276(11), 3529–3576 (2019)

    Article  MathSciNet  Google Scholar 

  30. Li, W., Lee, W., Osher, S.: Computational mean-field information dynamics associated with reaction–diffusion equations. J. Comput. Phys., p. 111409 (2022)

  31. Li, W., Ryu, E.K., Osher, S., Yin, W., Gangbo, W.: A parallel method for Earth Mover’s distance. J. Sci. Comput. 75(1), 182–197 (2018)

  32. Li, Y., Swersky, K., Zemel, R.: Generative moment matching networks. In: International Conference on Machine Learning, pp. 1718–1727 (2015)

  33. Liero, M., Mielke, A., Savaré, G.: Optimal entropy-transport problems and a new Hellinger–Kantorovich distance between positive measures. Inventiones Math. 211(3), 969–1117 (2018)

    Article  MathSciNet  Google Scholar 

  34. Liu, Q.: Stein variational gradient descent as gradient flow. Adv. Neural Inf. Process. Syst. 30 (2017)

  35. Liu, Q., Lee, J., Jordan, M.: A kernelized Stein discrepancy for goodness-of-fit tests. In: International Conference on Machine Learning, pp. 276–284 (2016)

  36. Liu, Q., Wang, D.: Stein variational gradient descent: a general purpose Bayesian inference algorithm. Adv. Neural Inf. Process. Syst. 29 (2016)

  37. Metropolis, N., Rosenbluth, A.W., Rosenbluth, M.N., Teller, A.H., Teller, E.: Equation of state calculations by fast computing machines. J. Chem. Phys. 21(6), 1087–1092 (1953)

  38. Monge, G.: Mémoire sur la théorie des déblais et des remblais. Mem. Math. Phys. Acad. Royale Sci., pp. 666–704 (1781)

  39. Müller, T., McWilliams, B., Rousselle, F., Gross, M., Novák, J.: Neural importance sampling. ACM Trans. Graphics (ToG) 38(5), 1–19 (2019)

    Article  Google Scholar 

  40. Onken, D., Fung, S.W., Li, X., Ruthotto, L.: OT-Flow: Fast and accurate continuous normalizing flows via optimal transport. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 9223–9232 (2021)

  41. Papamakarios, G., Nalisnick, E., Rezende, D.J., Mohamed, S., Lakshminarayanan, B.: Normalizing flows for probabilistic modeling and inference. J. Mach. Learn. Res. 22(1), 2617–2680 (2021)

  42. Pele, O., Werman, M.: A linear time histogram metric for improved sift matching. In: European Conference on Computer Vision, pp. 495–508. Springer (2008)

  43. Peyré, G., Cuturi, M.: Computational optimal transport: with applications to data science. Found. Trends® Mach. Learn. 11(5–6), 355–607 (2019)

  44. Rezende, D., Mohamed, S.: Variational inference with normalizing flows. In: International Conference on Machine Learning, pp. 1530–1538 (2015)

  45. Rubner, Y., Guibas, L.J., Tomasi, C.: The Earth Mover’s distance, multi-dimensional scaling, and color-based image retrieval. In: Proceedings of the ARPA Image Understanding Workshop, vol. 661, p. 668 (1997)

  46. Rubner, Y., Tomasi, C., Guibas, L.J.: The Earth Mover’s distance as a metric for image retrieval. Int. J. Comput. Vis. 40(2), 99–121 (2000)

  47. Ruthotto, L., Osher, S.J., Li, W., Nurbekyan, L., Fung, S.W.: A machine learning framework for solving high-dimensional mean field game and mean field control problems. Proc. Natl. Acad. Sci. 117(17), 9183–9193 (2020)

  48. Salimans, T., Zhang, H., Radford, A., Metaxas, D.: Improving GANs using optimal transport. In: International Conference on Learning Representations (2018)

  49. Santambrogio, F.: Optimal Transport for Applied Mathematicians. Birkäuser, NY 55(58–63), 94 (2015)

  50. Schiebinger, G., Shu, J., Tabaka, M., Cleary, B., Subramanian, V., Solomon, A., Gould, J., Liu, S., Lin, S., Berube, P., et al.: Optimal-transport analysis of single-cell gene expression identifies developmental trajectories in reprogramming. Cell 176(4), 928–943 (2019)

    Article  Google Scholar 

  51. Tabak, E.G., Vanden-Eijnden, E.: Density estimation by dual ascent of the log-likelihood. Commun. Math. Sci. 8(1), 217–233 (2010)

  52. Theis, L., Oord, A.V.D., Bethge, M.: A note on the evaluation of generative models. In: International Conference on Learning Representations (2016)

  53. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł., Polosukhin, I.: Attention is all you need. Adv. Neural Inf. Process. Syst. 30 (2017)

  54. Vidal, A., Wu Fung, S., Tenorio, L., Osher, S., Nurbekyan, L.: Taming hyperparameter tuning in continuous normalizing flows using the JKO scheme. Sci. Rep. 13(1), 4501 (2023)

  55. Villani, C.: Optimal Transport: Old and New, vol. 338. Springer (2009)

  56. Wang, Z., Zhou, D., Yang, M., Zhang, Y., Rao, C., Wu, H.: Robust document distance with Wasserstein–Fisher–Rao metric. In: Asian Conference on Machine Learning, pp. 721–736 (2020)

  57. Wu, J., Wen, L., Green, P.L., Li, J., Maskell, S.: Ensemble Kalman filter based sequential Monte Carlo sampler for sequential Bayesian inference. Stat. Comput. 32(1), 1–14 (2022)

  58. Xiong, Z., Li, L., Zhu, Y.-N., Zhang, X.: On the convergence of continuous and discrete unbalanced optimal transport models SIAM J. Numer. Anal. To appear. arXiv preprint arXiv:2303.17267 (2023)

  59. Yang, K.D., Damodaran, K., Venkatachalapathy, S., Soylemezoglu, A.C., Shivashankar, G.V., Uhler, C.: Predicting cell lineages using autoencoders and optimal transport. PLoS Comput. Biol. 16(4), e1007828 (2020)

  60. Zhou, D., Chen, J., Wu, H., Yang, D., Qiu, L.: The Wasserstein–Fisher–Rao metric for waveform based earthquake location. J. Comput. Math. 41(3), 417–438 (2023)

Download references

Funding

This work is partially supported by the National Key R &D Program of China Nos. 2020YFA0712000 and 2021YFA1002800. The work of L. Li was partially supported by Shanghai Municipal Science and Technology Major Project 2021SHZDZX0102, NSFC 12371400 and 12031013, and Shanghai Science and Technology Commission Grant No. 21JC1402900.

Author information

Authors and Affiliations

Authors

Contributions

All authors contributed equally.

Corresponding author

Correspondence to Lei Li.

Ethics declarations

Competing interests

The authors have no relevant financial or non-financial interests to disclose.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Appendix A: Implementation Details for Solving SWFR Distance with Optimization Method

Appendix A: Implementation Details for Solving SWFR Distance with Optimization Method

We can apply a primal-dual hybrid algorithm to compute the SWFR distance. The algorithm is also used to compute the earth mover’s distance in [31]. To begin with, we reformulate the optimization problem (3.4) as a convex form:

$$\begin{aligned} \min _{\rho , m, \xi }\left\{ \frac{1}{2} \int _0^1 \int _{\mathbb {R}^d} \left( \frac{|m|^2}{\rho }+\alpha \frac{\xi ^{2}}{\rho } \right) d x d t, \partial _t \rho +\nabla \cdot m =\xi , \int _{\mathbb {R}^d} \xi d x=0\right\} . \end{aligned}$$
(A.1)

Then we solve the corresponding min-max problem:

$$\begin{aligned}{} & {} \max _{\phi ,\lambda }\min _{\rho ,m,\xi } \int _0^1 \int _{\Omega } \left( \frac{|m|^2}{\rho }+\alpha \frac{\xi ^{2}}{\rho } \right) d x d t\nonumber \\{} & {} \quad + \int _0^1\int _{\Omega } \phi (x,t)\left( \partial _t \rho +\nabla \cdot m-\xi \right) dx dt+\int _0^1\lambda (t) \left( \int _{\Omega } \xi dx \right) dt. \end{aligned}$$
(A.2)

We consider the space domain \(\Omega =[0,1]^{d}\) and the time domain [0, 1] for simplicity. Let \(\Omega _{h}\) be the discrete space mesh-grid of \(\Omega \) with step size h, i.e. \(\Omega _{h}=\{0,h,2h,\cdots ,1\}^d\). The time domain [0, 1] is discreted with step size \(\Delta t\) . Let \(N_x=1/h\) be the space grid size and \(N_t=1/\Delta t\) be the time grid size. All optimization variables (\(\rho \), \(\xi \), m, \(\phi \) and \(\lambda \)) are defined on the grid.

We employ the same discrete scheme for divergence operator and boundary conditions as [58]. We give some definitions on the discrete space \(\Omega _h\):

$$\begin{aligned} \begin{aligned} \int _{\Omega _h}f(x)dx&:=\sum _{x \in \Omega _{h},x_i \ne 0} f(x) h^d, \quad \left\langle f, g \right\rangle _h:= \int _{\Omega _h}f(x) g(x) dx,\\ \nabla _{h} \cdot m(x)&:=\sum _{i=1}^{d} D_{h,i} m(x), x\in \Omega _h, \end{aligned} \end{aligned}$$

where \(D_{h,i}\) denotes discrete differential operator for i-th component in i-th dimension with step size h:

$$\begin{aligned} D_{h, i} m(x)=\left( m_i \left( x_1,\cdots , x_i, \cdots , x_d\right) -m_i \left( x_1,\cdots , x_i-h, \cdots , x_d\right) \right) / h, \quad h \le x_i \le 1. \\ \end{aligned}$$

Then the discretization of (A.2) can be written as:

$$\begin{aligned}{} & {} \max _{\phi ,\lambda }\min _{\rho ,m,\xi }L\left( m, \xi , \rho ,\phi ,\lambda \right) = \sum _{t=1}^{N_t} \left\{ \int _{\Omega _h} \left( \frac{|m_{t}|^2}{2\rho _{t}}+\alpha \frac{\xi _{t}^{2}}{2\rho _{t}}\right) d x + \left\langle \phi _{t},\frac{\rho _{t+1}-\rho _{t}}{\Delta t}+\nabla _h \cdot m_{t}-\xi _{t} \right\rangle _h \right. \nonumber \\{} & {} \quad \left. +\lambda _t \left( \int _{\Omega _h} \xi _{t} dx \right) \right\} . \end{aligned}$$
(A.3)

From the discretized form, we can describe the sizes of the optimization variables individually. The sizes of discretized variables are : \((N_t+1) \times (N_x+1)^d \) for \(\rho \), \( N_t\times (N_x+1)^{d} \times d \) for m, \( N_t\times (N_x)^{d} \) for \(\xi \) and \(\phi \), and \(N_t\) for \(\lambda \). The boundary conditions are given as \(\rho _1=\mu \), \(\rho _{N_t+1}=\nu \) and \(m(x)=0\) for all \(x \in \partial \Omega _h\). Then the primal-dual hybrid algorithm gives as follows for variables with superscripts k:

$$\begin{aligned} \begin{aligned}&\left( m_{t}^{k+1},\xi _t^{k+1},\rho _t^{k+1}\right) =\underset{m^{\star },\xi ^{\star },\rho ^{\star }}{\arg \min } L\left( m^{\star },\xi ^{\star },\rho ^{\star },\phi ^{k},\lambda ^{k} \right) \\&\qquad \quad +\frac{1}{2 \mu }\left( \left\| m^{\star }-m^{k}_{t}\right\| _{2}^2+\alpha \left\| \xi ^{\star }-\xi ^k_t \right\| _{ 2}^2 +\left\| \rho ^{\star }-\rho ^k_t\right\| _{ 2}^2 \right) , \\&\tilde{m}_t^{k+1}=2 m_t^{k+1}-m_t^k, \quad \tilde{\xi }_t^{k+1}=2 \xi _t^{k+1}-\xi _t^k, \quad \tilde{\rho }_t^{k+1}=2 \rho _t^{k+1}-\rho _t^k ,\\&\qquad \left( \phi _t^{k+1},\lambda _t^{k+1}\right) =\underset{\phi ^{\star },\lambda ^{\star }}{\arg \max } L\left( \tilde{m}_t^{k+1}, \tilde{\xi }_h^{k+1}, \tilde{\rho }_h^{k+1},\phi ^{\star }, \lambda ^{\star }\right) -\frac{1}{2 \tau } \left( \left\| \phi ^{\star }-\phi ^{k}_t\right\| _{2}^2 + \left\| \lambda ^{\star }-\lambda ^{k}_t\right\| _{2}^2 \right) . \\&\end{aligned} \end{aligned}$$
(A.4)

The first step of the algorithm is equivalent to solving the following system:

$$\begin{aligned} \left\{ \begin{array}{l} m^{\star }=\frac{\rho ^{\star }(m_t^{k}-\mu \text {div}_{h}^{*} \phi _t^{k})}{\rho ^{\star }+\mu }, \\ \xi ^{\star }=\frac{\rho ^{\star }(\xi _t^{k}+\frac{1}{\alpha }\mu (\phi _t^{k}-\lambda _t^k))}{\rho ^{\star }+\mu }, \\ -\frac{{m^{\star }}^{2}}{2{\rho ^{\star }}^2}-\alpha \frac{{\xi ^{\star }}^{2}}{2{\rho ^{\star }}^2}+\frac{\phi _{t-1}^{k}-\phi ^{k}_{t}}{\Delta t}+\frac{1}{\mu }(\rho ^{\star }-\rho _t^k)=0, \end{array}\right. \end{aligned}$$
(A.5)

where \(\text {div}_{h}^{*}\) represents the conjugate operator of divergence operator. By the definition of conjugate operator, we have

$$\begin{aligned} \left\langle \nabla _h \cdot f,g \right\rangle _{h} =\left\langle f,\text {div}_{h}^{*} g \right\rangle _h= \left\langle f, -\nabla _h g \right\rangle _h. \end{aligned}$$

Thus \(\text {div}_{h}^{*}=-\nabla _{h}\) and

$$\begin{aligned} \nabla _h u =\left( \partial _{h, 1} u, \partial _{h, 2} u, \cdots , \partial _{h, d} u\right) , \end{aligned}$$

where each \(\partial _{h,i} \) denotes discrete differential operator in i-th dimension with step size h:

$$\begin{aligned} \partial _{h, i} u(x)=\left\{ \begin{array}{l} \left( u\left( x_1,\cdots , x_i+h, \cdots , x_d\right) -u\left( x_1,\cdots , x_i, \cdots , x_d\right) \right) / h, \quad 0\le x_i < 1, \\ 0,\quad x_i=1. \end{array}\right. \end{aligned}$$

We also introduce a natural boundary condition \(\phi _0=0\), which comes from deriving optimal value condition for updating \(\rho \). Solving above system (A.5) requires us to solve the roots for a third order polynomial, where \(\rho ^{\star }\) should be the largest real root. The third step of the algorithm is to update dual variables:

$$\begin{aligned} \left\{ \begin{array}{l} \phi _t^{k+1}=\phi _{t}^{k}+ \tau \left( \frac{\tilde{\rho }_{t+1}^{k+1}-\tilde{\rho }_{t}^{k+1}}{\Delta t}-\nabla \cdot \tilde{m}_{t}^{k+1}-\tilde{\xi }_t^{k+1}\right) , \\ \lambda _t^{k+1}=\lambda _{t}^{k}+\tau \left( \int _{\Omega }\tilde{\xi }_t^{k+1}dx\right) . \end{array}\right. \end{aligned}$$
(A.6)

Rights and permissions

Springer Nature or its licensor (e.g. a society or other partner) holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Jing, Y., Chen, J., Li, L. et al. A Machine Learning Framework for Geodesics Under Spherical Wasserstein–Fisher–Rao Metric and Its Application for Weighted Sample Generation. J Sci Comput 98, 5 (2024). https://doi.org/10.1007/s10915-023-02396-y

Download citation

  • Received:

  • Revised:

  • Accepted:

  • Published:

  • DOI: https://doi.org/10.1007/s10915-023-02396-y

Keywords

Mathematics Subject Classification

Navigation