Source code for dctkit.dec.wedge

import jax.numpy as jnp
from dctkit.dec import cochain as C
from scipy.special import factorial


[docs] def wedge(c_1: C.Cochain, c_2: C.Cochain) -> C.Cochain: """Computes the wedge product of two cochains. Args: c_1: the first cochain. c_2: the second cochain. Returns: a new cochain representing the wedge product of `c_1` and `c_2`. Raises: Exception: If attempting a primal-dual wedge product, which is undefined. AssertionError: If computing a dual wedge product with dimension greater than 1, which is not defined. """ # if c_2.dim > c_1.dim, use property of the wedge product if c_2.dim < c_1.dim: return C.scalar_mul(wedge(c_2, c_1), (-1)**(c_1.dim*c_2.dim)) # so from now on, c_1.dim <= c_2.dim # if c_1 and c_2 are 0-cochains, wedge = cochain product if c_1.dim == 0 and c_2.dim == 0: return C.cochain_mul(c_1, c_2) wedge_coch_dim = c_1.dim + c_2.dim S = c_1.complex # extract the matrix of indices of the wedge_coch_dim+1-simplices (primal/dual) if c_1.is_primal and c_2.is_primal: # primal wedge S_list = S.S type_ = "primal" elif (not c_1.is_primal) and not (c_2.is_primal): # dual wedge is only defined for wedge_coch_dim <=1 assert wedge_coch_dim <= 1 S_list = S.S_dual type_ = "dual" else: raise Exception("The primal-dual wedge product is not defined.") num_c_2_dim_simplex = S_list[c_2.dim].shape[0] if c_1.dim == 0: # in this case the wedge product has an easier formulation wedge_vec = jnp.zeros((num_c_2_dim_simplex, c_2.dim+1)) for i in range(c_2.dim+1): wedge_vec = wedge_vec.at[:, i].set( c_2.dim*c_1.coeffs[S_list[c_2.dim][:, i]].flatten()) wedge_vec *= c_2.coeffs else: # extract the permutation vectors and compute its signs sgn_perm_vec = S.cup_lookup[(type_, c_1.dim, c_2.dim)]["sgn_perm_vec"] # extract cup product and sgn orientations tables lookup = S.cup_lookup[(type_, c_1.dim, c_2.dim)]["lookup"] sgn_orient = S.cup_lookup[(type_, c_1.dim, c_2.dim)]["sgn_orient"] # compute the value of the cup product extracting values from the # lookup table. We are following Definition 7.2.1 in Hirani's thesis cup_prod_no_sign = c_1.coeffs[lookup[:, 0]]*c_2.coeffs[lookup[:, 1]] cup_prod = cup_prod_no_sign.squeeze(-1)*sgn_orient[:, 0]*sgn_orient[:, 1] wedge_vec = sgn_perm_vec*cup_prod weight = 1/factorial(wedge_coch_dim+1, True) wedge_coch_coeffs = weight*jnp.sum(wedge_vec, axis=1) return C.Cochain(wedge_coch_dim, c_1.is_primal, S, wedge_coch_coeffs)