Source code for at_py.readwrite.ssp3d

"""3D SSP grid (Bellhop3D ``SSPFIL``) reader (port of Matlab ``readssp3d.m``)."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from at_py.readwrite.ssp2d import _tokens


[docs] @dataclass(frozen=True) class SSP3DRead: """3D SSP cube ``cmat`` shape ``(Nz, Ny, Nx)`` (Matlab ``cmat(iz,:,:)`` is ``Ny``×``Nx``).""" nx: int ny: int nz: int segx_km: np.ndarray segy_km: np.ndarray segz_km: np.ndarray cmat: np.ndarray
[docs] def parse_ssp3d(text: str) -> SSP3DRead: """Parse a 3D SSPFIL (Matlab ``readssp3d``). For each depth index ``iz``, one ``Nx``×``Ny`` block is read column-wise (Matlab ``fscanf(..., [Nx, Ny])``), then transposed into ``cmat(iz, :, :)``. """ tok = _tokens(text) it = iter(tok) try: nx = int(next(it)) segx = np.array([float(next(it)) for _ in range(nx)], dtype=np.float64) ny = int(next(it)) segy = np.array([float(next(it)) for _ in range(ny)], dtype=np.float64) nz = int(next(it)) segz = np.array([float(next(it)) for _ in range(nz)], dtype=np.float64) except StopIteration as e: raise ValueError("SSPFIL (3D): truncated header (Nx/Segx/Ny/Segy/Nz/Segz)") from e need = nx * ny * nz rest: list[float] = [] for x in it: rest.append(float(x)) if len(rest) < need: raise ValueError( f"SSPFIL (3D): need {need} sound-speed values, got {len(rest)} after header" ) if len(rest) > need: raise ValueError( f"SSPFIL (3D): expected exactly {need} sound-speed values, got {len(rest)}" ) arr = np.asarray(rest[:need], dtype=np.float64) cmat = np.zeros((nz, ny, nx), dtype=np.float64) for iz in range(nz): sl = arr[iz * nx * ny : (iz + 1) * nx * ny] m2 = sl.reshape((nx, ny), order="F").T cmat[iz, :, :] = m2 return SSP3DRead( nx=nx, ny=ny, nz=nz, segx_km=segx, segy_km=segy, segz_km=segz, cmat=cmat, )
[docs] def parse_ssp3d_bytes(data: bytes) -> SSP3DRead: """Parse 3D SSPFIL from UTF-8 text bytes.""" text = data.decode("utf-8", errors="replace") return parse_ssp3d(text)
def _fmt_ssp_float(x: float) -> str: return f"{x:.17g}"
[docs] def format_ssp3d(s: SSP3DRead) -> str: """Format a 3D SSPFIL text (inverse of :func:`parse_ssp3d`).""" nx, ny, nz = int(s.nx), int(s.ny), int(s.nz) segx = np.asarray(s.segx_km, dtype=np.float64).reshape(-1) segy = np.asarray(s.segy_km, dtype=np.float64).reshape(-1) segz = np.asarray(s.segz_km, dtype=np.float64).reshape(-1) if segx.size != nx or segy.size != ny or segz.size != nz: raise ValueError("segment arrays must match nx, ny, nz") cmat = np.asarray(s.cmat, dtype=np.float64) if cmat.shape != (nz, ny, nx): raise ValueError(f"cmat shape must be (nz, ny, nx)=({nz}, {ny}, {nx}); got {cmat.shape}") chunks: list[np.ndarray] = [] for iz in range(nz): m2 = cmat[iz, :, :] pre_t = m2.T sl = pre_t.reshape(-1, order="F") chunks.append(sl) sound = np.concatenate(chunks) parts: list[str] = [ str(nx), " ".join(_fmt_ssp_float(float(x)) for x in segx), str(ny), " ".join(_fmt_ssp_float(float(x)) for x in segy), str(nz), " ".join(_fmt_ssp_float(float(x)) for x in segz), " ".join(_fmt_ssp_float(float(x)) for x in sound.ravel()), ] return "\n".join(parts) + "\n"
[docs] def format_ssp3d_bytes(s: SSP3DRead, *, encoding: str = "utf-8") -> bytes: return format_ssp3d(s).encode(encoding)