Source code for dctkit.math.spmm

import jax.ops as ops
from typing import Tuple
from jax import Array
import numpy.typing as npt


[docs] def spmm(A: Tuple[Array | npt.NDArray, Array | npt.NDArray, Array | npt.NDArray], v: Array | npt.NDArray, transpose=False, shape=None) -> Array: """Performs the matrix-matrix product between a sparse matrix in COO format and a dense matrix or column vector. Args: A: tuple (rows,cols,values) representing the sparse matrix in COO format. v: matrix or column vector. transpose: whether to transpose A before multiplication. shape: the number of rows of the matrix A. Returns: the result of the matrix-matrix product. """ assert v.ndim > 1 rows, cols, vals = A if transpose: vv = v.take(rows, axis=0) else: vv = v.take(cols, axis=0) # NOTE: make vals a column vector # NOTE: the following formula is basically equivalent to # prod = vals[:, None] * vv with the advantage that it works # also in the case that vv is a tensor prod = vals.reshape(-1, *([1] * (vv.ndim - 1))) * vv if transpose: result = ops.segment_sum(prod, cols, shape) else: result = ops.segment_sum(prod, rows, shape) return result