Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,23 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition:
The problem input used in the compiled RAT code.

"""
hydrate_id = {"bulk in": 1, "bulk out": 2}
prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3}

# Ensure backgrounds and resolutions have a source defined
# Ensure all contrast fields are properly defined
for contrast in project.contrasts:
contrast_fields = ["data", "background", "bulk_in", "bulk_out", "scalefactor", "resolution"]

if project.calculation == Calculations.Domains:
contrast_fields.append("domain_ratio")

for field in contrast_fields:
if getattr(contrast, field) == "":
raise ValueError(
f'In the input project, the "{field}" field of contrast "{contrast.name}" does not have a '
f"value defined. A value must be supplied before running the project."
)

# Ensure backgrounds and resolutions have a source defined
background = project.backgrounds[contrast.background]
resolution = project.resolutions[contrast.resolution]
if background.source == "":
Expand Down Expand Up @@ -191,22 +203,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition:
contrast_custom_files = [project.custom_files.index(contrast.model[0], True) for contrast in project.contrasts]

# Get details of defined layers
layer_details = []
for layer in project.layers:
if project.absorption:
layer_params = [
project.parameters.index(getattr(layer, attribute), True)
for attribute in list(RATapi.models.AbsorptionLayer.model_fields.keys())[1:-2]
]
else:
layer_params = [
project.parameters.index(getattr(layer, attribute), True)
for attribute in list(RATapi.models.Layer.model_fields.keys())[1:-2]
]
layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN"))
layer_params.append(hydrate_id[layer.hydrate_with])

layer_details.append(layer_params)
layer_details = get_layer_details(project)

contrast_background_params = []
contrast_background_types = []
Expand Down Expand Up @@ -387,6 +384,35 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition:
return problem


def get_layer_details(project: RATapi.Project) -> list[int]:
"""Get parameter indices for all layers defined in the project."""
hydrate_id = {"bulk in": 1, "bulk out": 2}
layer_details = []

# Get the thickness, SLD, roughness fields from the appropriate model
if project.absorption:
layer_fields = list(RATapi.models.AbsorptionLayer.model_fields.keys())[1:-2]
else:
layer_fields = list(RATapi.models.Layer.model_fields.keys())[1:-2]

for layer in project.layers:
for field in layer_fields:
if getattr(layer, field) == "":
raise ValueError(
f'In the input project, the "{field}" field of layer {layer.name} does not have a value '
f"defined. A value must be supplied before running the project."
)

layer_params = [project.parameters.index(getattr(layer, attribute), True) for attribute in list(layer_fields)]

layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN"))
layer_params.append(hydrate_id[layer.hydrate_with])

layer_details.append(layer_params)

return layer_details


def make_resample(project: RATapi.Project) -> list[int]:
"""Construct the "resample" field of the problem input required for the compiled RAT code.

Expand Down
63 changes: 63 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,69 @@ def test_background_params_value_indices(self, test_problem, bad_value, request)
check_indices(test_problem)


@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"])
@pytest.mark.parametrize("field", ["data", "background", "bulk_in", "bulk_out", "scalefactor", "resolution"])
def test_undefined_contrast_fields(test_project, field, request):
"""If a field in a contrast is empty, we should raise an error."""
test_project = request.getfixturevalue(test_project)
setattr(test_project.contrasts[0], field, "")

with pytest.raises(
ValueError,
match=f'In the input project, the "{field}" field of contrast '
f'"{test_project.contrasts[0].name}" does not have a value defined. '
f"A value must be supplied before running the project.",
):
make_problem(test_project)


@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"])
def test_undefined_background(test_project, request):
"""If the source field of a background defined in a contrast is empty, we should raise an error."""
test_project = request.getfixturevalue(test_project)
background = test_project.backgrounds[test_project.contrasts[0].background]
background.source = ""

with pytest.raises(
ValueError,
match=f"All backgrounds must have a source defined. For a {background.type} type "
f"background, the source must be defined in "
f'"{RATapi.project.values_defined_in[f"backgrounds.{background.type}.source"]}"',
):
make_problem(test_project)


@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"])
def test_undefined_resolution(test_project, request):
"""If the source field of a resolution defined in a contrast is empty, we should raise an error."""
test_project = request.getfixturevalue(test_project)
resolution = test_project.resolutions[test_project.contrasts[0].resolution]
resolution.source = ""

with pytest.raises(
ValueError,
match=f"Constant resolutions must have a source defined. The source must be defined in "
f'"{RATapi.project.values_defined_in[f"resolutions.{resolution.type}.source"]}"',
):
make_problem(test_project)


@pytest.mark.parametrize("test_project", ["standard_layers_project", "domains_project"])
@pytest.mark.parametrize("field", ["thickness", "SLD", "roughness"])
def test_undefined_layers(test_project, field, request):
"""If the thickness, SLD, or roughness fields of a layer defined in the project are empty, we should raise an
error."""
test_project = request.getfixturevalue(test_project)
setattr(test_project.layers[0], field, "")

with pytest.raises(
ValueError,
match=f'In the input project, the "{field}" field of layer {test_project.layers[0].name} '
f"does not have a value defined. A value must be supplied before running the project.",
):
make_problem(test_project)


def test_append_data_background():
"""Test that background data is correctly added to contrast data."""
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Expand Down