Source code for aljabr.concrete

# Copyright (c) 2026 François Orieux <francois.orieux@universite-paris-saclay.fr>

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the " Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice (including the next
# paragraph) shall be included in all copies or substantial portions of the
# Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""The ``concrete`` module
=======================

Concrete implementations of common linear operators: convolutions (FFT-based,
direct, circulant), discrete transforms (DFT, DWT, wavelet analysis/synthesis),
diagonal, difference, sampling operators, ...

"""

import types

import numpy as np
import array_api_compat as arr_api

import udft

try:
    import pywt
except ImportError:
    pywt: types.ModuleType | None = None

from .linop import LinOp, Shape, Array

__all__ = [
    "Identity",
    "Diag",
    "DFT",
    "RealDFT",
    "Conv",
    "DirectConv",
    "FreqFilter",
    "CircConv",
    "Diff",
    "Sampling",
    "Slice",
    "DWT",
    "Analysis2",
    "Synthesis2",
]


[docs] class Identity(LinOp): """Identity operator of specific shape. Parameters ---------- shape : tuple of int The shape of the input and output. name : str, optional Name of the operator. """ def __init__(self, shape: Shape, name: str = "I"): super().__init__(shape, shape, name=name)
[docs] def forward(self, point: Array) -> Array: return point
[docs] def adjoint(self, point: Array) -> Array: return point
[docs] def asmatrix(self, like: Array | None = None) -> Array: xp = arr_api.get_namespace(like) if like is not None else np return xp.eye(self.isize)
[docs] class Diag(LinOp): """Diagonal operator. Parameters ---------- diag : Array The diagonal values. Input and output have the same shape as `diag`. The array namespace is inferred from this array. name : str, optional Name of the operator. """ def __init__(self, diag: Array, name: str = "D"): xp = arr_api.get_namespace(diag) self.diag = xp.asarray(diag) super().__init__( self.diag.shape, self.diag.shape, name=name, )
[docs] def forward(self, point: Array) -> Array: return self.diag * point
[docs] def adjoint(self, point: Array) -> Array: xp = arr_api.get_namespace(self.diag) return xp.conj(self.diag) * point
[docs] def fwadj(self, point: Array) -> Array: xp = arr_api.get_namespace(self.diag) return xp.abs(self.diag) ** 2 * point
[docs] def asmatrix(self, like: Array | None = None) -> Array: if like is None: xp = arr_api.get_namespace(self.diag) else: xp = arr_api.get_namespace(like) return xp.diag(xp.reshape(self.diag, (-1,)))
[docs] class DFT(LinOp): """Discrete Fourier Transform on the last N axes. Parameters ---------- shape : tuple of int The shape of the input. ndim : int The number of last axes over which to compute the DFT. name : str, optional Name of the operator. """ def __init__(self, shape: Shape, ndim: int, name: str = "DFT"): self._udft = udft super().__init__(shape, shape, name=name) self.dim = ndim
[docs] def forward(self, point: Array) -> Array: return self._udft.dftn(point, ndim=self.dim)
[docs] def adjoint(self, point: Array) -> Array: return self._udft.idftn(point, ndim=self.dim)
[docs] def fwadj(self, point: Array) -> Array: return point
[docs] class RealDFT(LinOp): """Real Discrete Fourier Transform on the last N axes. Parameters ---------- shape : tuple of int The shape of the input. ndim : int The number of last axes over which to compute the DFT. name : str, optional Name of the operator. """ def __init__(self, shape: Shape, ndim: int, name: str = "rDFT"): self._udft = udft super().__init__(shape, shape[:-1] + (shape[-1] // 2 + 1,), name=name) self.dim = ndim
[docs] def forward(self, point: Array) -> Array: return self._udft.rdftn(point, ndim=self.dim)
[docs] def adjoint(self, point: Array) -> Array: return self._udft.irdftn(point, self.ishape[-self.dim :])
[docs] def fwadj(self, point: Array) -> Array: return point
[docs] class Conv(LinOp): """ND convolution on the last `N` axes. Does not assume a periodic or circular boundary condition. Parameters ---------- ir : Array The impulse response. Must have at least `dim` dimensions. The array namespace is inferred from this array. ishape : tuple of int The shape of the input. Images are on the last `dim` axes. dim : int Number of last axes over which convolution applies. name : str, optional Name of the operator. Attributes ---------- imp_resp : Array The impulse response. freq_resp : Array The frequency response. dim : int The last `dim` axes where convolution applies. Notes ----- Uses FFT internally for fast computation. The `forward` method is equivalent to "valid" boundary condition and `adjoint` is equivalent to "full" boundary condition with zero filling. Uses the array namespace of the impulse response. """ def __init__(self, ir: Array, ishape: Shape, dim: int, name: str = "Conv"): self._udft = udft # oshape: batch dims are unchanged (pad=1 → s-1+1=s), convolved dims # shrink by the kernel size (pad=K → s-K+1, valid convolution). # # Example: ishape=(B, H, W), ir.shape=(Kh, Kw), dim=2 # # pads = (1,) + (Kh, Kw) → oshape = (B, H-Kh+1, W-Kw+1) super().__init__( ishape=ishape, oshape=tuple( s - pad + 1 for (s, pad) in zip(ishape, (len(ishape) - dim) * (1,) + ir.shape) ), name=name, ) self.dim = dim self.imp_resp = ir self.freq_resp = self._udft.ir2fr(ir, self.ishape[-dim:]) self.margins = ir.shape[-dim:] # Extract the valid part of the circular convolution result. # The kernel is centered at K//2 (ifftshift convention). For each # spatial axis of size N with kernel size K: # start = K//2 # end = start + (N - K + 1) = N - (K+1)//2 + 1 # This works for both odd and even K. The batch axes use slice(None). # Note: use ishape[len(ishape)-dim+i] (spatial axes), not ishape[i] # (which would index batch axes when len(ishape) > dim). self._slices: tuple = tuple( [slice(None)] * (len(ishape) - dim) + [ slice( ir.shape[i] // 2, ishape[len(ishape) - dim + i] - (ir.shape[i] + 1) // 2 + 1, ) for i in range(dim) ] ) def _dft(self, point: Array) -> Array: return self._udft.rdftn(point, self.dim) def _idft(self, point: Array) -> Array: return self._udft.irdftn(point, self.ishape[-self.dim :])
[docs] def forward(self, point: Array) -> Array: return self._idft(self._dft(point) * self.freq_resp)[self._slices]
[docs] def adjoint(self, point: Array) -> Array: xp = arr_api.get_namespace(self.freq_resp) out = xp.zeros(self.ishape, dtype=point.dtype) if arr_api.is_jax_array(out): out = out.at[self._slices].set(point) else: out[self._slices] = point return self._idft(self._dft(out) * xp.conj(self.freq_resp))
[docs] class DirectConv(LinOp): """Direct convolution. The convolution is performed on the last N axes where N = ir.ndim. Parameters ---------- ir : Array The impulse response (numpy array). ishape : tuple of int The shape of the input array. name : str, optional Name of the operator. Notes ----- Numpy-only. Uses `scipy.signal.oaconvolve` (Overlap-Add method), which is generally faster than FFT-based convolution when one array is much larger than the other. Requires scipy. """ def __init__(self, ir: Array, ishape: Shape, name: str = "DConv"): try: from scipy.signal import oaconvolve except ImportError as e: raise ImportError("scipy is required for DirectConv") from e oshape = tuple( ( ishape[idx] if idx < len(ishape) - len(ir.shape) else ishape[idx] - ir.shape[idx - (len(ishape) - len(ir.shape))] + 1 ) for idx in range(len(ishape)) ) super().__init__( ishape=ishape, oshape=oshape, name=name, ) self._conv = oaconvolve self._ir = np.reshape( np.asarray(ir), (len(self.ishape) - ir.ndim) * (1,) + ir.shape ) @property def ir(self) -> Array: """The impulse response.""" return np.squeeze(self._ir)
[docs] def forward(self, point: Array) -> Array: return self._conv(point, self._ir, mode="valid")
[docs] def adjoint(self, point: Array) -> Array: return self._conv(point, np.flip(self._ir), mode="full")
[docs] class FreqFilter(Diag): """Frequency filter in Fourier space. Parameters ---------- ir : Array The impulse response. ishape : tuple of int The shape of the input array (used to compute the frequency response). name : str, optional Name of the operator. Attributes ---------- diag : Array The frequency response of the filter. Notes ----- Almost like diagonal but assumes a complex Fourier space and is defined by an impulse response. If you have the frequency response, just use Diag. """ def __init__(self, ir: Array, ishape: Shape, name: str = "Filter"): super().__init__(udft.ir2fr(ir, ishape), name=name)
[docs] class CircConv(LinOp): """Circulant (periodic) convolution. Parameters ---------- imp_resp : Array The impulse response. shape : tuple of int Shape of the input and output arrays. name : str, optional Name of the operator. Attributes ---------- imp_resp : Array The impulse response. """ def __init__(self, imp_resp: Array, shape: Shape, name: str = "CConv"): self._udft = udft self.imp_resp = imp_resp self.ffilter = FreqFilter(imp_resp, shape) super().__init__(ishape=shape, oshape=shape, name=name) @property def freq_resp(self) -> Array: """The frequency response.""" return self.ffilter.diag def _dft(self, arr: Array) -> Array: return self._udft.rdftn(arr, len(self.ishape)) def _idft(self, arr: Array) -> Array: return self._udft.irdftn(arr, self.oshape)
[docs] def forward(self, point: Array) -> Array: return self._idft(self.ffilter.forward(self._dft(point)))
[docs] def adjoint(self, point: Array) -> Array: return self._idft(self.ffilter.adjoint(self._dft(point)))
[docs] def fwadj(self, point: Array) -> Array: return self._idft(self.ffilter.fwadj(self._dft(point)))
[docs] class Diff(LinOp): """Difference operator. Compute the first-order differences along an axis. Parameters ---------- axis : int The axis along which to perform the diff. ishape : tuple of int The shape of the input array. name : str, optional Name of the operator. Attributes ---------- axis : int The axis along which the differences are performed. """ def __init__(self, axis: int, ishape: Shape, name: str = "Diff"): oshape = list(ishape) oshape[axis] = ishape[axis] - 1 super().__init__(ishape, tuple(oshape), name=name + f"[{axis}]") self.axis = axis
[docs] def forward(self, point: Array) -> Array: """The forward application `A·x`. This corresponds to the application of the following matrix in 1D:: -1 1 0 0 0 -1 1 0 0 0 -1 1 """ xp = arr_api.get_namespace(point) return xp.diff(point, axis=self.axis)
[docs] def adjoint(self, point: Array) -> Array: """The adjoint application `Aᴴ·y`. This corresponds to the application of the following matrix in 1D:: -1 0 0 1 -1 0 0 1 -1 0 0 1 """ xp = arr_api.get_namespace(point) return -xp.diff(point, prepend=0, append=0, axis=self.axis)
[docs] class Sampling(LinOp): """Sampling operator using numpy fancy indexing. Numpy-only. Index is a tuple of index arrays as in numpy fancy indexing. Parameters ---------- ishape : tuple of int The shape of the input array. index : tuple of array of int Tuple of index arrays, one per dimension, as in numpy fancy indexing. All arrays must have the same shape, which becomes the output shape. """ def __init__(self, ishape: Shape, index: tuple): super().__init__(ishape, index[0].shape, name="Sampling") self.index = index
[docs] def forward(self, point: Array) -> Array: return point[self.index]
[docs] def adjoint(self, point: Array) -> Array: # Alternative that needs to be tested: out = np.zeros(self.ishape, # dtype=point.dtype); np.add.at(out, self.index, point) flat_index = np.ravel_multi_index(self.index, self.ishape) flat = flat_index.ravel() minlength = np.prod(self.ishape) w = point.ravel() if np.iscomplexobj(point): result = np.bincount( flat, weights=w.real, minlength=minlength ) + 1j * np.bincount(flat, weights=w.imag, minlength=minlength) else: result = np.bincount(flat, weights=w, minlength=minlength) return np.reshape(result, self.ishape)
[docs] class Slice(LinOp): """Equivalent to obj[::2, 1, ...] etc. Parameters ---------- ishape : tuple of int The shape of the input array. idx : index expression The index expression to apply (use ``np.index_exp`` to build it). The output shape is inferred as ``np.empty(ishape)[idx].shape``. See Also -------- Sampling : when you have an array of indices that can handle multiple sampling of the same value. Notes ----- Use `np.index_exp` to build the `idx` argument. Examples -------- >>> s = Slice((10, 10), idx=np.index_exp[::2, 1]) >>> y = s.forward(np.empty((10, 10))) # shape (5,) >>> x = s.adjoint(y) # shape (10, 10) """ def __init__(self, ishape: Shape, idx: tuple): super().__init__(ishape, np.empty(ishape)[idx].shape, name=f"S[{idx}]") self._idx = idx @property def idx(self) -> tuple: """The index expression.""" return self._idx
[docs] def forward(self, point: Array) -> Array: return point[self._idx]
[docs] def adjoint(self, point: Array) -> Array: xp = arr_api.get_namespace(point) out = xp.zeros(self.ishape, dtype=point.dtype) if arr_api.is_jax_array(out): out = out.at[self._idx].set(point) else: out[self._idx] = point return out
[docs] class DWT(LinOp): """Unitary Discrete Wavelet Transform. Parameters ---------- shape : tuple of int The input shape. level : int, optional The decomposition level. wavelet : str, optional The wavelet to use. name : str, optional Name of the operator. Attributes ---------- wlt : str The wavelet. lvl : int The decomposition level. Notes ----- NumPy-only. Uses pywt internally. """ def __init__( self, shape: Shape, level: int | None = None, wavelet: str = "haar", name: str = "DWT", ): if pywt is None: raise ImportError("pywt is required for DWT") self._pywt = pywt super().__init__(shape, shape, name=name) self.wlt = wavelet self.lvl = level self._mode = "periodization" self._slices = self._pywt.coeffs_to_array( self._pywt.wavedecn( np.empty(shape), wavelet=wavelet, mode="periodization", level=level ) )[1]
[docs] def forward(self, point: Array) -> Array: return self._pywt.coeffs_to_array( self._pywt.wavedecn( point, wavelet=self.wlt, mode=self._mode, level=self.lvl ) )[0]
[docs] def adjoint(self, point: Array) -> Array: return self._pywt.waverecn( self._pywt.array_to_coeffs(point, self._slices), wavelet=self.wlt, mode=self._mode, )
[docs] def fwadj(self, point: Array) -> Array: return point
[docs] class Analysis2(LinOp): """2D analysis operator with stationary wavelet decomposition. Parameters ---------- shape : tuple of (int, int) The input shape. level : int The decomposition level. wavelet : str, optional The wavelet to use. name : str, optional Name of the operator. Notes ----- NumPy-only. Uses pywt internally. The output is a 3D array where the first axis is the coefficient axis, with the approximation coefficients at index 0 and the detail coefficients at indices 1 to 3*level. The second and third axes are the spatial axes. See `pywt.swt2` documentation for more details on the output format. """ def __init__( self, shape: tuple[int, int], level: int, wavelet: str = "haar", name: str = "A", ): if pywt is None: raise ImportError("pywt is required for Analysis2") self._pywt = pywt super().__init__(shape, (3 * level + 1,) + shape, name=name) self.wlt = wavelet self.lvl = level self.norm = True
[docs] def forward(self, point: Array) -> Array: coeffs = self._pywt.swt2( point, wavelet=self.wlt, level=self.lvl, norm=self.norm, trim_approx=True ) return self.coeffs2cube(coeffs)
[docs] def adjoint(self, point: Array) -> Array: return self._pywt.iswt2(self.cube2coeffs(point), self.wlt, norm=self.norm)
[docs] def cube2coeffs(self, point: Array) -> list: """Return pywt coefficients from 3D array.""" split = np.split(point, 3 * self.lvl + 1, axis=0) coeffs_list: list = [np.squeeze(split[0])] for lvl in range(self.lvl): coeffs_list.append( [ np.squeeze(split[3 * lvl + 1]), np.squeeze(split[3 * lvl + 2]), np.squeeze(split[3 * lvl + 3]), ] ) return coeffs_list
[docs] @staticmethod def coeffs2cube(coeffs: list) -> Array: """Return 3D array from pywt coefficients.""" clist = [coeffs[0][np.newaxis, ...]] for coeff in coeffs[1:]: clist.extend( [ coeff[0][np.newaxis, ...], coeff[1][np.newaxis, ...], coeff[2][np.newaxis, ...], ] ) return np.concatenate(clist, axis=0)
[docs] def im2coeffs(self, point: Array) -> list: """Return pywt coefficients from an image array.""" split = np.split(point, 3 * self.lvl + 1, axis=1) coeffs_list: list = [split[0]] for lvl in range(self.lvl): coeffs_list.append( [split[3 * lvl + 1], split[3 * lvl + 2], split[3 * lvl + 3]] ) return coeffs_list
[docs] @staticmethod def coeffs2im(coeffs: list) -> Array: """Return an image array from pywt coefficients.""" clist = [coeffs[0]] for coeff in coeffs[1:]: clist.extend([coeff[0], coeff[1], coeff[2]]) return np.concatenate(clist, axis=1)
[docs] def cube2im(self, cube: Array) -> Array: """Convert 3D coefficient cube to image-stacked array.""" return self.coeffs2im(self.cube2coeffs(cube))
[docs] def im2cube(self, im: Array) -> Array: """Convert image-stacked array to 3D coefficient cube.""" return self.coeffs2cube(self.im2coeffs(im))
[docs] def get_irs(self) -> Array: """Return the impulse response of the filter bank.""" iarr = np.zeros(self.ishape) iarr[0, 0] = 1 return self.forward(iarr)
[docs] def get_frs(self) -> Array: """Return the frequency response of the filter bank.""" return np.ascontiguousarray(np.fft.rfftn(self.get_irs(), self.ishape[-2:]))
[docs] class Synthesis2(LinOp): """2D synthesis operator with stationary wavelet decomposition. Parameters ---------- shape : tuple of (int, int) The input shape. level : int The decomposition level. wavelet : str, optional The wavelet to use. name : str, optional Name of the operator. """ def __init__( self, shape: tuple[int, int], level: int, wavelet: str = "haar", name: str = "S", ): self.analysis = Analysis2(shape, level, wavelet) super().__init__(self.analysis.oshape, self.analysis.ishape, name=name) self.wlt = self.analysis.wlt self.lvl = self.analysis.lvl
[docs] def forward(self, point: Array) -> Array: return self.analysis.adjoint(point)
[docs] def adjoint(self, point: Array) -> Array: return self.analysis.forward(point)
[docs] def cube2coeffs(self, point: Array) -> list: """Return pywt coefficients from 3D array.""" return self.analysis.cube2coeffs(point)
[docs] def coeffs2cube(self, coeffs: list) -> Array: """Return 3D array from pywt coefficients.""" return self.analysis.coeffs2cube(coeffs)
[docs] def im2coeffs(self, point: Array) -> list: """Return pywt coefficients from image.""" return self.analysis.im2coeffs(point)
[docs] def coeffs2im(self, coeffs: list) -> Array: """Return image from pywt coefficients.""" return self.analysis.coeffs2im(coeffs)
[docs] def cube2im(self, cube: Array) -> Array: """Convert 3D coefficient cube to image-stacked array.""" return self.analysis.cube2im(cube)
[docs] def im2cube(self, im: Array) -> Array: """Convert image-stacked array to 3D coefficient cube.""" return self.analysis.im2cube(im)
[docs] def get_irs(self) -> Array: """Return the impulse response of the filter bank.""" return np.flip(self.analysis.get_irs(), axis=(1, 2))
[docs] def get_frs(self) -> Array: """Return the frequency response of the filter bank.""" return np.ascontiguousarray(np.fft.rfftn(self.get_irs(), self.ishape[-2:]))