Source code for dctkit.dec.cochain

import dctkit as dt

from dctkit.mesh import simplex as spx
from dctkit.math import spmm
import numpy.typing as npt
from jax import Array
import jax.numpy as jnp
from typeguard import check_type
import warnings
from enum import Enum

# suppress all warnings
warnings.filterwarnings("ignore")


[docs] class Cochain(): """Parent cochain class. Args: dim: dimension of the complex where the cochain is defined. is_primal: True if the cochain is primal, False otherwise. complex: the simplicial complex where the cochain is defined. coeffs: array of the coefficients of the cochain. """ def __init__(self, dim: int, is_primal: bool, complex: spx.SimplicialComplex, coeffs: npt.NDArray | Array): self.dim = dim self.complex = complex self.is_primal = is_primal check_type(coeffs, npt.NDArray | Array) # in case we pass the coefficients of a scalar-valued cochain # as a row vector, transform it into column (needed for matrix-vector # ops, such as coboundary) if coeffs.ndim > 1: self.coeffs = coeffs else: self.coeffs = coeffs[:, None]
# automatic generator of cochain subclasses aliases
[docs] class rank(Enum): SCALAR = "" VECTOR = "V" TENSOR = "T"
# template for the constructor of derived classes init_template = """ def init(self, complex, coeffs): Cochain.__init__(self, dim_, is_primal_, complex, coeffs) """ attributes = {'is_primal': (True, False), 'dim': ( 0, 1, 2, 3), 'rank': (rank.SCALAR.value, rank.VECTOR.value, rank.TENSOR.value)} categories = attributes['is_primal'] dimensions = attributes['dim'] ranks = attributes['rank'] for is_primal_ in categories: for dim_ in dimensions: for rank_ in ranks: primal_flag = is_primal_*'P' + (not is_primal_)*'D' name = "Cochain" + primal_flag + str(dim_) + rank_ exec(init_template.replace("dim_", f"{dim_}").replace( "is_primal_", f"{is_primal_}")) exec(name + " =type(name, (Cochain,), {'__init__': init})")
[docs] class CochainP(Cochain): """Class for primal cochains.""" def __init__(self, dim: int, complex: spx.SimplicialComplex, coeffs: npt.NDArray | Array): super().__init__(dim, True, complex, coeffs)
[docs] class CochainD(Cochain): """Class for dual cochains.""" def __init__(self, dim: int, complex: spx.SimplicialComplex, coeffs: npt.NDArray | Array): super().__init__(dim, False, complex, coeffs)
[docs] def add(c1: Cochain, c2: Cochain) -> Cochain: """Adds two cochains. Args: c1: a cohcain. c2: a cochain. Returns: c1 + c2 """ c = Cochain(c1.dim, c1.is_primal, c1.complex, c1.coeffs + c2.coeffs) return c
[docs] def sub(c1: Cochain, c2: Cochain) -> Cochain: """Subtracts two cochains. Args: c1: a cochain. c2: a cochain. Returns: c_1 - c_2 """ c = Cochain(c1.dim, c1.is_primal, c1.complex, c1.coeffs - c2.coeffs) return c
[docs] def scalar_mul(c: Cochain, k: float) -> Cochain: """Multiplies a cochain by a scalar. Args: c: a cochain. k: a scalar. Returns: cochain with coefficients equal to k*(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, k*c.coeffs) return C
[docs] def cochain_mul(c1: Cochain, c2: Cochain) -> Cochain: """Multiplies two cochain component-wise Args: c1: a cochain. c2: a cochain. Returns: cochain with coefficients = c1*c2. """ assert (c1.is_primal == c2.is_primal) return Cochain(c1.dim, c1.is_primal, c1.complex, c1.coeffs*c2.coeffs)
[docs] def identity(c: Cochain) -> Cochain: """Implements the identity operator. Args: c: a cochain. Returns: the same cochain. """ return c
[docs] def sin(c: Cochain) -> Cochain: """Computes the sin of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to sin(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.sin(c.coeffs)) return C
[docs] def arcsin(c: Cochain) -> Cochain: """Computes the arcsin of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to arcsin(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.arcsin(c.coeffs)) return C
[docs] def cos(c: Cochain) -> Cochain: """Computes the cos of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to cos(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.cos(c.coeffs)) return C
[docs] def arccos(c: Cochain) -> Cochain: """Computes the arcsin of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to arccos(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.arccos(c.coeffs)) return C
[docs] def exp(c: Cochain) -> Cochain: """Compute the exp of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to exp(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.exp(c.coeffs)) return C
[docs] def log(c: Cochain) -> Cochain: """Computes the natural log of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to log(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.log(c.coeffs)) return C
[docs] def sqrt(c: Cochain) -> Cochain: """Compute the sqrt of a cochain. Args: c: a cochain. Returns: cochain with coefficients equal to sqrt(c.coeffs). """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.sqrt(c.coeffs)) return C
[docs] def square(c: Cochain) -> Cochain: """Computes the square of a cochain. Args: c: a cochain. Returns: cochain with coefficients squared. """ C = Cochain(c.dim, c.is_primal, c.complex, dt.backend.square(c.coeffs)) return C
[docs] def coboundary(c: Cochain) -> Cochain: """Implements the coboundary operator. Args: c: a cochain. Returns: the cochain obtained by taking the coboundary of c. """ # initialize dc dc = Cochain(dim=c.dim + 1, is_primal=c.is_primal, complex=c.complex, coeffs=jnp.empty_like(c.coeffs, dtype=dt.float_dtype)) # apply coboundary matrix (transpose of the primal boundary matrix) to the # array of coefficients of the cochain. if c.is_primal: # get the appropriate (primal) boundary matrix cbnd_coo = c.complex.boundary[c.dim + 1] dc.coeffs = spmm.spmm(cbnd_coo, c.coeffs, transpose=True, shape=c.complex.S[c.dim+1].shape[0]) else: # FIXME: change sign of the boundary before applying it? bnd_coo = c.complex.boundary[c.complex.dim - c.dim] dc.coeffs = spmm.spmm(bnd_coo, c.coeffs, transpose=False, shape=c.complex.S[c.complex.dim-c.dim-1].shape[0]) dc.coeffs *= (-1)**(c.complex.dim - c.dim) return dc
[docs] def star(c: Cochain) -> Cochain: """Implements the diagonal Hodge star operator (see Grinspun et al.). Args: c: a cochain. Returns: the dual cochain obtained applying the Hodge star operator. """ star_c = Cochain(dim=c.complex.dim - c.dim, is_primal=not c.is_primal, complex=c.complex, coeffs=jnp.empty_like(c.coeffs, dtype=dt.float_dtype)) if c.is_primal: star_c.coeffs = (c.complex.hodge_star[c.dim]*c.coeffs.T).T else: star_c.coeffs = ( c.complex.hodge_star_inverse[c.complex.dim - c.dim]*c.coeffs.T).T return star_c
[docs] def inner(c1: Cochain, c2: Cochain) -> Array: """Computes the inner product between two cochains. Args: c1: a cochain. c2: a cochain. Returns: inner product between c1 and c2. """ star_c_2 = star(c2) n = c1.complex.dim # dimension of the complexes must agree assert (n == c2.complex.dim) # ndim must agree assert c1.coeffs.ndim == c2.coeffs.ndim if c1.coeffs.ndim == 1: inner_product = dt.backend.dot(c1.coeffs, star_c_2.coeffs) elif c1.coeffs.ndim == 2: inner_product = dt.backend.sum(c1.coeffs * star_c_2.coeffs) elif c1.coeffs.ndim == 3: c1_coeffs_T = dt.backend.transpose(c1.coeffs, axes=(0, 2, 1)) inner_product_per_cell = dt.backend.trace( c1_coeffs_T @ star_c_2.coeffs, axis1=1, axis2=2) inner_product = dt.backend.sum(inner_product_per_cell) # NOTE: not sure whether we should keep both Jax and numpy as backends and allow for # different return types check_type(inner_product, npt.NDArray | Array) return inner_product
[docs] def codifferential(c: Cochain) -> Cochain: """Implements the discrete codifferential. Args: c: a cochain. Returns: the discrete codifferential of c. """ k = c.dim n = c.complex.dim cob = coboundary(star(c)) if c.is_primal: return Cochain(k-1, c.is_primal, c.complex, (-1)**(n*(k-1)+1)*star(cob).coeffs) return Cochain(k-1, c.is_primal, c.complex, (-1)**(k*(n+1-k))*star(cob).coeffs)
[docs] def laplacian(c: Cochain) -> Cochain: """Implements the discrete Laplace-de Rham (or Laplace-Beltrami) operator. (https://en.wikipedia.org/wiki/Laplace%E2%80%93Beltrami_operator) Args: c: a cochain. Returns: a cochain. """ if c.dim == 0: laplacian = codifferential(coboundary(c)) else: laplacian = add(codifferential(coboundary(c)), coboundary(codifferential(c))) return laplacian
[docs] def transpose(c: Cochain) -> Cochain: """Compute the transpose of a tensor-valued cochain. Args: c: a tensor-valued cochain. Returns: its transpose. """ return Cochain(c.dim, c.is_primal, c.complex, jnp.transpose(c.coeffs, axes=(0, 2, 1)))
[docs] def trace(c: Cochain) -> Cochain: """Compute the trace of a tensor-valued cochain. Args: c: a tensor-valued cochain. Returns: its trace. """ return Cochain(c.dim, c.is_primal, c.complex, jnp.trace(c.coeffs, axis1=1, axis2=2))
[docs] def sym(c: Cochain) -> Cochain: """Compute the symmetric part of a tensor-valued cochain. Args: c: a tensor-valued cochain. Returns: its symmetric part. """ return scalar_mul(add(c, transpose(c)), 0.5)
[docs] def convolution(c: Cochain, kernel: Cochain, kernel_window: float) -> Cochain: """ Compute the convolution between two scalar 0-cochains. Args: c: a scalar 0-cochain. kernel: the scalar 0-cochain kernel. kernel_window: the kernel window. Returns: the convolution rho*kernel. """ # FIXME: add the docs n = len(c.coeffs) star_c_coeffs_flatten = star(c).coeffs.flatten() # Extract the active part of the kernel kernel_non_zero = kernel.coeffs[:kernel_window] # Compute the convolution using JAX's convolution function # NOTE: we reverse the kernel ([::-1]) as required by the convolution def. # "valid" mode means only return output where signals overlap completely conv_non_zero_coeffs = jnp.convolve( star_c_coeffs_flatten, kernel_non_zero.flatten()[::-1], mode="valid") conv_coeffs = jnp.zeros_like(star_c_coeffs_flatten) # NOTE: last values should be fixed by BCs conv_coeffs = conv_coeffs.at[:n - kernel_window + 1].set(conv_non_zero_coeffs) conv = Cochain(c.dim, c.is_primal, c.complex, conv_coeffs) return conv
[docs] def abs(c: Cochain) -> Cochain: """ Compute the absolute value of a cochain. Args: c: a cochain. Returns: its absolute value. """ return Cochain(c.dim, c.is_primal, c.complex, jnp.abs(c.coeffs))
[docs] def maximum(c_1: Cochain, c_2: Cochain) -> Cochain: """ Compute the component-wise maximum between two cochains. Args: c_1: a cochain. c_2: a cochain. Returns: the component-wise maximum """ return Cochain(c_1.dim, c_1.is_primal, c_1.complex, jnp.maximum(c_1.coeffs, c_2.coeffs))