from __future__ import annotations
import struct
from dataclasses import dataclass
from typing import Literal
import numpy as np
from at_py.readwrite.ram import RamTlGridResult, read_ram_tlgrid
@dataclass(frozen=True)
class ShdPosS:
"""Source grid: horizontal coordinates and source depth (meters)."""
x: np.ndarray # meters
y: np.ndarray # meters
z: np.ndarray # meters
@dataclass(frozen=True)
class ShdPosR:
"""Receiver grid: receiver depth and range vectors (meters)."""
z: np.ndarray # meters
r: np.ndarray # meters
[docs]
@dataclass(frozen=True)
class ShdPos:
"""Axes and counts for binary ``.shd`` / ``read_shd_bin`` geometry (see Matlab ``Pos``)."""
theta: np.ndarray
s: ShdPosS
r: ShdPosR
nsx: int
nsy: int
nsz: int
nrz: int
nrr: int
ntheta: int
nfreq: int
[docs]
@dataclass(frozen=True)
class ShdReadResult:
"""Binary shade file payload: metadata, grids in ``pos``, and complex pressure field."""
title: str
plot_type: str
freq_vec: np.ndarray
freq0: float
atten: float
pos: ShdPos
pressure: np.ndarray # complex64, shape (ntheta, nsz, nrcvrs_per_range, nrr)
def _read_at(fmt: str, data: bytes, offset: int):
"""Unpack ``struct`` format at ``offset``; return ``(values_tuple, next_offset)``."""
size = struct.calcsize(fmt)
return struct.unpack_from(fmt, data, offset), offset + size
def _read_bytes_at(data: bytes, offset: int, n: int) -> bytes:
"""Slice ``n`` bytes from ``data`` at ``offset``."""
return data[offset : offset + n]
[docs]
def read_shd_bin(
data: bytes,
*,
freq: float | None = None,
xs_km: float | None = None,
ys_km: float | None = None,
) -> ShdReadResult:
"""Parse AT binary `.shd`/`.grn` bytes (ported from Matlab `read_shd_bin.m`).
This function is intentionally **bytes-in / objects-out** (no filesystem).
- **freq**: select nearest frequency (defaults to first frequency)
- **xs_km / ys_km**: select nearest source x/y in km (defaults to first x/y)
"""
# The file uses fixed-size direct-access records of length (4 * recl) bytes,
# where recl is stored as int32 at byte 0 (ported from Matlab code).
(recl,), _ = _read_at("<i", data, 0)
rec_bytes = 4 * recl
title_raw = _read_bytes_at(data, 4, 80)
title = title_raw.decode("utf-8", errors="replace").rstrip("\x00").rstrip()
plot_type_raw = _read_bytes_at(data, rec_bytes, 10)
plot_type = plot_type_raw.decode("utf-8", errors="replace")
# Record start offsets (Matlab seeks to N * rec_bytes from BOF).
off = 2 * rec_bytes
(nfreq,), off = _read_at("<i", data, off)
(ntheta,), off = _read_at("<i", data, off)
(nsx,), off = _read_at("<i", data, off)
(nsy,), off = _read_at("<i", data, off)
(nsz,), off = _read_at("<i", data, off)
(nrz,), off = _read_at("<i", data, off)
(nrr,), off = _read_at("<i", data, off)
(freq0,), off = _read_at("<d", data, off)
(atten,), off = _read_at("<d", data, off)
freq_vec = np.frombuffer(data, dtype="<f8", count=nfreq, offset=3 * rec_bytes).copy()
theta = np.frombuffer(data, dtype="<f8", count=ntheta, offset=4 * rec_bytes).copy()
if not plot_type.startswith("TL"):
sx = np.frombuffer(data, dtype="<f8", count=nsx, offset=5 * rec_bytes).copy()
sy = np.frombuffer(data, dtype="<f8", count=nsy, offset=6 * rec_bytes).copy()
else:
# Compressed TL format: record stores just [min, max] and we expand with linspace.
sx2 = np.frombuffer(data, dtype="<f8", count=2, offset=5 * rec_bytes)
sy2 = np.frombuffer(data, dtype="<f8", count=2, offset=6 * rec_bytes)
sx = np.linspace(float(sx2[0]), float(sx2[-1]), nsx)
sy = np.linspace(float(sy2[0]), float(sy2[-1]), nsy)
sz = np.frombuffer(data, dtype="<f4", count=nsz, offset=7 * rec_bytes).astype(
np.float32, copy=True
)
rz = np.frombuffer(data, dtype="<f4", count=nrz, offset=8 * rec_bytes).astype(
np.float32, copy=True
)
rr = np.frombuffer(data, dtype="<f8", count=nrr, offset=9 * rec_bytes).copy()
# Determine which frequency index to read (Matlab uses 1-based; we use 0-based).
if freq is None:
ifreq = 0
else:
ifreq = int(np.argmin(np.abs(freq_vec - freq)))
# Determine which source x/y index to read.
if xs_km is None or ys_km is None:
src_mode: Literal["first", "xy"] = "first"
idx_x = idx_y = None
else:
src_mode = "xy"
idx_x = int(np.argmin(np.abs(sx - (xs_km * 1000.0))))
idx_y = int(np.argmin(np.abs(sy - (ys_km * 1000.0))))
# Shape depends on plot type (ported from Matlab):
if plot_type == "rectilin ":
nrcvrs_per_range = nrz
elif plot_type == "irregular ":
nrcvrs_per_range = 1
else:
nrcvrs_per_range = nrz
pressure = np.zeros((ntheta, nsz, nrcvrs_per_range, nrr), dtype=np.complex64)
def _rec_offset(recnum: int) -> int:
"""Byte offset of shade file record ``recnum``."""
return recnum * rec_bytes
# Data records start at recnum 10 (Matlab constant), meaning offset 10 * rec_bytes.
for itheta in range(ntheta):
for isz_i in range(nsz):
for irz_i in range(nrcvrs_per_range):
if src_mode == "first":
recnum = (
10
+ ifreq * ntheta * nsz * nrcvrs_per_range
+ itheta * nsz * nrcvrs_per_range
+ isz_i * nrcvrs_per_range
+ irz_i
)
else:
assert idx_x is not None and idx_y is not None
recnum = (
10
+ idx_x * nsy * ntheta * nsz * nrcvrs_per_range
+ idx_y * ntheta * nsz * nrcvrs_per_range
+ itheta * nsz * nrcvrs_per_range
+ isz_i * nrcvrs_per_range
+ irz_i
)
off = _rec_offset(recnum)
# Each record stores 2*Nrr float32 values: interleaved real/imag.
temp = np.frombuffer(data, dtype="<f4", count=2 * nrr, offset=off)
real = temp[0::2]
imag = temp[1::2]
pressure[itheta, isz_i, irz_i, :] = (real + 1j * imag).astype(np.complex64)
pos = ShdPos(
theta=theta,
s=ShdPosS(x=sx, y=sy, z=sz.astype(np.float32, copy=False)),
r=ShdPosR(z=rz.astype(np.float32, copy=False), r=rr),
nsx=nsx,
nsy=nsy,
nsz=nsz,
nrz=nrz,
nrr=nrr,
ntheta=ntheta,
nfreq=nfreq,
)
return ShdReadResult(
title=title,
plot_type=plot_type,
freq_vec=freq_vec,
freq0=float(freq0),
atten=float(atten),
pos=pos,
pressure=pressure,
)
[docs]
@dataclass(frozen=True)
class ShdAscResult:
"""ASCII shade file (port of Matlab ``read_shd_asc.m``)."""
plot_title: str
plot_type: str
freq_vec: np.ndarray
freq0: float
atten: float
theta: np.ndarray
source_z: np.ndarray
receiver_z: np.ndarray
receiver_r: np.ndarray
pressure: np.ndarray # complex64, shape (nsd, nrd, nrr)
[docs]
def read_shd_asc(data: str | bytes, *, encoding: str = "utf-8") -> ShdAscResult:
"""Parse ASCII shade file text (port of Matlab ``read_shd_asc.m``).
Matlab reads a single source-depth slab (``isd = 1`` in the original);
this port reads **all** ``Nsd`` slabs sequentially if ``Nsd > 1``, stacking
an extra dimension so ``pressure`` has shape ``(nsd, nrd, nrr)``.
For the common ``Nsd == 1`` case, shape is ``(1, nrd, nrr)``.
"""
text = data.decode(encoding, errors="replace") if isinstance(data, bytes) else data
raw_lines = text.splitlines()
if len(raw_lines) < 2:
raise ValueError("ascii shade file: need at least title and plot type lines")
plot_title = raw_lines[0].strip()
plot_type = raw_lines[1].strip()
body = "\n".join(raw_lines[2:])
tokens = body.split()
it = iter(tokens)
try:
def ri() -> int:
"""Next ASCII token as ``int``."""
return int(next(it))
def rf() -> float:
"""Next ASCII token as ``float``."""
return float(next(it))
nfreq = ri()
ntheta = ri()
nsd = ri()
nrd = ri()
nrr = ri()
freq0 = rf()
atten = rf()
freq_vec = np.array([rf() for _ in range(nfreq)], dtype=np.float64)
theta = np.array([rf() for _ in range(ntheta)], dtype=np.float64)
source_z = np.array([rf() for _ in range(nsd)], dtype=np.float32)
receiver_z = np.array([rf() for _ in range(nrd)], dtype=np.float32)
receiver_r = np.array([rf() for _ in range(nrr)], dtype=np.float64)
nval = 2 * nrr * nrd * nsd
flat = np.array([rf() for _ in range(nval)], dtype=np.float32)
except StopIteration as e:
raise ValueError("ascii shade file: unexpected end of data") from e
# Matlab fscanf fills [2*Nrr, Nrd] column-major for each source depth.
slabs = []
offset = 0
for _ in range(nsd):
block = flat[offset : offset + 2 * nrr * nrd]
offset += 2 * nrr * nrd
temp = block.reshape(2 * nrr, nrd, order="F")
real = temp[0::2, :]
imag = temp[1::2, :]
pressure_rd = (real.T + 1j * imag.T).astype(np.complex64)
slabs.append(pressure_rd)
pressure = np.stack(slabs, axis=0)
return ShdAscResult(
plot_title=plot_title,
plot_type=plot_type,
freq_vec=freq_vec,
freq0=float(freq0),
atten=float(atten),
theta=theta,
source_z=source_z,
receiver_z=receiver_z,
receiver_r=receiver_r,
pressure=pressure,
)
[docs]
def read_shd(
data: bytes | str,
*,
freq: float | None = None,
xs_km: float | None = None,
ys_km: float | None = None,
file_type: Literal["shd", "asc", "ram"] | None = None,
encoding: str = "utf-8",
) -> ShdReadResult | ShdAscResult | RamTlGridResult:
"""Dispatch shade parsing like Matlab ``read_shd.m`` for in-memory payloads.
- ``file_type='shd'``: binary ``read_shd_bin``
- ``file_type='asc'``: ASCII ``read_shd_asc``
- ``file_type='ram'``: RAM ``tl.grid`` via ``read_ram_tlgrid``
- ``file_type=None``: ``str`` -> ASCII; ``bytes`` -> try binary, then ASCII decode
"""
if file_type == "shd":
if isinstance(data, str):
data = data.encode(encoding)
return read_shd_bin(data, freq=freq, xs_km=xs_km, ys_km=ys_km)
if file_type == "asc":
return read_shd_asc(data, encoding=encoding)
if file_type == "ram":
if isinstance(data, str):
data = data.encode(encoding)
return read_ram_tlgrid(data)
if file_type is not None:
raise ValueError(
f"unsupported shade file_type {file_type!r}; expected 'shd', 'asc', or 'ram'"
)
if isinstance(data, str):
return read_shd_asc(data, encoding=encoding)
try:
return read_shd_bin(data, freq=freq, xs_km=xs_km, ys_km=ys_km)
except ValueError:
return read_shd_asc(data, encoding=encoding)
[docs]
def read_shd_bytes(
data: bytes,
*,
freq: float | None = None,
xs_km: float | None = None,
ys_km: float | None = None,
file_type: Literal["shd", "asc", "ram"] | None = None,
encoding: str = "utf-8",
) -> ShdReadResult | ShdAscResult | RamTlGridResult:
"""Like :func:`read_shd` for ``bytes`` input only."""
return read_shd(
data,
freq=freq,
xs_km=xs_km,
ys_km=ys_km,
file_type=file_type,
encoding=encoding,
)