Skip to content

Commit 8b01aa4

Browse files
committed
more refactor
1 parent b440a19 commit 8b01aa4

18 files changed

+2918
-16
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Helper modules for SEG-Y CLI.
2+
3+
This package isolates interactive UI helpers and IO/spec helpers
4+
away from the CLI command definitions to improve separation of concerns.
5+
"""
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import questionary
6+
import typer
7+
from segy.schema.format import TextHeaderEncoding
8+
from segy.schema.header import HeaderField
9+
from segy.schema.segy import SegyStandard
10+
from upath import UPath
11+
12+
from mdio.commands.segy_helpers.text_headers import _format_text_header
13+
from mdio.commands.segy_helpers.text_headers import _pager
14+
from mdio.commands.segy_helpers.text_headers import _read_text_header
15+
16+
if TYPE_CHECKING: # pragma: no cover
17+
from segy.schema.segy import SegySpec
18+
19+
from mdio.builder.templates.base import AbstractDatasetTemplate
20+
21+
TEXT_ENCODING = TextHeaderEncoding.EBCDIC
22+
REVISION_MAP = {
23+
"rev 0": SegyStandard.REV0,
24+
"rev 1": SegyStandard.REV1,
25+
"rev 2": SegyStandard.REV2,
26+
"rev 2.1": SegyStandard.REV21,
27+
}
28+
29+
30+
def prompt_for_segy_standard() -> SegyStandard:
31+
"""Prompt user to select a SEG-Y standard."""
32+
choices = list(REVISION_MAP.keys())
33+
standard_str = questionary.select("Select SEG-Y standard:", choices=choices, default="rev 1").ask()
34+
return SegyStandard(REVISION_MAP[standard_str])
35+
36+
37+
def prompt_for_text_encoding() -> TextHeaderEncoding | None:
38+
"""Prompt user for text header encoding (returns TextHeaderEncoding)."""
39+
choices = [member.name for member in TextHeaderEncoding]
40+
choice = questionary.select("Select text header encoding:", choices=choices, default=TEXT_ENCODING).ask()
41+
if choice is None:
42+
return None
43+
return TextHeaderEncoding(choice)
44+
45+
46+
def prompt_for_header_fields(field_type: str, segy_spec: SegySpec) -> list[HeaderField]:
47+
"""Prompt user to customize header fields with interactive choices."""
48+
49+
def _get_known_fields() -> list[HeaderField]:
50+
"""Get known fields for the given field type."""
51+
if field_type.lower() == "binary":
52+
return segy_spec.binary.header.fields
53+
if field_type.lower() == "trace":
54+
return segy_spec.trace.header.fields
55+
return []
56+
57+
def _format_choice(hf: HeaderField) -> str:
58+
"""Format a header field choice for the checkbox."""
59+
return f"{hf.name} (byte={hf.byte}, format={hf.format})"
60+
61+
if not questionary.confirm(f"Customize {field_type} header fields?", default=False).ask():
62+
return []
63+
64+
fields = []
65+
while True:
66+
action = questionary.select(
67+
f"Customize {field_type} header fields — choose an action:",
68+
choices=[
69+
"Add from known fields",
70+
"Add a new field",
71+
"View current selections",
72+
"Clear selections",
73+
"Done",
74+
],
75+
default="Add a new field",
76+
).ask()
77+
78+
if action == "Add from known fields":
79+
known = _get_known_fields()
80+
choices = [_format_choice(hf) for hf in known]
81+
selected = questionary.checkbox(f"Pick {field_type} header fields to add:", choices=choices).ask()
82+
lookup = {_format_choice(hf): hf for hf in known}
83+
for label in selected:
84+
header_field = lookup.get(label)
85+
if header_field is not None:
86+
fields.append(header_field)
87+
88+
elif action == "Add a new field":
89+
name = questionary.text("Field name (e.g., inline):").ask()
90+
if not name:
91+
print("Name cannot be empty.")
92+
continue
93+
94+
byte_str = questionary.text("Starting byte (integer):").ask()
95+
try:
96+
byte_val = int(byte_str)
97+
except (TypeError, ValueError):
98+
print("Byte must be an integer.")
99+
continue
100+
101+
from segy.schema.format import ScalarType
102+
103+
fmt_choices = [s.value for s in ScalarType]
104+
format_ = questionary.select("Data format:", choices=fmt_choices, default="int32").ask()
105+
if not format_:
106+
print("Format cannot be empty.")
107+
continue
108+
109+
try:
110+
valid_field = HeaderField.model_validate({"name": name, "byte": byte_val, "format": format_})
111+
except Exception as exc: # pydantic validation error
112+
print(f"Invalid field specification: {exc}")
113+
continue
114+
fields.append(valid_field)
115+
116+
elif action == "View current selections":
117+
if not fields:
118+
print("No custom fields selected yet.")
119+
else:
120+
print("Currently selected fields:")
121+
for i, hf in enumerate(fields, start=1):
122+
print(f" {i}. {hf.name} (byte={hf.byte}, format={hf.format})")
123+
124+
elif action == "Clear selections":
125+
if fields and questionary.confirm("Clear all selected fields?", default=False).ask():
126+
fields = []
127+
128+
elif action == "Done":
129+
break
130+
131+
return fields
132+
133+
134+
def prompt_for_mdio_template() -> AbstractDatasetTemplate:
135+
"""Prompt user to select a MDIO template."""
136+
from mdio.builder.template_registry import get_template_registry
137+
138+
registry = get_template_registry()
139+
choices = registry.list_all_templates()
140+
template_name = questionary.select("Select MDIO template:", choices=choices).ask()
141+
142+
if template_name is None:
143+
raise typer.Abort
144+
145+
return registry.get(template_name)
146+
147+
148+
def interactive_text_header(input_path: UPath) -> TextHeaderEncoding:
149+
"""Run textual header preview and return the chosen encoding."""
150+
from segy.standards.registry import get_segy_standard
151+
152+
text_encoding = TextHeaderEncoding.EBCDIC
153+
segy_spec_preview = get_segy_standard(SegyStandard.REV0)
154+
segy_spec_preview.text_header.encoding = text_encoding
155+
segy_spec_preview.endianness = None
156+
157+
if questionary.confirm("Preview textual header now?", default=True).ask():
158+
while True:
159+
main_txt = _read_text_header(input_path, segy_spec_preview)
160+
formatted_txt = _format_text_header(main_txt, segy_spec_preview.text_header.encoding)
161+
_pager(formatted_txt)
162+
163+
did_save = False
164+
if questionary.confirm("Save displayed header(s) to a file?", default=False).ask():
165+
default_hdr_uri = input_path.with_name(f"{input_path.stem}_text_header.txt").as_posix()
166+
out_hdr_uri = questionary.text("Filename for text header:", default=default_hdr_uri).ask()
167+
if out_hdr_uri:
168+
with UPath(out_hdr_uri).open("w") as fp:
169+
fp.write(formatted_txt)
170+
print(f"Textual header saved to '{out_hdr_uri}'.")
171+
did_save = True
172+
173+
if did_save:
174+
break
175+
176+
if not questionary.confirm("Switch encoding and preview again?", default=False).ask():
177+
break
178+
179+
new_enc = prompt_for_text_encoding()
180+
text_encoding = new_enc or text_encoding
181+
segy_spec_preview.text_header.encoding = text_encoding
182+
183+
return text_encoding
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from typing import TYPE_CHECKING
5+
6+
import questionary
7+
import typer
8+
from rich import print # noqa: A004
9+
from upath import UPath
10+
11+
from mdio.commands.segy_helpers.interactive import _interactive_text_header_preview_select_encoding
12+
from mdio.commands.segy_helpers.interactive import prompt_for_header_fields
13+
from mdio.commands.segy_helpers.interactive import prompt_for_segy_standard
14+
15+
if TYPE_CHECKING:
16+
from segy.schema.format import TextHeaderEncoding
17+
from segy.schema.segy import SegySpec
18+
19+
from mdio.builder.templates.base import AbstractDatasetTemplate
20+
21+
22+
def load_mdio_template(mdio_template_name: str) -> AbstractDatasetTemplate:
23+
"""Load MDIO template from registry or fail with Typer.Abort."""
24+
from mdio.builder.template_registry import get_template_registry
25+
26+
registry = get_template_registry()
27+
try:
28+
return registry.get(mdio_template_name)
29+
except KeyError:
30+
typer.secho(f"MDIO template '{mdio_template_name}' not found.", fg="red", err=True)
31+
raise typer.Abort from None
32+
33+
34+
def load_segy_spec(segy_spec_path: UPath) -> SegySpec:
35+
"""Load SEG-Y specification from a file."""
36+
from pydantic import ValidationError
37+
from segy.schema.segy import SegySpec
38+
39+
try:
40+
with segy_spec_path.open("r") as f:
41+
return SegySpec.model_validate_json(f.read())
42+
except FileNotFoundError:
43+
typer.secho(f"SEG-Y specification file '{segy_spec_path}' does not exist.", fg="red", err=True)
44+
raise typer.Abort from None
45+
except ValidationError:
46+
typer.secho(f"Invalid SEG-Y specification file '{segy_spec_path}'.", fg="red", err=True)
47+
raise typer.Abort from None
48+
49+
50+
def create_segy_spec(
51+
input_path: UPath, mdio_template: AbstractDatasetTemplate, preselected_encoding: TextHeaderEncoding | None = None
52+
) -> SegySpec:
53+
"""Create SEG-Y specification interactively."""
54+
from segy.standards.registry import get_segy_standard
55+
56+
# Preview textual header FIRST with EBCDIC by default (before selecting SEG-Y revision)
57+
if preselected_encoding is None:
58+
text_encoding = _interactive_text_header_preview_select_encoding(input_path)
59+
else:
60+
text_encoding = preselected_encoding
61+
62+
# Now prompt for SEG-Y standard and build the final spec
63+
segy_standard = prompt_for_segy_standard()
64+
segy_spec = get_segy_standard(segy_standard)
65+
segy_spec.text_header.encoding = text_encoding
66+
if segy_standard >= 1:
67+
segy_spec.ext_text_header.spec.encoding = text_encoding
68+
segy_spec.endianness = None
69+
70+
# Optionally reduce to only template-required trace headers
71+
is_minimal = questionary.confirm("Import only trace headers required by template?", default=False).ask()
72+
if is_minimal:
73+
required_fields = set(mdio_template.coordinate_names) | set(mdio_template.spatial_dimension_names)
74+
required_fields = required_fields | {"coordinate_scalar"}
75+
new_fields = [field for field in segy_spec.trace.header.fields if field.name in required_fields]
76+
segy_spec.trace.header.fields = new_fields
77+
78+
# Prompt for any customizations
79+
binary_fields = prompt_for_header_fields("binary", segy_spec)
80+
trace_fields = prompt_for_header_fields("trace", segy_spec)
81+
if binary_fields or trace_fields:
82+
segy_spec = segy_spec.customize(binary_header_fields=binary_fields, trace_header_fields=trace_fields)
83+
84+
should_save = questionary.confirm("Save SEG-Y specification?", default=True).ask()
85+
if should_save:
86+
from segy import SegyFile
87+
88+
out_segy_spec_path = input_path.with_name(f"{input_path.stem}_segy_spec.json")
89+
out_segy_spec_uri = out_segy_spec_path.as_posix()
90+
91+
custom_uri = questionary.text("Filename for SEG-Y Specification:", default=out_segy_spec_uri).ask()
92+
custom_path = UPath(custom_uri)
93+
updated_spec = SegyFile(input_path.as_posix(), spec=segy_spec).spec
94+
95+
with custom_path.open(mode="w") as f:
96+
json.dump(updated_spec.model_dump(mode="json"), f, indent=2)
97+
print(f"SEG-Y specification saved to '{custom_uri}'.")
98+
99+
return segy_spec
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import typer
6+
7+
if TYPE_CHECKING: # pragma: no cover - type checking only
8+
from segy.schema.format import TextHeaderEncoding
9+
from segy.schema.segy import SegySpec
10+
from upath import UPath
11+
12+
13+
def _read_text_header(input_path: UPath, segy_spec: SegySpec) -> str:
14+
"""Read file textual header from a SEG-Y file using the provided spec.
15+
16+
Important: Avoid SegyFile.text_header cached properties so that switching encodings reflects immediately.
17+
"""
18+
from segy import SegyFile
19+
20+
segy_file = SegyFile(input_path.as_posix(), spec=segy_spec)
21+
22+
# Clear any cached instances.
23+
if hasattr(segy_file.spec.text_header, "processor"):
24+
del segy_file.spec.text_header.processor
25+
if hasattr(segy_file, "text_header"):
26+
del segy_file.text_header
27+
28+
return segy_file.text_header
29+
30+
31+
def _pager(content: str) -> None:
32+
"""Show content via a pager if available; fallback to plain print."""
33+
try:
34+
typer.echo_via_pager(content)
35+
except Exception:
36+
print(content)
37+
38+
39+
def _format_text_header(main: str, encoding: TextHeaderEncoding) -> str:
40+
"""Format textual headers nicely for display or saving."""
41+
enc_label = getattr(encoding, "value", str(encoding))
42+
lines: list[str] = [
43+
f"Textual Header (encoding={enc_label})",
44+
"-" * 60,
45+
main.rstrip("\n"),
46+
]
47+
return "\n".join(lines)

0 commit comments

Comments
 (0)