Source code for dctkit.dec.flat

import jax.numpy as jnp
from dctkit.dec import cochain as C
from dctkit.math.spmm import spmm
from jax import Array, vmap
from functools import partial
from typing import Callable, Dict, Optional


[docs] def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V | C.CochainD1V, interp_func: Optional[Callable] = None, interp_func_args: Optional[Dict] = {}) -> C.CochainP1 | C.CochainD1: """Applies the flat to a vector/tensor-valued cochain representing a discrete vector/tensor field to obtain a scalar-valued cochain over primal/dual edges. Args: c: input vector/tensor-valued 0-cochain representing a primal/dual discrete vector/tensor field. weights: array of weights that represent the contribution of each component of the input cochain to the primal/dual edges where the output cochain is defined (i.e. where integration is performed). The number of columns must be equal to the number of primal/dual target edges. Weights depend on the interpolation scheme chosen for the input discrete vector/tensor field. edges: vector-valued cochain collecting the primal/dual edges over which the discrete vector/tensor field should be integrated. interp_func: interpolation function (callable) taking in input the cochain c and providing in output a 1-cochain of the same type (primal/dual). If it is None, then an interpolation function is built as W^T@c.coeffs. interp_func_args: additional keyword arguments for interp_func Returns: a primal/dual scalar/vector-valued cochain defined over primal/dual edges. """ if interp_func is None: # contract over the simplices of the input cochain (last axis of weights, # first axis of input cochain coeffs) def interp_func(x): return spmm(weights, x.coeffs, transpose=True, shape=edges.coeffs.shape[0]) interp_func_args = {} weighted_v = interp_func(c, **interp_func_args) # contract input vector/tensors with edge vectors (last indices of both # coefficient matrices), for each target primal/dual edge contract = partial(jnp.tensordot, axes=([-1,], [-1,])) # map over target primal/dual edges batch_contract = vmap(contract) coch_coeffs = batch_contract(weighted_v, edges.coeffs) if edges.is_primal: return C.CochainP1(c.complex, coch_coeffs) else: return C.CochainD1(c.complex, coch_coeffs)
[docs] def flat_DPD(c: C.CochainD0V | C.CochainD0T) -> C.CochainD1: """Implements the flat DPD operator for dual 0-cochains. Args: v: a dual 0-cochain. Returns: the dual 1-cochain resulting from the application of the flat operator. """ dual_edges = c.complex.dual_edges_vectors[:, :c.coeffs.shape[1]] flat_matrix = c.complex.flat_DPD_weights return flat(c, flat_matrix, C.CochainD1(c.complex, dual_edges))
[docs] def flat_DPP(c: C.CochainD0V | C.CochainD0T) -> C.CochainP1: """Implements the flat DPP operator for dual 0-cochains. Args: v: a dual 0-cochain. Returns: the primal 1-cochain resulting from the application of the flat operator. """ primal_edges = c.complex.primal_edges_vectors[:, :c.coeffs.shape[1]] flat_matrix = c.complex.flat_DPP_weights return flat(c, flat_matrix, C.CochainP1(c.complex, primal_edges))
[docs] def flat_PDP(c: C.CochainP0V | C.CochainP0T) -> C.CochainP1: """Implements the flat PDP operator for primal 0-cochains. Args: c: a primal 0-cochain. Returns: the primal 1-cochain resulting from the application of the flat operator. """ primal_edges = c.complex.primal_edges_vectors[:, :c.coeffs.shape[1]] flat_matrix = c.complex.flat_PDP_weights return flat(c, flat_matrix, C.CochainP1(c.complex, primal_edges))
[docs] def flat_dual_upw(c: C.CochainD0V | C.CochainD0T) -> C.CochainD1: """Implements the flat upwind operator for dual 0-cochains. Args: v: a dual 0-cochain. Returns: the dual 1-cochain resulting from the application of the flat operator. """ dual_edges = c.complex.dual_edges_vectors[:, :c.coeffs.shape[1]] flat_matrix = c.complex.flat_dual_upw_weights return flat(c, flat_matrix, C.CochainD1(c.complex, dual_edges))