Open In Colab

Compare wcosmo and astropy timing.#

The two primary pieces of functionality we use are converting from luminosity distance to redshift, calculating the distance to redshift jacobian, and calculating the differential comoving volume.

Timing the wcosmo implementation is non trivial as we rely on JIT compilation and also need to make sure we wait until the evaluation is complete. The steps are:

  • jit compile a wrapper function to call.

  • burn an evaluation for the compilation.

  • run the function and use block_until_ready to ensure we get the full time.

We also time wcosmo with the numpy and cupy backends. Note that cupy also requires burning a call to compile the underlying CUDA code.

We manually switch backends, although this can be done automatically using GWPopulation.

[1]:
!pip install wcosmo --quiet --progress-bar off
[2]:
import numpy as np
import wcosmo


def set_backend(backend):
    from importlib import import_module
    np_modules = dict(
        numpy="numpy",
        jax="jax.numpy",
        cupy="cupy",
    )
    linalg_modules = dict(
        numpy="scipy.linalg",
        jax="jax.scipy.linalg",
        cupy="cupyx.scipy.linalg",
    )
    setattr(wcosmo.wcosmo, "xp", import_module(np_modules[backend]))
    setattr(wcosmo.utils, "xp", import_module(np_modules[backend]))
    toeplitz = getattr(import_module(linalg_modules[backend]), "toeplitz")
    setattr(wcosmo.utils, "toeplitz", toeplitz)


ndata = np.random.uniform(1, 10, 1000000)

wcosmo + jax + GPU#

[3]:
import jax.numpy as jnp
import numpy as np
from jax import jit


set_backend("jax")

jdata = jnp.array(ndata)


@jit
def time_jax_redshift(jdata):
    return wcosmo.z_at_value(wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, jdata)


@jit
def time_jax_dvcdz(jdata):
    return wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(jdata)


burn_vals = time_jax_redshift(jdata)
burn_vals = time_jax_dvcdz(jdata)
[4]:
%%time

_ = time_jax_redshift(jdata).block_until_ready()
CPU times: user 1.56 ms, sys: 0 ns, total: 1.56 ms
Wall time: 4.83 ms
[5]:
%%time

_ = time_jax_dvcdz(jdata).block_until_ready()
CPU times: user 632 µs, sys: 173 µs, total: 805 µs
Wall time: 3.69 ms

astropy + cpu#

Note that this is very slow in this case so we only use one percent of the full data. Since this is numpy-based the time scales linearly with the amount of data.

In practice, most people when using astropy use intepolation to evaluate z_at_value at many points as is done in wcosmo.

[6]:
from astropy import cosmology, units
[7]:
%%time

_ = cosmology.z_at_value(
    cosmology.FlatwCDM(67, 0.3, -1).luminosity_distance,
    ndata[:10000] * units.Mpc,
).value
CPU times: user 40.9 s, sys: 418 ms, total: 41.3 s
Wall time: 52 s
[8]:
%%time

_ = cosmology.FlatwCDM(67, 0.3, -1).differential_comoving_volume(
    ndata[:10000],
).value
CPU times: user 106 ms, sys: 27 µs, total: 106 ms
Wall time: 107 ms

wcosmo + numpy + cpu#

[9]:
set_backend("numpy")
[10]:
%%time

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, ndata
)
CPU times: user 89.5 ms, sys: 75.2 ms, total: 165 ms
Wall time: 86.7 ms
[11]:
%%time

_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(ndata)
CPU times: user 65.8 ms, sys: 89.1 ms, total: 155 ms
Wall time: 79.4 ms

wcosmo + cupy + gpu#

The final test is using the cupy backend on the GPU. Typically this is much faster than numpy but slower than the JAX GPU code. Especially, not tested here is transfer between CPU/GPU which can be quite slow for cupy.

[12]:
import cupy

set_backend("cupy")

cdata = cupy.asarray(ndata)

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)
[13]:
%%time

_ = wcosmo.z_at_value(
    wcosmo.FlatwCDM(67, 0.3, -1).luminosity_distance, cdata
)
cupy.cuda.stream.get_current_stream().synchronize()
CPU times: user 5.42 ms, sys: 17 µs, total: 5.44 ms
Wall time: 5.45 ms
[14]:
%%time

_ = wcosmo.FlatwCDM(67, 0.3, -1).differential_comoving_volume(cdata)
cupy.cuda.stream.get_current_stream().synchronize()
CPU times: user 73.6 ms, sys: 0 ns, total: 73.6 ms
Wall time: 83.6 ms
[ ]: