diff --git a/RATapi/utils/plotting.py b/RATapi/utils/plotting.py index 8301dea4..92322c98 100644 --- a/RATapi/utils/plotting.py +++ b/RATapi/utils/plotting.py @@ -20,28 +20,258 @@ from RATapi.rat_core import PlotEventData, makeSLDProfile -def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_sided: bool, color: str): - """Plot the error bars. +def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool): + """Extract the plot data for the sld, ref, error plot lines. Parameters ---------- - ax : matplotlib.axes._axes.Axes - The axis on which to draw errorbars - x : np.ndarray - The shifted data x axis data - y : np.ndarray - The shifted data y axis data - err : np.ndarray - The shifted data e data - one_sided : bool - A boolean to indicate whether to draw one sided errorbars - color : str - The hex representing the color of the errorbars + event_data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + q4 : bool, default: False + Controls whether Q^4 is plotted on the reflectivity plot + show_error_bar : bool, default: True + Controls whether the error bars are shown + + Returns + ------- + plot_values : dict + A dict containing the data for the sld, ref, error plot lines. + + """ + results = {"ref": [], "error": [], "sld": [], "sld_resample": []} + + for i, (r, data, sld) in enumerate(zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles)): + # Calculate the divisor + div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1)) + q4_data = 1 if not q4 or not event_data.dataPresent[i] else data[:, 0] ** 4 + mult = q4_data / div + + # Plot the reflectivity on plot (1,1) + results["ref"].append([r[:, 0], r[:, 1] * mult]) + + if event_data.dataPresent[i]: + sd_x = data[:, 0] + sd_y, sd_e = map(lambda x: x * mult, (data[:, 1], data[:, 2])) + + if show_error_bar: + errors = np.zeros(len(sd_e)) + valid = sd_y - sd_e >= 0 + errors[valid] = sd_e[valid] + valid |= sd_y < 0 + + results["error"].append([sd_x[valid], sd_y[valid], sd_e[valid]]) + + results["sld"].append([]) + for j in range(len(sld)): + results["sld"][-1].append([sld[j][:, 0], sld[j][:, 1]]) + + if event_data.resample[i] == 1 or event_data.modelType == "custom xy": + layers = event_data.resampledLayers[i][0] + results["sld_resample"].append([]) + for j in range(len(event_data.resampledLayers[i])): + layer = event_data.resampledLayers[i][j] + if layers.shape[1] == 4: + layer = np.delete(layer, 2, 1) + new_profile = makeSLDProfile( + layers[0, 1], # Bulk In + layers[-1, 1], # Bulk Out + layer, + event_data.subRoughs[i], # roughness + 1, + ) + + results["sld_resample"][-1].append([new_profile[:, 0] - 49, new_profile[:, 1]]) + + return results + + +class PlotSLDWithBlitting: + """Create a SLD plot that uses blitting to get faster draws. + + The blit plot stores the background from an + initial draw then updates the foreground (lines and error bars) if the background is not changed. + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + fig : matplotlib.pyplot.figure, optional + The figure class that has two subplots + linear_x : bool, default: False + Controls whether the x-axis on reflectivity plot uses the linear scale + q4 : bool, default: False + Controls whether Q^4 is plotted on the reflectivity plot + show_error_bar : bool, default: True + Controls whether the error bars are shown + show_grid : bool, default: False + Controls whether the grid is shown + show_legend : bool, default: True + Controls whether the legend is shown """ - y_error = [[0] * len(err), err] if one_sided else err - ax.errorbar(x=x, y=y, yerr=y_error, fmt="none", ecolor=color, elinewidth=1, capsize=0) - ax.scatter(x=x, y=y, s=3, marker="o", color=color) + + def __init__( + self, + data: PlotEventData, + fig: Optional[matplotlib.pyplot.figure] = None, + linear_x: bool = False, + q4: bool = False, + show_error_bar: bool = True, + show_grid: bool = False, + show_legend: bool = True, + ): + self.figure = fig + self.linear_x = linear_x + self.q4 = q4 + self.show_error_bar = show_error_bar + self.show_grid = show_grid + self.show_legend = show_legend + self.updatePlot(data) + self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent) + + def __del__(self): + self.figure.canvas.mpl_disconnect(self.event_id) + + def resizeEvent(self, _event): + """Ensure the background is updated after a resize event.""" + self.__background_changed = True + + def update(self, data: PlotEventData): + """Update the foreground, if background has not changed otherwise it updates full plot. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + if self.__background_changed: + self.updatePlot(data) + else: + self.updateForeground(data) + + def __setattr__(self, name, value): + super().__setattr__(name, value) + if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend"]: + self.__background_changed = True + + def setAnimated(self, is_animated: bool): + """Set the animated property of foreground plot elements. + + Parameters + ---------- + is_animated : bool + Indicates if the animated property should been set. + """ + for line in self.figure.axes[0].lines: + line.set_animated(is_animated) + for line in self.figure.axes[1].lines: + line.set_animated(is_animated) + for container in self.figure.axes[0].containers: + container[2][0].set_animated(is_animated) + + def adjustErrorBar(self, error_bar_container, x, y, y_error): + """Adjust the error bar data. + + Parameters + ---------- + error_bar_container : Tuple + Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines) + x : np.ndarray + The shifted data x axis data + y : np.ndarray + The shifted data y axis data + y_error : np.ndarray + The shifted data y axis error data + """ + line, _, (bars_y,) = error_bar_container + + line.set_data(x, y) + x_base = x + y_base = y + + y_error_top = y_base + y_error + y_error_bottom = y_base - y_error + + new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)] + bars_y.set_segments(new_segments_y) + + def updatePlot(self, data: PlotEventData): + """Update the full plot. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + if self.figure is not None: + self.figure.clf() + self.figure = plot_ref_sld_helper( + data, + self.figure, + linear_x=self.linear_x, + q4=self.q4, + show_error_bar=self.show_error_bar, + show_grid=self.show_grid, + show_legend=self.show_legend, + animated=True, + ) + + self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox) + for line in self.figure.axes[0].lines: + self.figure.axes[0].draw_artist(line) + for line in self.figure.axes[1].lines: + self.figure.axes[1].draw_artist(line) + for container in self.figure.axes[0].containers: + self.figure.axes[0].draw_artist(container[2][0]) + self.figure.canvas.blit(self.figure.bbox) + self.setAnimated(False) + self.__background_changed = False + + def updateForeground(self, data: PlotEventData): + """Update the plot foreground only. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + self.setAnimated(True) + self.figure.canvas.restore_region(self.bg) + plot_data = _extract_plot_data(data, self.q4, self.show_error_bar) + + offset = 2 if self.show_error_bar else 1 + for i in range( + 0, + len(self.figure.axes[0].lines), + ): + self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1]) + self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i]) + + i = 0 + for j in range(len(plot_data["sld"])): + for sld in plot_data["sld"][j]: + self.figure.axes[1].lines[i].set_data(sld[0], sld[1]) + self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) + i += 1 + + if plot_data["sld_resample"]: + for resampled in plot_data["sld_resample"][j]: + self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1]) + self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) + i += 1 + + for i, container in enumerate(self.figure.axes[0].containers): + self.adjustErrorBar(container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2]) + self.figure.axes[0].draw_artist(container[2][0]) + self.figure.axes[0].draw_artist(container[0]) + + self.figure.canvas.blit(self.figure.bbox) + self.figure.canvas.flush_events() + self.setAnimated(False) def plot_ref_sld_helper( @@ -54,36 +284,39 @@ def plot_ref_sld_helper( show_error_bar: bool = True, show_grid: bool = False, show_legend: bool = True, + animated=False, ): """Clear the previous plots and updates the ref and SLD plots. Parameters ---------- data : PlotEventData - The plot event data that contains all the information - to generate the ref and sld plots + The plot event data that contains all the information + to generate the ref and sld plots fig : matplotlib.pyplot.figure, optional - The figure class that has two subplots + The figure class that has two subplots delay : bool, default: True - Controls whether to delay 0.005s after plot is created + Controls whether to delay 0.005s after plot is created confidence_intervals : dict or None, default None The Bayesian confidence intervals for reflectivity and SLD. Only relevant if the procedure used is Bayesian (NS or DREAM) linear_x : bool, default: False - Controls whether the x-axis on reflectivity plot uses the linear scale + Controls whether the x-axis on reflectivity plot uses the linear scale q4 : bool, default: False - Controls whether Q^4 is plotted on the reflectivity plot + Controls whether Q^4 is plotted on the reflectivity plot show_error_bar : bool, default: True - Controls whether the error bars are shown + Controls whether the error bars are shown show_grid : bool, default: False - Controls whether the grid is shown + Controls whether the grid is shown show_legend : bool, default: True - Controls whether the lengend is shown + Controls whether the legend is shown + animated : bool, default: False + Controls whether the animated property of foreground plot elements should be set. Returns ------- fig : matplotlib.pyplot.figure - The figure class that has two subplots + The figure class that has two subplots """ preserve_zoom = False @@ -105,68 +338,55 @@ def plot_ref_sld_helper( ref_plot.cla() sld_plot.cla() - for i, (r, sd, sld, name) in enumerate( - zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.contrastNames), - ): - # Calculate the divisor - div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1)) - q4_data = 1 if not q4 or not data.dataPresent[i] else sd[:, 0] ** 4 - mult = q4_data / div - - # Plot the reflectivity on plot (1,1) - ref_plot.plot(r[:, 0], r[:, 1] * mult, label=name, linewidth=2) + plot_data = _extract_plot_data(data, q4, show_error_bar) + for i, name in enumerate(data.contrastNames): + ref_plot.plot(plot_data["ref"][i][0], plot_data["ref"][i][1], label=name, linewidth=1, animated=animated) color = ref_plot.get_lines()[-1].get_color() # Plot confidence intervals if required if confidence_intervals is not None: + # Calculate the divisor + div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1)) ref_min, ref_max = confidence_intervals["reflectivity"][i] - mult = (1 if not q4 else r[:, 0] ** 4) / div - ref_plot.fill_between(r[:, 0], ref_min * mult, ref_max * mult, alpha=0.6, color="grey") - - if data.dataPresent[i]: - sd_x = sd[:, 0] - sd_y, sd_e = map(lambda x: x * mult, (sd[:, 1], sd[:, 2])) - - if show_error_bar: - # Plot the errorbars - indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0]) - sd_x_r, sd_y_r, sd_e_r = map(lambda x: np.delete(x, indices_removed), (sd_x, sd_y, sd_e)) - plot_errorbars(ref_plot, sd_x_r, sd_y_r, sd_e_r, False, color) - - # Plot one sided errorbars - indices_selected = [x for x in indices_removed if x not in np.nonzero(sd_y < 0)[0]] - sd_x_s, sd_y_s, sd_e_s = map(lambda x: [x[i] for i in indices_selected], (sd_x, sd_y, sd_e)) - plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color) + mult = (1 if not q4 else plot_data["ref"][i][0] ** 4) / div + ref_plot.fill_between(plot_data["ref"][i][0], ref_min * mult, ref_max * mult, alpha=0.6, color="grey") + + if data.dataPresent[i] and show_error_bar: + # Plot the errorbars + ref_plot.errorbar( + x=plot_data["error"][i][0], + y=plot_data["error"][i][1], + yerr=plot_data["error"][i][2], + elinewidth=1, + ecolor=color, + marker=".", + markersize=3, + linestyle="none", + color=color, + capsize=0, + animated=animated, + ) # Plot the slds on plot (1,2) - for j in range(len(sld)): - label = name if len(sld) == 1 else f"{name} Domain {j + 1}" - sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=label, linewidth=1) + for j in range(len(plot_data["sld"][i])): + label = name if len(plot_data["sld"][i]) == 1 else f"{name} Domain {j + 1}" + sld_plot.plot( + plot_data["sld"][i][j][0], plot_data["sld"][i][j][1], label=label, linewidth=1, animated=animated + ) - # Plot confidence intervals if required - if confidence_intervals is not None: - sld_min, sld_max = confidence_intervals["sld"][i][j] - sld_plot.fill_between(sld[j][:, 0], sld_min, sld_max, alpha=0.6, color="grey") - - if data.resample[i] == 1 or data.modelType == "custom xy": - layers = data.resampledLayers[i][0] - for j in range(len(data.resampledLayers[i])): - layer = data.resampledLayers[i][j] - if layers.shape[1] == 4: - layer = np.delete(layer, 2, 1) - new_profile = makeSLDProfile( - layers[0, 1], # Bulk In - layers[-1, 1], # Bulk Out - layer, - data.subRoughs[i], # roughness - 1, - ) + # Plot confidence intervals if required + if confidence_intervals is not None: + sld_min, sld_max = confidence_intervals["sld"][i][j] + sld_plot.fill_between(plot_data["sld"][i][j][0], sld_min, sld_max, alpha=0.6, color="grey") + if plot_data["sld_resample"]: + for j in range(len(plot_data["sld_resample"][i])): sld_plot.plot( - [row[0] - 49 for row in new_profile], - [row[1] for row in new_profile], + plot_data["sld_resample"][i][j][0], + plot_data["sld_resample"][i][j][1], color=color, linewidth=1, + animated=animated, ) # Format the axis @@ -177,7 +397,7 @@ def plot_ref_sld_helper( ref_plot.set_ylabel("Reflectivity") sld_plot.set_xlabel("$Z (\u00c5)$") - sld_plot.set_ylabel("$SLD (\u00c5^{-2})$") + sld_plot.set_ylabel("$SLD (\u00c5^{-2})$", labelpad=1) if show_legend: ref_plot.legend() @@ -317,7 +537,7 @@ class LivePlot: Parameters ---------- block : bool, default: False - Indicates the plot should block until it is closed + Indicates the plot should block until it is closed """ diff --git a/tests/test_plotting.py b/tests/test_plotting.py index d907ba16..4cdcb0aa 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -105,27 +105,20 @@ def test_figure_axis_formatting(fig: plt.figure) -> None: ] -def test_ref_sld_color_formating(fig: plt.figure) -> None: - """Tests the color formating of the figure.""" +def test_ref_sld_color_formatting(fig: plt.figure) -> None: + """Tests the color formatting of the figure.""" ref_plot = fig.axes[0] sld_plot = fig.axes[1] - assert len(ref_plot.get_lines()) == 3 + assert len(ref_plot.get_lines()) == 6 assert len(sld_plot.get_lines()) == 6 - for axis_ix in range(len(ref_plot.get_lines())): - ax1 = axis_ix * 2 - ax2 = ax1 + 1 - + for i in range(0, len(ref_plot.get_lines()), 2): # Tests whether the color of the line and the errorbars match on the ref_plot - assert ( - ref_plot.containers[ax1][2][0]._original_edgecolor - == ref_plot.containers[ax2][2][0]._original_edgecolor - == ref_plot.get_lines()[axis_ix].get_color() - ) + assert ref_plot.containers[i // 2][2][0]._original_edgecolor == ref_plot.get_lines()[i].get_color() # Tests whether the color of the sld and resampled_sld match on the sld_plot - assert sld_plot.get_lines()[ax1].get_color() == sld_plot.get_lines()[ax2].get_color() + assert sld_plot.get_lines()[i].get_color() == sld_plot.get_lines()[i + 1].get_color() @pytest.mark.parametrize("bayes", [65, 95]) @@ -483,3 +476,24 @@ def test_bayes_validation(input_project, reflectivity_calculation_results): ValueError, match=r"Bayes plots are only available for the results of Bayesian analysis \(NS or DREAM\)" ): RATplot.plot_bayes(input_project, reflectivity_calculation_results) + + +@pytest.mark.parametrize("data", [data(), domains_data()]) +def test_extract_plot_data(data) -> None: + plot_data = RATplot._extract_plot_data(data, False, True) + assert len(plot_data["ref"]) == len(data.reflectivity) + assert len(plot_data["sld"]) == len(data.shiftedData) + + +@patch("RATapi.utils.plotting.plot_ref_sld_helper") +def test_blit_plot(plot_helper, fig: plt.figure) -> None: + plot_helper.return_value = fig + event_data = data() + new_plot = RATplot.PlotSLDWithBlitting(event_data) + assert plot_helper.call_count == 1 + new_plot.update(event_data) + assert plot_helper.call_count == 1 # foreground only is updated so no call to plot helper + new_plot.show_grid = False + new_plot.figure = plt.subplots(1, 2)[0] + new_plot.update(event_data) # plot properties have changed so update should call plot_helper + assert plot_helper.call_count == 2