Skip to content

Commit ac2158e

Browse files
committed
create_empty squash
1 parent f5c799e commit ac2158e

File tree

7 files changed

+682
-10
lines changed

7 files changed

+682
-10
lines changed

src/mdio/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@
22

33
from importlib import metadata
44

5-
from mdio.api.io import open_mdio
6-
from mdio.api.io import to_mdio
7-
from mdio.converters import mdio_to_segy
8-
from mdio.converters import segy_to_mdio
9-
105
try:
116
__version__ = metadata.version("multidimio")
127
except metadata.PackageNotFoundError:
138
__version__ = "unknown"
149

10+
from mdio.api.io import open_mdio
11+
from mdio.api.io import to_mdio
12+
from mdio.converters.mdio import mdio_to_segy
13+
from mdio.converters.segy import segy_to_mdio
14+
from mdio.api.create import create_empty
15+
from mdio.api.create import create_empty_like
1516

1617
__all__ = [
1718
"__version__",
1819
"open_mdio",
1920
"to_mdio",
2021
"mdio_to_segy",
2122
"segy_to_mdio",
23+
"create_empty",
24+
"create_empty_like",
2225
]

src/mdio/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
"""Public API."""
2+

src/mdio/api/create.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Creating MDIO v1 datasets."""
2+
3+
from __future__ import annotations
4+
5+
from datetime import UTC
6+
from datetime import datetime
7+
from typing import TYPE_CHECKING
8+
9+
from mdio.api.io import _normalize_path
10+
from mdio.api.io import open_mdio
11+
from mdio.api.io import to_mdio
12+
from mdio.builder.template_registry import TemplateRegistry
13+
from mdio.builder.xarray_builder import to_xarray_dataset
14+
from mdio.converters.segy import populate_dim_coordinates
15+
from mdio.converters.type_converter import to_structured_type
16+
from mdio.core.grid import Grid
17+
18+
if TYPE_CHECKING:
19+
from pathlib import Path
20+
21+
from segy.schema import HeaderSpec
22+
from upath import UPath
23+
from xarray import Dataset as xr_Dataset
24+
25+
from mdio.builder.schemas import Dataset
26+
from mdio.builder.templates.base import AbstractDatasetTemplate
27+
from mdio.core.dimension import Dimension
28+
29+
30+
def create_empty( # noqa PLR0913
31+
mdio_template: AbstractDatasetTemplate | str,
32+
dimensions: list[Dimension],
33+
output_path: UPath | Path | str | None,
34+
headers: HeaderSpec | None = None,
35+
overwrite: bool = False,
36+
) -> xr_Dataset:
37+
"""A function that creates an empty MDIO v1 file with known dimensions.
38+
39+
Args:
40+
mdio_template: The MDIO template or template name to use to define the dataset structure.
41+
dimensions: The dimensions of the MDIO file.
42+
output_path: The universal path for the output MDIO v1 file.
43+
headers: SEG-Y v1.0 trace headers. Defaults to None.
44+
overwrite: Whether to overwrite the output file if it already exists. Defaults to False.
45+
46+
Returns:
47+
The output MDIO dataset.
48+
49+
Raises:
50+
FileExistsError: If the output location already exists and overwrite is False.
51+
"""
52+
output_path = _normalize_path(output_path)
53+
54+
if not overwrite and output_path.exists():
55+
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
56+
raise FileExistsError(err)
57+
58+
header_dtype = to_structured_type(headers.dtype) if headers else None
59+
grid = Grid(dims=dimensions)
60+
if isinstance(mdio_template, str):
61+
# A template name is passed in. Get a unit-unaware template from registry
62+
mdio_template = TemplateRegistry().get(mdio_template)
63+
# Build the dataset using the template
64+
mdio_ds: Dataset = mdio_template.build_dataset(name=mdio_template.name, sizes=grid.shape, header_dtype=header_dtype)
65+
66+
# Convert to xarray dataset
67+
xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)
68+
69+
# Populate coordinates using the grid
70+
# For empty datasets, we only populate dimension coordinates
71+
drop_vars_delayed = []
72+
xr_dataset, drop_vars_delayed = populate_dim_coordinates(xr_dataset, grid, drop_vars_delayed=drop_vars_delayed)
73+
74+
if headers:
75+
# Since the headers were provided, the user wants to export to SEG-Y
76+
# Add a dummy segy_file_header variable used to export to SEG-Y
77+
xr_dataset["segy_file_header"] = ((), "")
78+
79+
# Create the Zarr store with the correct structure but with empty arrays
80+
if output_path is not None:
81+
to_mdio(xr_dataset, output_path=output_path, mode="w", compute=False)
82+
83+
# Write the dimension coordinates and trace mask
84+
xr_dataset = xr_dataset[drop_vars_delayed + ["trace_mask"]]
85+
86+
if output_path is not None:
87+
to_mdio(xr_dataset, output_path=output_path, mode="r+", compute=True)
88+
89+
return xr_dataset
90+
91+
92+
def create_empty_like( # noqa PLR0913
93+
input_path: UPath | Path | str,
94+
output_path: UPath | Path | str,
95+
keep_coordinates: bool = False,
96+
overwrite: bool = False,
97+
) -> xr_Dataset:
98+
"""A function that creates an empty MDIO v1 file with the same structure as an existing one.
99+
100+
Args:
101+
input_path: The path of the input MDIO file.
102+
output_path: The path of the output MDIO file.
103+
If None, the output will not be written to disk.
104+
keep_coordinates: Whether to keep the coordinates in the output file.
105+
overwrite: Whether to overwrite the output file if it exists.
106+
107+
Returns:
108+
The output MDIO dataset.
109+
110+
Raises:
111+
FileExistsError: If the output location already exists and overwrite is False.
112+
"""
113+
input_path = _normalize_path(input_path)
114+
output_path = _normalize_path(output_path) if output_path is not None else None
115+
116+
if not overwrite and output_path is not None and output_path.exists():
117+
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
118+
raise FileExistsError(err)
119+
120+
ds = open_mdio(input_path)
121+
122+
# Create a copy with the same structure but no data or,
123+
# optionally, coordinates
124+
ds_output = ds.copy(data=None).reset_coords(drop=not keep_coordinates)
125+
126+
# Dataset
127+
# Keep the name (which is the same as the used template name) and the original API version
128+
# ds_output.attrs["name"]
129+
# ds_output.attrs["apiVersion"]
130+
ds_output.attrs["createdOn"] = str(datetime.now(UTC))
131+
132+
# Coordinates
133+
if not keep_coordinates:
134+
for coord_name in ds_output.coords:
135+
ds_output[coord_name].attrs.pop("unitsV1", None)
136+
137+
# MDIO attributes
138+
attr = ds_output.attrs["attributes"]
139+
if attr is not None:
140+
attr.pop("gridOverrides", None) # Empty dataset should not have gridOverrides
141+
# Keep the original values for the following attributes
142+
# attr["defaultVariableName"]
143+
# attr["surveyType"]
144+
# attr["gatherType"]
145+
146+
# "All traces should be marked as dead in empty dataset"
147+
if "trace_mask" in ds_output.variables:
148+
ds_output["trace_mask"][:] = False
149+
150+
# Data variable
151+
var_name = attr["defaultVariableName"]
152+
var = ds_output[var_name]
153+
var.attrs.pop("statsV1", None)
154+
if not keep_coordinates:
155+
var.attrs.pop("unitsV1", None)
156+
157+
# SEG-Y file header
158+
if "segy_file_header" in ds_output.variables:
159+
segy_file_header = ds_output["segy_file_header"]
160+
if segy_file_header is not None:
161+
segy_file_header.attrs.pop("textHeader", None)
162+
segy_file_header.attrs.pop("binaryHeader", None)
163+
segy_file_header.attrs.pop("rawBinaryHeader", None)
164+
165+
if output_path is not None:
166+
to_mdio(ds_output, output_path=output_path, mode="w", compute=True)
167+
168+
return ds_output

src/mdio/converters/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
11
"""MDIO Data conversion API."""
22

3-
from mdio.converters.mdio import mdio_to_segy
4-
from mdio.converters.segy import segy_to_mdio
5-
6-
__all__ = ["mdio_to_segy", "segy_to_mdio"]

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,9 @@ def segy_export_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
5858
"""Make a temp file for the round-trip IBM SEG-Y."""
5959
tmp_dir = tmp_path_factory.mktemp("segy")
6060
return tmp_dir / "teapot_roundtrip.segy"
61+
62+
63+
@pytest.fixture(scope="class")
64+
def empty_mdio_dir(tmp_path_factory: pytest.TempPathFactory) -> Path:
65+
"""Make a temp file for empty MDIO testing."""
66+
return tmp_path_factory.mktemp(r"empty_mdio_dir")

0 commit comments

Comments
 (0)