Source code for at_py.readwrite.mat_at

"""Acoustics Toolbox ``.mat`` readers built on :mod:`at_py.readwrite.mat_bundle`.

These are **separate entry points** from :func:`at_py.readwrite.read_shd`,
:func:`at_py.readwrite.read_modes`, and :func:`at_py.readwrite.read_ts` so that
binary/ASCII dispatch stays unambiguous (no magic sniffing of arbitrary ``bytes`` as MAT).

Each function calls :func:`~at_py.readwrite.mat_bundle.load_mat_normalized` and maps the
documented variable sets used by the Matlab ``read_*.m`` ``.mat`` branches.
"""

from __future__ import annotations

from typing import Any

import numpy as np

from at_py.readwrite.mat_bundle import load_mat_normalized
from at_py.readwrite.modes import ModeHalfspace, ModesReadResult
from at_py.readwrite.shd import ShdPos, ShdPosR, ShdPosS, ShdReadResult
from at_py.readwrite.ts import TsPos, TsPosR, TsReadResult


def _as_str(x: Any, *, key: str) -> str:
    """Coerce normalized MAT values to ``str`` (char arrays, scalars, etc.)."""
    if isinstance(x, str):
        return x
    if isinstance(x, bytes):
        return x.decode("utf-8", errors="replace")
    if isinstance(x, np.ndarray):
        if x.dtype.kind in "SU":
            return str(x.reshape(-1)[0]) if x.size else ""
        if x.size == 1 and x.dtype.kind not in "biufc":
            return _as_str(x.reshape(-1)[0], key=key)
        flat = x.reshape(-1)
        if flat.size and flat.dtype.kind in "biufc":
            return "".join(chr(int(v)) for v in flat if int(v) != 0).strip()
    raise TypeError(f"{key!r}: expected a string-like value, got {type(x).__name__}")


def _as_float(x: Any, *, key: str) -> float:
    """Coerce to scalar ``float`` (raises if not size-1)."""
    return float(np.asarray(x, dtype=np.float64).reshape(-1)[0])


def _as_int(x: Any, *, key: str) -> int:
    """Coerce to scalar ``int``."""
    return int(np.asarray(x).reshape(-1)[0])


def _as_float64_vec(x: Any, *, key: str) -> np.ndarray:
    """1-D ``float64`` view of ``x`` (non-empty)."""
    a = np.asarray(x, dtype=np.float64).reshape(-1)
    if a.size == 0:
        raise ValueError(f"{key!r}: expected a non-empty numeric vector")
    return a


def _as_complex_array(x: Any, *, key: str) -> np.ndarray:
    """Return complex array, promoting real arrays to complex if needed."""
    a = np.asarray(x)
    if a.dtype.kind == "c":
        return a
    if a.dtype.kind in "biuf":
        return a.astype(np.complex64)
    raise TypeError(f"{key!r}: expected a numeric or complex array, got {a.dtype!r}")


def _nested_get(d: Any, *parts: str, key: str) -> Any:
    """Walk nested struct dicts ``d[p1][p2]...`` with clear errors."""
    cur = d
    for p in parts:
        if not isinstance(cur, dict):
            raise TypeError(f"{key!r}: expected struct dict at {p!r}, got {type(cur).__name__}")
        if p not in cur:
            raise KeyError(f"{key!r}: missing field {p!r}")
        cur = cur[p]
    return cur


def _mater_to_labels(mater: Any, *, nmedia: int, key: str) -> list[bytes]:
    """Kraken ``Mater`` is often a cell/char matrix; normalize to 8-byte labels."""
    if isinstance(mater, list):
        out: list[bytes] = []
        for i, item in enumerate(mater):
            if isinstance(item, bytes):
                raw = item
            elif isinstance(item, str):
                raw = item.encode("latin-1", errors="replace")
            elif isinstance(item, np.ndarray):
                raw = np.asarray(item).tobytes()
            else:
                raise TypeError(f"{key}[{i}]: unsupported Mater element {type(item).__name__}")
            raw = raw[:8].ljust(8, b"\x00")
            out.append(raw)
        if len(out) != nmedia:
            raise ValueError(f"{key!r}: expected {nmedia} medium labels, got {len(out)}")
        return out

    arr = np.asarray(mater)
    if arr.dtype.kind in "SU":
        if arr.ndim == 2 and arr.shape[0] == nmedia:
            return [bytes(arr[i, :].tobytes()[:8].ljust(8, b"\x00")) for i in range(nmedia)]
        flat = arr.reshape(-1)
        if flat.size == nmedia:
            out = []
            for i in range(nmedia):
                b = str(flat[i]).encode("latin-1", errors="replace")[:8].ljust(8, b"\x00")
                out.append(b)
            return out
    if arr.dtype == np.uint8 or (arr.dtype.kind == "u" and arr.itemsize == 1):
        if arr.ndim == 2 and arr.shape[0] == nmedia:
            labels = []
            for i in range(nmedia):
                raw = np.asarray(arr[i]).tobytes().split(b"\x00", 1)[0]
                labels.append(bytes(raw).ljust(8, b"\x00")[:8].ljust(8, b"\x00"))
            return labels
    raise TypeError(f"{key!r}: unsupported Mater layout {type(mater).__name__}")


def _halfspace_from_mat(d: Any, *, key: str) -> ModeHalfspace:
    """Build :class:`~at_py.readwrite.modes.ModeHalfspace` from a MAT struct."""
    if not isinstance(d, dict):
        raise TypeError(f"{key!r}: expected struct dict, got {type(d).__name__}")
    bc_val = d.get("BC", d.get("bc"))
    if bc_val is None:
        raise KeyError(f"{key!r}: missing BC")
    if isinstance(bc_val, np.ndarray) and bc_val.dtype.kind in "SU":
        bc = str(bc_val.reshape(-1)[0]).strip() or "?"
    elif isinstance(bc_val, np.ndarray) and bc_val.size >= 1:
        el = bc_val.reshape(-1)[0]
        if np.issubdtype(bc_val.dtype, np.integer) or bc_val.dtype == np.uint8:
            bc = chr(int(el))
        else:
            bc = str(el)
    else:
        bc = str(bc_val).strip() or "?"

    def _c128(name: str) -> np.complex128:
        """Scalar complex field ``name`` from struct ``d``."""
        v = d.get(name)
        if v is None:
            raise KeyError(f"{key!r}: missing {name}")
        z = np.asarray(v, dtype=np.complex128).reshape(-1)
        if z.size != 1:
            raise ValueError(f"{key!r}.{name}: expected scalar complex")
        return np.complex128(z[0])

    return ModeHalfspace(
        bc=bc[:1] if bc else "?",
        cp=_c128("cp"),
        cs=_c128("cs"),
        rho=np.float32(_as_float(d["rho"], key=f"{key}.rho")),
        depth=np.float32(_as_float(d["depth"], key=f"{key}.depth")),
    )


[docs] def read_ts_from_mat(data: bytes) -> TsReadResult: """Port of Matlab ``read_ts.m`` for ``.mat`` (loads ``PlotTitle``, ``Pos``, ``tout``, ``RTS``). Applies the same ``RTS = RTS.'`` transpose as the Matlab branch. """ bundle = load_mat_normalized(data) v = bundle.variables need = ("PlotTitle", "Pos", "tout", "RTS") for k in need: if k not in v: raise KeyError(f"read_ts_from_mat: missing required variable {k!r}") plot_title = _as_str(v["PlotTitle"], key="PlotTitle") pos = v["Pos"] rz = _nested_get(pos, "r", "z", key="Pos.r.z") rz = np.asarray(rz, dtype=np.float64).reshape(-1) tout = _as_float64_vec(v["tout"], key="tout") rts = np.asarray(v["RTS"], dtype=np.float64) if rts.ndim != 2: raise ValueError(f"RTS: expected a 2-D array, got shape {rts.shape}") rts = np.ascontiguousarray(rts.T) if rts.shape[0] != tout.shape[0]: raise ValueError(f"RTS.' shape {rts.shape} incompatible with tout length {tout.shape[0]}") return TsReadResult( plot_title=plot_title, pos=TsPos(r=TsPosR(z=rz)), tout=tout, rts=rts, )
[docs] def read_shd_from_mat( data: bytes, *, freq: float | None = None, xs_km: float | None = None, ys_km: float | None = None, greens_function: bool = False, ) -> ShdReadResult: """Port of Matlab ``read_shd.m`` branches ``shdmat`` / ``grnmat``. Required variables (as loaded by Matlab ``load``): ``PlotTitle``, ``PlotType``, ``freqVec``, ``freq0``, ``atten``, ``Pos``, ``pressure``. - ``greens_function=True`` applies ``Pos.r.r = Pos.r.r.'`` like ``grnmat``. - ``freq`` selects the nearest entry in ``freqVec`` (default: first index), matching the binary reader when ``freq`` is omitted. - ``xs_km`` / ``ys_km`` select nearest source grid point in meters via ``abs(Pos.s.x - xs*1000)`` / ``abs(Pos.s.y - ys*1000)`` when ``pressure`` has leading source axes ``(Nsx, Nsy, ...)``. """ bundle = load_mat_normalized(data) v = bundle.variables req = ("PlotTitle", "PlotType", "freqVec", "freq0", "atten", "Pos", "pressure") for k in req: if k not in v: raise KeyError(f"read_shd_from_mat: missing required variable {k!r}") title = _as_str(v["PlotTitle"], key="PlotTitle") plot_type = _as_str(v["PlotType"], key="PlotType") if len(plot_type) < 10: plot_type = plot_type.ljust(10) freq_vec = np.asarray(v["freqVec"], dtype=np.float64).reshape(-1) freq0 = _as_float(v["freq0"], key="freq0") atten = _as_float(v["atten"], key="atten") pos = v["Pos"] if not isinstance(pos, dict): raise TypeError(f"Pos: expected struct dict, got {type(pos).__name__}") theta = np.asarray(_nested_get(pos, "theta", key="Pos.theta"), dtype=np.float64).reshape(-1) sx = np.asarray(_nested_get(pos, "s", "x", key="Pos.s.x"), dtype=np.float64).reshape(-1) sy = np.asarray(_nested_get(pos, "s", "y", key="Pos.s.y"), dtype=np.float64).reshape(-1) sz = np.asarray(_nested_get(pos, "s", "z", key="Pos.s.z"), dtype=np.float32).reshape(-1) rz = np.asarray(_nested_get(pos, "r", "z", key="Pos.r.z"), dtype=np.float32).reshape(-1) rr = np.asarray(_nested_get(pos, "r", "r", key="Pos.r.r"), dtype=np.float64).reshape(-1) if greens_function: rr = np.ascontiguousarray(rr.reshape(1, -1).T.ravel()) nsx = _as_int(pos.get("Nsx", sx.size), key="Pos.Nsx") nsy = _as_int(pos.get("Nsy", sy.size), key="Pos.Nsy") nsz = _as_int(pos.get("Nsz", sz.size), key="Pos.Nsz") nrz = _as_int(pos.get("Nrz", rz.size), key="Pos.Nrz") nrr = _as_int(pos.get("Nrr", rr.size), key="Pos.Nrr") ntheta = _as_int(pos.get("Ntheta", theta.size), key="Pos.Ntheta") nfreq = _as_int(pos.get("Nfreq", freq_vec.size), key="Pos.Nfreq") p = _as_complex_array(v["pressure"], key="pressure").astype(np.complex64, copy=False) use_xy = xs_km is not None or ys_km is not None if use_xy and (xs_km is None or ys_km is None): raise ValueError("read_shd_from_mat: xs_km and ys_km must both be set or both omitted") exp7 = (nsx, nsy, nfreq, ntheta, nsz, nrz, nrr) exp5 = (nfreq, ntheta, nsz, nrz, nrr) if use_xy: if p.ndim != 7: raise ValueError( f"pressure: expected 7-D array {exp7} for xs_km/ys_km selection, got {p.shape}" ) if tuple(int(x) for x in p.shape) != exp7: raise ValueError(f"pressure: expected shape {exp7}, got {p.shape}") ix = int(np.argmin(np.abs(sx - float(xs_km) * 1000.0))) # type: ignore[arg-type] iy = int(np.argmin(np.abs(sy - float(ys_km) * 1000.0))) # type: ignore[arg-type] sub = p[ix, iy, :, :, :, :, :] else: if p.ndim == 7: if tuple(int(x) for x in p.shape) != exp7: raise ValueError(f"pressure: expected shape {exp7} for 7-D pressure, got {p.shape}") sub = p[0, 0, :, :, :, :, :] elif p.ndim == 5: sub = p else: raise ValueError( f"pressure: expected rank 5 {exp5} or rank 7 {exp7}, got rank {p.ndim} {p.shape}" ) if sub.ndim != 5: raise ValueError(f"pressure: internal field must be 5-D {exp5}, got {sub.shape}") if tuple(int(x) for x in sub.shape) != exp5: raise ValueError(f"pressure: expected shape {exp5} after source selection, got {sub.shape}") ifreq = 0 if freq is None else int(np.argmin(np.abs(freq_vec - float(freq)))) p_sel = np.asarray(sub[ifreq, :, :, :, :], dtype=np.complex64, order="C") pt10 = plot_type[:10].ljust(10) if pt10.startswith("TL"): nrcvrs_per_range = nrz elif pt10 == "irregular ": nrcvrs_per_range = 1 else: nrcvrs_per_range = nrz if nrcvrs_per_range == nrz: pressure = p_sel else: if p_sel.shape[2] != 1: raise ValueError( f"irregular plot type expects Nrz axis 1 in pressure, got {p_sel.shape}" ) pressure = p_sel shd_pos = ShdPos( theta=theta, s=ShdPosS(x=sx, y=sy, z=sz), r=ShdPosR(z=rz, r=rr), nsx=nsx, nsy=nsy, nsz=nsz, nrz=nrz, nrr=nrr, ntheta=ntheta, nfreq=nfreq, ) return ShdReadResult( title=title, plot_type=plot_type, freq_vec=freq_vec, freq0=freq0, atten=atten, pos=shd_pos, pressure=pressure, )
def _resolve_modes_phi_k_for_frequency( m: dict[str, Any], *, freq_index: int, nfreq: int, ) -> tuple[np.ndarray, np.ndarray, int]: """Select 2-D ``phi``, 1-D ``k`` for ``freq_index``, and mode count *M* at that frequency.""" phi_raw = _as_complex_array(m["phi"], key="Modes.phi").astype(np.complex64, copy=False) k_raw = _as_complex_array(m["k"], key="Modes.k").astype(np.complex128, copy=False) m_arr = np.asarray(m["M"]) m_flat = m_arr.reshape(-1) if nfreq == 1: if phi_raw.ndim == 1: phi_raw = phi_raw.reshape(-1, 1) elif phi_raw.ndim == 3: if phi_raw.shape[2] == 1: phi_raw = phi_raw[:, :, 0] else: raise ValueError( f"Modes.phi: Nfreq=1 but phi has shape {phi_raw.shape}; " "expected third axis length 1 if 3-D" ) elif phi_raw.ndim > 2: raise ValueError( f"Modes.phi: Nfreq=1 expected 1-D, 2-D, or 3-D with third axis 1, " f"got {phi_raw.ndim} dims" ) if k_raw.ndim == 2: if k_raw.shape[1] == 1: k_vec = k_raw[:, 0].reshape(-1) elif k_raw.shape[0] == 1: k_vec = k_raw[0, :].reshape(-1) else: raise ValueError( f"Modes.k: Nfreq=1 expected 1-D or 2-D with a singleton axis, " f"got shape {k_raw.shape}" ) else: k_vec = k_raw.reshape(-1) num_modes = int(m_flat[0]) if m_flat.size else _as_int(m["M"], key="Modes.M") if phi_raw.ndim != 2: raise ValueError(f"Modes.phi: expected 2-D after normalization, got {phi_raw.shape}") if phi_raw.shape[1] != num_modes or k_vec.shape[0] != num_modes: raise ValueError( f"Modes.phi / k inconsistent with Modes.M={num_modes}: " f"phi {phi_raw.shape}, k {k_vec.shape}" ) return phi_raw, k_vec.astype(np.complex128, copy=False), num_modes if phi_raw.ndim != 3: raise ValueError( f"Modes.phi: for Nfreq={nfreq}, expected shape (nmat, M, Nfreq), got {phi_raw.shape}" ) if phi_raw.shape[2] != nfreq: raise ValueError(f"Modes.phi: third axis length {phi_raw.shape[2]} != Nfreq={nfreq}") if m_flat.size == 1: num_modes = int(m_flat[0]) if phi_raw.shape[1] != num_modes: raise ValueError( f"Modes.phi: second axis length {phi_raw.shape[1]} != Modes.M ({num_modes})" ) phi_f = phi_raw[:, :, freq_index] elif m_flat.size == nfreq: num_modes = int(m_flat[freq_index]) if phi_raw.shape[1] < num_modes: raise ValueError( f"Modes.phi: second axis length {phi_raw.shape[1]} < " f"Modes.M[{freq_index}]={num_modes}" ) phi_f = phi_raw[:, :num_modes, freq_index] else: raise ValueError( f"Modes.M: expected 1 or {nfreq} elements when Nfreq={nfreq}, got shape {m_arr.shape}" ) if k_raw.ndim != 2: raise ValueError( f"Modes.k: for Nfreq={nfreq}, expected a 2-D array (M, Nfreq) or (Nfreq, M), " f"got shape {k_raw.shape}" ) if k_raw.shape[1] == nfreq: k_f = np.asarray(k_raw[:num_modes, freq_index], dtype=np.complex128).reshape(-1) elif k_raw.shape[0] == nfreq: k_f = np.asarray(k_raw[freq_index, :num_modes], dtype=np.complex128).reshape(-1) else: raise ValueError( f"Modes.k: expected first or second axis length {nfreq}, got shape {k_raw.shape}" ) if phi_f.shape[1] != num_modes or k_f.shape[0] != num_modes: raise ValueError( f"after selecting frequency index {freq_index}: phi shape {phi_f.shape}, " f"k length {k_f.shape[0]} vs Modes.M ({num_modes})" ) return phi_f, k_f, num_modes
[docs] def read_modes_from_mat( data: bytes, freq: float, *, modes: list[int] | None = None, ) -> ModesReadResult: """Port of Matlab ``read_modes.m`` for ``.mod.mat`` (``load`` provides ``Modes``). Expects a variable ``Modes`` matching Kraken/Krakel-style saves (see ``krakelM.m``): ``title``, ``Nfreq``, ``Nmedia``, ``N``, ``depth``, ``Mater``, ``rho``, ``freqVec``, ``z``, ``M``, ``Top``, ``Bot``, ``k``, ``phi``. ``modes`` uses **1-based** indices like :func:`~at_py.readwrite.read_modes_bin`. **Multi-frequency** ``.mod.mat`` files must store eigenvectors and wavenumbers in a **stacked** layout so the nearest frequency to ``freq`` can be selected (same idea as Matlab ``read_modes_bin``): - ``phi`` with shape ``(nmat, M, Nfreq)`` - ``k`` with shape ``(M, Nfreq)`` or ``(Nfreq, M)`` ``Modes.M`` may be a scalar (same *M* at every frequency) or a vector of length ``Nfreq`` (mode count per frequency). Single-frequency files keep the Krakel-style 2-D ``phi`` and 1-D ``k``. """ bundle = load_mat_normalized(data) v = bundle.variables if "Modes" not in v: raise KeyError("read_modes_from_mat: missing required variable 'Modes'") m = v["Modes"] if not isinstance(m, dict): raise TypeError(f"Modes: expected struct dict, got {type(m).__name__}") title = _as_str(m["title"], key="Modes.title") nmedia = _as_int(m["Nmedia"], key="Modes.Nmedia") n_per_medium = np.asarray(m["N"], dtype=np.int32).reshape(-1) if n_per_medium.size != nmedia: raise ValueError("Modes.N: length must match Nmedia") depth = np.asarray(m["depth"], dtype=np.float32).reshape(-1) if depth.size != nmedia + 1: raise ValueError("Modes.depth: expected Nmedia+1 interface depths") rho_med = np.asarray(m["rho"], dtype=np.float32).reshape(-1) if rho_med.size != nmedia: raise ValueError("Modes.rho: expected Nmedia values") freq_vec = np.asarray(m["freqVec"], dtype=np.float64).reshape(-1) nfreq = _as_int(m.get("Nfreq", freq_vec.size), key="Modes.Nfreq") if freq_vec.size != nfreq: raise ValueError("Modes.freqVec: length must match Nfreq") freq_index = int(np.argmin(np.abs(freq_vec - float(freq)))) z = np.asarray(m["z"], dtype=np.float32).reshape(-1) ntot = int(z.size) phi_all, k_all, num_modes_file = _resolve_modes_phi_k_for_frequency( m, freq_index=freq_index, nfreq=nfreq ) nmat = int(phi_all.shape[0]) mater = _mater_to_labels(m["Mater"], nmedia=nmedia, key="Modes.Mater") top = _halfspace_from_mat(m["Top"], key="Modes.Top") bot = _halfspace_from_mat(m["Bot"], key="Modes.Bot") if modes is None: mode_list = list(range(1, num_modes_file + 1)) else: mode_list = list(modes) mode_list = [mm for mm in mode_list if 1 <= mm <= num_modes_file] if not mode_list: raise ValueError("read_modes_from_mat: no valid mode indices in requested subset") idx0 = [mm - 1 for mm in mode_list] phi = phi_all[:, idx0] k_sel = k_all[idx0] return ModesReadResult( title=title, nfreq=nfreq, nmedia=nmedia, ntot=ntot, nmat=nmat, n_per_medium=n_per_medium, mater=mater, depth=depth, rho=rho_med, freq_vec=freq_vec, z=z, num_modes=num_modes_file, top=top, bot=bot, phi=phi, k=k_sel, )