Compare wcosmo
and astropy
timing.#
The two primary pieces of functionality we use are converting from luminosity distance to redshift, calculating the distance to redshift jacobian, and calculating the differential comoving volume.
Timing the wcosmo
implementation is non trivial as we rely on JIT compilation and also need to make sure we wait until the evaluation is complete. The steps are:
jit compile a wrapper function to call.
burn an evaluation for the compilation.
run the function and use
block_until_ready
to ensure we get the full time.
We also time wcosmo
with the numpy
and cupy
backends. Note that cupy
also requires burning a call to compile the underlying CUDA
code.
We manually switch backends, although this can be done automatically using GWPopulation
.
[1]:
!pip install wcosmo --quiet --progress-bar off
[2]:
import numpy as np
import wcosmo
def set_backend(backend):
from importlib import import_module
np_modules = dict(
numpy="numpy",
jax="jax.numpy",
cupy="cupy",
)
linalg_modules = dict(
numpy="scipy.linalg",
jax="jax.scipy.linalg",
cupy="cupyx.scipy.linalg",
)
setattr(wcosmo.wcosmo, "xp", import_module(np_modules[backend]))
setattr(wcosmo.utils, "xp", import_module(np_modules[backend]))
toeplitz = getattr(import_module(linalg_modules[backend]), "toeplitz")
setattr(wcosmo.utils, "toeplitz", toeplitz)
ndata = np.random.uniform(1, 10, 1000000)
wcosmo + jax + GPU#
[3]:
import jax.numpy as jnp
import numpy as np
from jax import jit
set_backend("jax")
jdata = jnp.array(ndata)
@jit
def time_jax_redshift(jdata):
return wcosmo.z_at_value(wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, jdata)
@jit
def time_jax_dvcdz(jdata):
return wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(jdata)
burn_vals = time_jax_redshift(jdata)
burn_vals = time_jax_dvcdz(jdata)
[4]:
%%time
_ = time_jax_redshift(jdata).block_until_ready()
CPU times: user 1.56 ms, sys: 0 ns, total: 1.56 ms
Wall time: 4.83 ms
[5]:
%%time
_ = time_jax_dvcdz(jdata).block_until_ready()
CPU times: user 632 µs, sys: 173 µs, total: 805 µs
Wall time: 3.69 ms
astropy + cpu#
Note that this is very slow in this case so we only use one percent of the full data. Since this is numpy
-based the time scales linearly with the amount of data.
In practice, most people when using astropy
use intepolation to evaluate z_at_value
at many points as is done in wcosmo
.
[6]:
from astropy import cosmology, units
[7]:
%%time
_ = cosmology.z_at_value(
cosmology.FlatwCDM(67, 0.3, -1).luminosity_distance,
ndata[:10000] * units.Mpc,
).value
CPU times: user 40.9 s, sys: 418 ms, total: 41.3 s
Wall time: 52 s
[8]:
%%time
_ = cosmology.FlatwCDM(67, 0.3, -1).differential_comoving_volume(
ndata[:10000],
).value
CPU times: user 106 ms, sys: 27 µs, total: 106 ms
Wall time: 107 ms
wcosmo + numpy + cpu#
[9]:
set_backend("numpy")
[10]:
%%time
_ = wcosmo.z_at_value(
wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, ndata
)
CPU times: user 89.5 ms, sys: 75.2 ms, total: 165 ms
Wall time: 86.7 ms
[11]:
%%time
_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(ndata)
CPU times: user 65.8 ms, sys: 89.1 ms, total: 155 ms
Wall time: 79.4 ms
wcosmo + cupy + gpu#
The final test is using the cupy
backend on the GPU. Typically this is much faster than numpy
but slower than the JAX
GPU code. Especially, not tested here is transfer between CPU/GPU which can be quite slow for cupy
.
[12]:
import cupy
set_backend("cupy")
cdata = cupy.asarray(ndata)
_ = wcosmo.z_at_value(
wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)
[13]:
%%time
_ = wcosmo.z_at_value(
wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
cupy.cuda.stream.get_current_stream().synchronize()
CPU times: user 5.42 ms, sys: 17 µs, total: 5.44 ms
Wall time: 5.45 ms
[14]:
%%time
_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)
cupy.cuda.stream.get_current_stream().synchronize()
CPU times: user 73.6 ms, sys: 0 ns, total: 73.6 ms
Wall time: 83.6 ms
[ ]: