Orateur
Description
A series of recent works has highlighted the potential of so-called full-field cosmological inference for the analysis of upcoming weak lensing and galaxy clustering surveys, with the promise of being able to access the non-gaussian information contained in the data. However, these approaches require a differentiable forward model of the large scale structure, which currently constitutes one of the main bottlenecks, as none of the publicly available differentiable N-body codes support large-scale distribution necessary for large cosmological volumes.
To address this, we introduce JAXPM, a JAX-powered library that facilitates multi-GPU and multi-node Particle-Mesh N-body simulations. Leveraging the capabilities of a low-level CUDA distributed computing library, JaxPM allows for the first time the simulation of the entire cosmological volume observed by LSST down to Mpc resolution with an automatically differentiable simulation code.
We will present our early scaling tests of this approach on the Jean-Zay GPU supercomputer.