Source code for at_py.readwrite.merge_shd_mat

"""Merge broadband MATLAB ``.mat`` shade files (port of Matlab ``merge_shd_files.m``)."""

from __future__ import annotations

import io
from typing import Any

import numpy as np

from at_py.readwrite.mat_bundle import load_mat_normalized

_REQUIRED = ("PlotTitle", "atten", "Pos", "freq0", "freqVec", "pressure", "PlotType")


def _require_scipy_savemat() -> Any:
    """Import ``scipy.io`` or raise with install hint (for ``savemat``)."""
    try:
        import scipy.io as sio  # noqa: PLC0415
    except ImportError as e:
        from at_py.readwrite.mat_bundle import _import_error_mat

        raise _import_error_mat("scipy is required to write merged SHD MAT files.") from e
    return sio


def _deep_equal_at(a: Any, b: Any) -> bool:
    """Structural equality similar to Matlab ``isequal`` for normalized MAT payloads."""
    if isinstance(a, dict) and isinstance(b, dict):
        if set(a.keys()) != set(b.keys()):
            return False
        return all(_deep_equal_at(a[k], b[k]) for k in a)
    if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
        if len(a) != len(b):
            return False
        return all(_deep_equal_at(x, y) for x, y in zip(a, b, strict=True))
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        if a.shape != b.shape:
            return False
        if a.dtype == b.dtype:
            return np.array_equal(a, b)
        if np.issubdtype(a.dtype, np.number) and np.issubdtype(b.dtype, np.number):
            return np.array_equal(np.asarray(a, dtype=np.float64), np.asarray(b, dtype=np.float64))
        return np.array_equal(a, b)
    if isinstance(a, (np.generic, float, int)) and isinstance(b, (np.generic, float, int)):
        return bool(np.asarray(a).reshape(-1)[0] == np.asarray(b).reshape(-1)[0])
    return a == b


[docs] def merge_shd_mat_bytes(mats: list[bytes]) -> bytes: """Merge several broadband SHD ``.mat`` buffers into one (Matlab ``merge_shd_files``). Input buffers must each contain the variables ``PlotTitle``, ``atten``, ``Pos``, ``freq0``, ``freqVec``, ``pressure``, and ``PlotType`` (as produced by Acoustics Toolbox). ``atten``, ``Pos``, and ``PlotType`` must match across inputs; receiver grid dimensions (all ``pressure`` axes except frequency) must agree. Frequencies from all inputs are concatenated, checked for duplicates, then sorted ascending with ``pressure`` rows permuted to match (same as Matlab). **Output format:** SciPy ``savemat`` **format 5** (MATLAB v7 compatible), not HDF5 v7.3—Matlab’s ``merge_shd_files`` uses ``-v7.3`` for large files; use this merged file with ``read_shd_from_mat`` / ``load_mat_normalized`` the same way as classic MAT. Requires optional dependency **SciPy** (``pip install 'oalib-at-py[mat]'``). """ if not mats: raise ValueError("merge_shd_mat_bytes: need at least one .mat buffer") bundles = [load_mat_normalized(m) for m in mats] size_pressure: tuple[int, ...] | None = None for j, b in enumerate(bundles): v = b.variables miss = [k for k in _REQUIRED if k not in v] if miss: raise KeyError(f"merged SHD: missing keys {miss} in input {j}") p = np.asarray(v["pressure"]) sp = p.shape if j > 0: if size_pressure is not None and sp[1:] != size_pressure[1:]: raise ValueError("merge_shd_mat_bytes: pressure field trailing dimensions disagree") else: size_pressure = sp ref0 = bundles[0].variables for j in range(1, len(bundles)): v = bundles[j].variables for key in ("atten", "Pos", "PlotType"): if not _deep_equal_at(ref0[key], v[key]): raise ValueError(f"merge_shd_mat_bytes: field {key!r} differs between inputs") freq_parts = [np.asarray(b.variables["freqVec"], dtype=np.float64).reshape(-1) for b in bundles] freq_cat = np.concatenate(freq_parts, axis=0) if freq_cat.size != np.unique(freq_cat).size: raise ValueError("merge_shd_mat_bytes: duplicate frequencies in inputs") p_parts = [np.asarray(b.variables["pressure"]) for b in bundles] pressure_cat = np.concatenate(p_parts, axis=0) order = np.argsort(freq_cat, kind="mergesort") freq_out = freq_cat[order] pressure_out = pressure_cat[order, ...] sio = _require_scipy_savemat() buf = io.BytesIO() out_vars = { "PlotTitle": ref0["PlotTitle"], "atten": ref0["atten"], "Pos": ref0["Pos"], "freq0": ref0["freq0"], "freqVec": freq_out, "pressure": pressure_out, "PlotType": ref0["PlotType"], } sio.savemat(buf, out_vars, format="5", do_compression=True) return buf.getvalue()