"""Merge broadband MATLAB ``.mat`` shade files (port of Matlab ``merge_shd_files.m``)."""
from __future__ import annotations
import io
from typing import Any
import numpy as np
from at_py.readwrite.mat_bundle import load_mat_normalized
_REQUIRED = ("PlotTitle", "atten", "Pos", "freq0", "freqVec", "pressure", "PlotType")
def _require_scipy_savemat() -> Any:
"""Import ``scipy.io`` or raise with install hint (for ``savemat``)."""
try:
import scipy.io as sio # noqa: PLC0415
except ImportError as e:
from at_py.readwrite.mat_bundle import _import_error_mat
raise _import_error_mat("scipy is required to write merged SHD MAT files.") from e
return sio
def _deep_equal_at(a: Any, b: Any) -> bool:
"""Structural equality similar to Matlab ``isequal`` for normalized MAT payloads."""
if isinstance(a, dict) and isinstance(b, dict):
if set(a.keys()) != set(b.keys()):
return False
return all(_deep_equal_at(a[k], b[k]) for k in a)
if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
if len(a) != len(b):
return False
return all(_deep_equal_at(x, y) for x, y in zip(a, b, strict=True))
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
if a.shape != b.shape:
return False
if a.dtype == b.dtype:
return np.array_equal(a, b)
if np.issubdtype(a.dtype, np.number) and np.issubdtype(b.dtype, np.number):
return np.array_equal(np.asarray(a, dtype=np.float64), np.asarray(b, dtype=np.float64))
return np.array_equal(a, b)
if isinstance(a, (np.generic, float, int)) and isinstance(b, (np.generic, float, int)):
return bool(np.asarray(a).reshape(-1)[0] == np.asarray(b).reshape(-1)[0])
return a == b
[docs]
def merge_shd_mat_bytes(mats: list[bytes]) -> bytes:
"""Merge several broadband SHD ``.mat`` buffers into one (Matlab ``merge_shd_files``).
Input buffers must each contain the variables ``PlotTitle``, ``atten``, ``Pos``,
``freq0``, ``freqVec``, ``pressure``, and ``PlotType`` (as produced by Acoustics
Toolbox). ``atten``, ``Pos``, and ``PlotType`` must match across inputs; receiver
grid dimensions (all ``pressure`` axes except frequency) must agree.
Frequencies from all inputs are concatenated, checked for duplicates, then sorted
ascending with ``pressure`` rows permuted to match (same as Matlab).
**Output format:** SciPy ``savemat`` **format 5** (MATLAB v7 compatible), not HDF5
v7.3—Matlab’s ``merge_shd_files`` uses ``-v7.3`` for large files; use this merged
file with ``read_shd_from_mat`` / ``load_mat_normalized`` the same way as classic MAT.
Requires optional dependency **SciPy** (``pip install 'oalib-at-py[mat]'``).
"""
if not mats:
raise ValueError("merge_shd_mat_bytes: need at least one .mat buffer")
bundles = [load_mat_normalized(m) for m in mats]
size_pressure: tuple[int, ...] | None = None
for j, b in enumerate(bundles):
v = b.variables
miss = [k for k in _REQUIRED if k not in v]
if miss:
raise KeyError(f"merged SHD: missing keys {miss} in input {j}")
p = np.asarray(v["pressure"])
sp = p.shape
if j > 0:
if size_pressure is not None and sp[1:] != size_pressure[1:]:
raise ValueError("merge_shd_mat_bytes: pressure field trailing dimensions disagree")
else:
size_pressure = sp
ref0 = bundles[0].variables
for j in range(1, len(bundles)):
v = bundles[j].variables
for key in ("atten", "Pos", "PlotType"):
if not _deep_equal_at(ref0[key], v[key]):
raise ValueError(f"merge_shd_mat_bytes: field {key!r} differs between inputs")
freq_parts = [np.asarray(b.variables["freqVec"], dtype=np.float64).reshape(-1) for b in bundles]
freq_cat = np.concatenate(freq_parts, axis=0)
if freq_cat.size != np.unique(freq_cat).size:
raise ValueError("merge_shd_mat_bytes: duplicate frequencies in inputs")
p_parts = [np.asarray(b.variables["pressure"]) for b in bundles]
pressure_cat = np.concatenate(p_parts, axis=0)
order = np.argsort(freq_cat, kind="mergesort")
freq_out = freq_cat[order]
pressure_out = pressure_cat[order, ...]
sio = _require_scipy_savemat()
buf = io.BytesIO()
out_vars = {
"PlotTitle": ref0["PlotTitle"],
"atten": ref0["atten"],
"Pos": ref0["Pos"],
"freq0": ref0["freq0"],
"freqVec": freq_out,
"pressure": pressure_out,
"PlotType": ref0["PlotType"],
}
sio.savemat(buf, out_vars, format="5", do_compression=True)
return buf.getvalue()