Skip to content

Commit 73ab646

Browse files
Implementation of NonBinned and DuplicateIndex grid overrides (#732)
* Working example of DuplicateIndex grid override * Remove hardcoded values + linting * Streamline code * Relocate template mutation to all happen in the same function -- Update logging level * Formatting * Clean up * Update logging * Remove incorrect noqa * Checkpoint * Fully functional demo * pre-commit * Revert debugging changes * pre-commit * Simplifications to DuplicateIndex and GridOverrides * Add base safety for grid override template mutation * Extract template update logic * Resolve nondeterminstic ordering causing incorrect template mutations * Bandaid fix to join non-binned and autoshotwrap overrides -- Should be refactored for a uniform experience * pre-commit * Fix memory consumption on grid calculation and not appropriately calculating trace dimension * Apply proper dimensions to non-binned coordiante arrays -- Fix large memory spike between header scan and trace ingestion --------- Co-authored-by: Altay Sansal <tasansal@users.noreply.github.com>
1 parent a4bbb74 commit 73ab646

File tree

7 files changed

+414
-48
lines changed

7 files changed

+414
-48
lines changed

src/mdio/builder/dataset_builder.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,32 @@ def add_coordinate( # noqa: PLR0913
150150
msg = "Adding coordinate with the same name twice is not allowed"
151151
raise ValueError(msg)
152152

153-
# Validate that all referenced dimensions are already defined
153+
# Resolve referenced dimensions strictly, except allow a single substitution with 'trace' if present.
154154
named_dimensions = []
155+
trace_dim = _get_named_dimension(self._dimensions, "trace")
156+
resolved_dim_names: list[str] = []
157+
trace_used = False
158+
missing_dims: list[str] = []
155159
for dim_name in dimensions:
156160
nd = _get_named_dimension(self._dimensions, dim_name)
161+
if nd is not None:
162+
if dim_name not in resolved_dim_names:
163+
resolved_dim_names.append(dim_name)
164+
continue
165+
if trace_dim is not None and not trace_used and "trace" not in resolved_dim_names:
166+
resolved_dim_names.append("trace")
167+
trace_used = True
168+
else:
169+
missing_dims.append(dim_name)
170+
171+
if missing_dims:
172+
msg = f"Pre-existing dimension named {missing_dims[0]!r} is not found"
173+
raise ValueError(msg)
174+
175+
for resolved_name in resolved_dim_names:
176+
nd = _get_named_dimension(self._dimensions, resolved_name)
157177
if nd is None:
158-
msg = f"Pre-existing dimension named {dim_name!r} is not found"
178+
msg = f"Pre-existing dimension named {resolved_name!r} is not found"
159179
raise ValueError(msg)
160180
named_dimensions.append(nd)
161181

@@ -174,7 +194,7 @@ def add_coordinate( # noqa: PLR0913
174194
self.add_variable(
175195
name=coord.name,
176196
long_name=coord.long_name,
177-
dimensions=dimensions, # dimension names (list[str])
197+
dimensions=tuple(resolved_dim_names), # resolved dimension names
178198
data_type=coord.data_type,
179199
compressor=compressor,
180200
coordinates=[name], # Use the coordinate name as a reference

src/mdio/builder/templates/base.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def build_dataset(
8282
8383
Returns:
8484
Dataset: The constructed dataset
85+
86+
Raises:
87+
ValueError: If coordinate already exists from subclass override.
8588
"""
8689
self._dim_sizes = sizes
8790

@@ -90,6 +93,20 @@ def build_dataset(
9093
self._builder = MDIODatasetBuilder(name=name, attributes=attributes)
9194
self._add_dimensions()
9295
self._add_coordinates()
96+
# Ensure any coordinates declared on the template but not added by _add_coordinates
97+
# are materialized with generic defaults. This handles coordinates added by grid overrides.
98+
for coord_name in self.coordinate_names:
99+
try:
100+
self._builder.add_coordinate(
101+
name=coord_name,
102+
dimensions=self.spatial_dimension_names,
103+
data_type=ScalarType.FLOAT64,
104+
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
105+
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(coord_name)),
106+
)
107+
except ValueError as exc: # coordinate may already exist
108+
if "same name twice" not in str(exc):
109+
raise
93110
self._add_variables()
94111
self._add_trace_mask()
95112

@@ -241,14 +258,21 @@ def _add_coordinates(self) -> None:
241258
)
242259

243260
# Add non-dimension coordinates
261+
# Note: coordinate_names may be modified at runtime by grid overrides,
262+
# so we need to handle dynamic additions gracefully
244263
for name in self.coordinate_names:
245-
self._builder.add_coordinate(
246-
name=name,
247-
dimensions=self.spatial_dimension_names,
248-
data_type=ScalarType.FLOAT64,
249-
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
250-
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)),
251-
)
264+
try:
265+
self._builder.add_coordinate(
266+
name=name,
267+
dimensions=self.spatial_dimension_names,
268+
data_type=ScalarType.FLOAT64,
269+
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
270+
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)),
271+
)
272+
except ValueError as exc:
273+
# Coordinate may already exist from subclass override
274+
if "same name twice" not in str(exc):
275+
raise
252276

253277
def _add_trace_mask(self) -> None:
254278
"""Add trace mask variables."""

src/mdio/converters/segy.py

Lines changed: 194 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,146 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None:
134134
raise GridTraceSparsityError(grid.shape, num_traces, msg)
135135

136136

137+
def _patch_add_coordinates_for_non_binned(
138+
template: AbstractDatasetTemplate,
139+
non_binned_dims: set[str],
140+
) -> None:
141+
"""Patch template's _add_coordinates to skip adding non-binned dims as dimension coordinates.
142+
143+
When NonBinned override is used, dimensions like 'offset' or 'azimuth' become coordinates
144+
instead of dimensions. However, template subclasses may still try to add them as 1D
145+
dimension coordinates (e.g., with dimensions=("offset",)). Since 'offset' is no longer
146+
a dimension, the builder substitutes 'trace', resulting in wrong coordinate dimensions.
147+
148+
This function patches the template's _add_coordinates method to intercept calls to
149+
builder.add_coordinate and skip adding coordinates that are non-binned dims with
150+
single-element dimension tuples. These coordinates will be added later by build_dataset
151+
with the correct spatial_dimension_names (e.g., (inline, crossline, trace)).
152+
153+
Args:
154+
template: The template to patch
155+
non_binned_dims: Set of dimension names that became coordinates due to NonBinned override
156+
"""
157+
# Check if already patched to avoid duplicate patching
158+
if hasattr(template, "_mdio_non_binned_patched"):
159+
return
160+
161+
# Store the original _add_coordinates method
162+
original_add_coordinates = template._add_coordinates
163+
164+
def patched_add_coordinates() -> None:
165+
"""Wrapper that intercepts builder.add_coordinate calls for non-binned dims."""
166+
# Store the original add_coordinate method from the builder
167+
original_builder_add_coordinate = template._builder.add_coordinate
168+
169+
def filtered_add_coordinate( # noqa: ANN202
170+
name: str,
171+
*,
172+
dimensions: tuple[str, ...],
173+
**kwargs, # noqa: ANN003
174+
):
175+
"""Skip adding non-binned dims as 1D dimension coordinates."""
176+
# Check if this is a non-binned dim being added as a 1D dimension coordinate
177+
# (i.e., the coordinate name matches a non-binned dim and has only 1 dimension)
178+
if name in non_binned_dims and len(dimensions) == 1:
179+
logger.debug(
180+
"Skipping 1D coordinate '%s' with dims %s - will be added with full spatial dims",
181+
name,
182+
dimensions,
183+
)
184+
return template._builder # Return builder for chaining, but don't add
185+
186+
# Otherwise, call the original method
187+
return original_builder_add_coordinate(name, dimensions=dimensions, **kwargs)
188+
189+
# Temporarily replace builder's add_coordinate
190+
template._builder.add_coordinate = filtered_add_coordinate
191+
192+
try:
193+
# Call the original _add_coordinates
194+
original_add_coordinates()
195+
finally:
196+
# Restore the original add_coordinate method
197+
template._builder.add_coordinate = original_builder_add_coordinate
198+
199+
# Replace the template's _add_coordinates method
200+
template._add_coordinates = patched_add_coordinates
201+
202+
# Mark as patched to prevent duplicate patching
203+
template._mdio_non_binned_patched = True
204+
205+
206+
def _update_template_from_grid_overrides(
207+
template: AbstractDatasetTemplate,
208+
grid_overrides: dict[str, Any] | None,
209+
segy_dimensions: list[Dimension],
210+
full_chunk_shape: tuple[int, ...],
211+
chunk_size: tuple[int, ...],
212+
) -> None:
213+
"""Update template attributes to match grid plan results after grid overrides.
214+
215+
This function modifies the template in-place to reflect changes from grid overrides:
216+
- Updates chunk shape if it changed due to overrides
217+
- Updates dimension names if they changed due to overrides
218+
- Adds non-binned dimensions as coordinates for NonBinned override
219+
- Patches _add_coordinates to skip adding non-binned dims as dimension coordinates
220+
221+
Args:
222+
template: The template to update
223+
grid_overrides: Grid override configuration
224+
segy_dimensions: Dimensions returned from grid planning
225+
full_chunk_shape: Original template chunk shape
226+
chunk_size: Chunk size returned from grid planning
227+
"""
228+
# Update template to match grid_plan results after grid overrides
229+
# Extract actual spatial dimensions from segy_dimensions (excluding vertical dimension)
230+
actual_spatial_dims = tuple(dim.name for dim in segy_dimensions[:-1])
231+
232+
# Align chunk_size with actual dimensions - truncate if dimensions were filtered out
233+
num_actual_spatial = len(actual_spatial_dims)
234+
num_chunk_spatial = len(chunk_size) - 1 # Exclude vertical dimension chunk
235+
if num_actual_spatial != num_chunk_spatial:
236+
# Truncate chunk_size to match actual dimensions
237+
chunk_size = chunk_size[:num_actual_spatial] + (chunk_size[-1],)
238+
239+
if full_chunk_shape != chunk_size:
240+
logger.debug(
241+
"Adjusting template chunk shape from %s to %s to match grid after overrides",
242+
full_chunk_shape,
243+
chunk_size,
244+
)
245+
template._var_chunk_shape = chunk_size
246+
247+
# Update dimensions if they don't match grid_plan results
248+
if template.spatial_dimension_names != actual_spatial_dims:
249+
logger.debug(
250+
"Adjusting template dimensions from %s to %s to match grid after overrides",
251+
template.spatial_dimension_names,
252+
actual_spatial_dims,
253+
)
254+
template._dim_names = actual_spatial_dims + (template.trace_domain,)
255+
256+
# If using NonBinned override, expose non-binned dims as logical coordinates on the template instance
257+
# and patch _add_coordinates to skip adding them as 1D dimension coordinates
258+
if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides:
259+
non_binned_dims = tuple(grid_overrides["non_binned_dims"])
260+
if non_binned_dims:
261+
logger.debug(
262+
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
263+
non_binned_dims,
264+
)
265+
# Append any missing names; keep existing order and avoid duplicates
266+
existing = set(template.coordinate_names)
267+
to_add = tuple(n for n in non_binned_dims if n not in existing)
268+
if to_add:
269+
template._logical_coord_names = template._logical_coord_names + to_add
270+
271+
# Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates
272+
# This prevents them from being added with wrong dimensions (e.g., just "trace")
273+
# They will be added later by build_dataset with full spatial_dimension_names
274+
_patch_add_coordinates_for_non_binned(template, set(non_binned_dims))
275+
276+
137277
def _scan_for_headers(
138278
segy_file_kwargs: SegyFileArguments,
139279
segy_file_info: SegyFileInfo,
@@ -143,7 +283,11 @@ def _scan_for_headers(
143283
"""Extract trace dimensions and index headers from the SEG-Y file.
144284
145285
This is an expensive operation.
146-
It scans the SEG-Y file in chunks by using ProcessPoolExecutor
286+
It scans the SEG-Y file in chunks by using ProcessPoolExecutor.
287+
288+
Note:
289+
If grid_overrides are applied to the template before calling this function,
290+
the chunk_size returned from get_grid_plan should match the template's chunk shape.
147291
"""
148292
full_chunk_shape = template.full_chunk_shape
149293
segy_dimensions, chunk_size, segy_headers = get_grid_plan(
@@ -154,13 +298,15 @@ def _scan_for_headers(
154298
chunksize=full_chunk_shape,
155299
grid_overrides=grid_overrides,
156300
)
157-
if full_chunk_shape != chunk_size:
158-
# TODO(Dmitriy): implement grid overrides
159-
# https://github.com/TGSAI/mdio-python/issues/585
160-
# The returned 'chunksize' is used only for grid_overrides. We will need to use it when full
161-
# support for grid overrides is implemented
162-
err = "Support for changing full_chunk_shape in grid overrides is not yet implemented"
163-
raise NotImplementedError(err)
301+
302+
_update_template_from_grid_overrides(
303+
template=template,
304+
grid_overrides=grid_overrides,
305+
segy_dimensions=segy_dimensions,
306+
full_chunk_shape=full_chunk_shape,
307+
chunk_size=chunk_size,
308+
)
309+
164310
return segy_dimensions, segy_headers
165311

166312

@@ -233,7 +379,8 @@ def _get_coordinates(
233379
if coord_name not in segy_headers.dtype.names:
234380
err = f"Coordinate '{coord_name}' not found in SEG-Y dimensions."
235381
raise ValueError(err)
236-
non_dim_coords[coord_name] = segy_headers[coord_name]
382+
# Copy the data to allow segy_headers to be garbage collected
383+
non_dim_coords[coord_name] = np.array(segy_headers[coord_name])
237384

238385
return dimensions_coords, non_dim_coords
239386

@@ -255,25 +402,54 @@ def populate_non_dim_coordinates(
255402
drop_vars_delayed: list[str],
256403
spatial_coordinate_scalar: int,
257404
) -> tuple[xr_Dataset, list[str]]:
258-
"""Populate the xarray dataset with coordinate variables."""
405+
"""Populate the xarray dataset with coordinate variables.
406+
407+
Memory optimization: Processes coordinates one at a time and explicitly
408+
releases intermediate arrays to reduce peak memory usage.
409+
"""
259410
non_data_domain_dims = grid.dim_names[:-1] # minus the data domain dimension
260-
for coord_name, coord_values in coordinates.items():
411+
412+
# Process coordinates one at a time to minimize peak memory
413+
coord_names = list(coordinates.keys())
414+
for coord_name in coord_names:
415+
coord_values = coordinates.pop(coord_name) # Remove from dict to free memory
261416
da_coord = dataset[coord_name]
262-
tmp_coord_values = dataset[coord_name].values
263417

418+
# Get coordinate shape from dataset (uses dask shape, no memory allocation)
419+
coord_shape = da_coord.shape
420+
421+
# Create output array with fill value
422+
fill_value = da_coord.encoding.get("_FillValue") or da_coord.encoding.get("fill_value")
423+
if fill_value is None:
424+
fill_value = np.nan
425+
tmp_coord_values = np.full(coord_shape, fill_value, dtype=da_coord.dtype)
426+
427+
# Compute slices for this coordinate's dimensions
264428
coord_axes = tuple(non_data_domain_dims.index(coord_dim) for coord_dim in da_coord.dims)
265429
coord_slices = tuple(slice(None) if idx in coord_axes else 0 for idx in range(len(non_data_domain_dims)))
266-
coord_trace_indices = grid.map[coord_slices]
267430

431+
# Read only the required slice from grid map
432+
coord_trace_indices = np.asarray(grid.map[coord_slices])
433+
434+
# Find valid (non-null) indices
268435
not_null = coord_trace_indices != grid.map.fill_value
269-
tmp_coord_values[not_null] = coord_values[coord_trace_indices[not_null]]
270436

437+
# Populate values efficiently
438+
if not_null.any():
439+
valid_indices = coord_trace_indices[not_null]
440+
tmp_coord_values[not_null] = coord_values[valid_indices]
441+
442+
# Apply scalar if needed
271443
if coord_name in SCALE_COORDINATE_KEYS:
272444
tmp_coord_values = _apply_coordinate_scalar(tmp_coord_values, spatial_coordinate_scalar)
273445

446+
# Assign to dataset
274447
dataset[coord_name][:] = tmp_coord_values
275448
drop_vars_delayed.append(coord_name)
276449

450+
# Explicitly release intermediate arrays
451+
del tmp_coord_values, coord_trace_indices, not_null, coord_values
452+
277453
# TODO(Altay): Add verification of reduced coordinates being the same as the first
278454
# https://github.com/TGSAI/mdio-python/issues/645
279455

@@ -554,6 +730,10 @@ def segy_to_mdio( # noqa PLR0913
554730
grid = _build_and_check_grid(segy_dimensions, segy_file_info, segy_headers)
555731

556732
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
733+
734+
# Explicitly delete segy_headers to free memory - coordinate values have been copied
735+
del segy_headers
736+
557737
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
558738

559739
if settings.raw_headers:

0 commit comments

Comments
 (0)