Source code for dctkit

import sys
import enum
import jax.numpy as jnp
from jax import config as cfg

if sys.version_info[:2] >= (3, 8):
    # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
    from importlib.metadata import PackageNotFoundError, version  # pragma: no cover
else:
    from importlib_metadata import PackageNotFoundError, version  # pragma: no cover

try:
    # Change here if project is renamed and does not equal the package name
    dist_name = __name__
    __version__ = version(dist_name)
except PackageNotFoundError:  # pragma: no cover
    __version__ = "unknown"
finally:
    del version, PackageNotFoundError

FloatDtype = enum.Enum('FloatDtype', ['float32', 'float64'])
IntDtype = enum.Enum('IntDtype', ['int32', 'int64'])
Platform = enum.Enum('Platform', ['cpu', 'gpu'])

# data types, backend and platform used in all the modules
float_dtype = FloatDtype.float64.name
int_dtype = IntDtype.int64.name
backend = jnp
platform = Platform.cpu
config_called = False


[docs] def config(fdtype=FloatDtype.float64, platfm=Platform.cpu): """Set global configuration parameters.""" global config_called, float_dtype, int_dtype, platform if not config_called: float_dtype = fdtype.name platform = platfm cfg.update('jax_platform_name', platfm.name) if fdtype == FloatDtype.float64: int_dtype = IntDtype.int64.name cfg.update("jax_enable_x64", True) else: int_dtype = IntDtype.int32.name config_called = True