import jaxlib.xla_extension


def ret_jax() -> jaxlib.xla_extension.DeviceArray[dtype=float32, shape=(2, 4)]: ...
