From 8ba18f75e29e94db5823f24d130e30f368500e81 Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Mon, 20 Apr 2026 18:48:45 -0400 Subject: [PATCH 01/10] Fix #5427 - Line not shared across matched axes - Bug was that function marked the axis to be connected, but the trace_kwargs still had unique axes - Change: change the keyword argument for the trace, so that when the graph is initialized, it uses the correct axis instead of the autogenerated one - Note: The program generates a unique axis label for each subgraph, and then overwrites the label (under this fix) --- plotly/_subplots.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index 16a3958637e..98220b4ad63 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -926,6 +926,7 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): else: axis_name = subplot_ref.layout_keys[layout_key_ind] axis_to_match = layout[axis_name] + subplot_ref.trace_kwargs[axis_name] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match) axis_to_match.matches = first_axis_id if remove_label: axis_to_match.showticklabels = False @@ -981,6 +982,7 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): first_axis_id = update_axis_matches( first_axis_id, subplot_ref, spec, ok_to_remove_label ) + def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None): From 79dc240153b61501b303a35d82266c8a6c38899e Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Mon, 20 Apr 2026 18:54:46 -0400 Subject: [PATCH 02/10] update change log --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 126ff280b49..24f8eec6855 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Fixed +- Bug was that function marked the axis to be connected, but the trace_kwargs still had unique axes [[#5427](https://github.com/plotly/plotly.py/issues/5427)] +- Change: change the keyword argument for the trace, so that when the graph is initialized, it uses the correct axis instead of the autogenerated one +- Note: The program generates a unique axis label for each subgraph, and then overwrites the label (under this fix) + + ### Fixed - Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)], with thanks to @antonymilne for the contribution! - Update tests to be compatible with numpy 2.4 [[#5522](https://github.com/plotly/plotly.py/pull/5522)], with thanks to @thunze for the contribution! From 480f72879addc9b7f41129308e2d1e1038d31515 Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Wed, 22 Apr 2026 02:04:35 -0400 Subject: [PATCH 03/10] Noticed a bug with multiple plots --- plotly/_subplots.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index 98220b4ad63..edb1cc7b468 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -916,8 +916,10 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): if x_or_y == "x": span = spec["colspan"] + match_axis = 'xaxis' else: span = spec["rowspan"] + match_axis = 'yaxis' if subplot_ref.subplot_type == "xy" and span == 1: if first_axis_id is None: @@ -926,7 +928,7 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): else: axis_name = subplot_ref.layout_keys[layout_key_ind] axis_to_match = layout[axis_name] - subplot_ref.trace_kwargs[axis_name] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match) + subplot_ref.trace_kwargs[match_axis] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match) axis_to_match.matches = first_axis_id if remove_label: axis_to_match.showticklabels = False From a243ceeae865e7e0b0dcf38823ada25671421d2c Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Wed, 22 Apr 2026 02:06:03 -0400 Subject: [PATCH 04/10] Regression tests added - Checks if the axes match when the should and don't match when they shouldn't --- CHANGELOG.md | 1 - .../test_subplots/test_make_subplots.py | 199 ++++++++++++++++++ 2 files changed, 199 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24f8eec6855..b296ee382b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,6 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Change: change the keyword argument for the trace, so that when the graph is initialized, it uses the correct axis instead of the autogenerated one - Note: The program generates a unique axis label for each subgraph, and then overwrites the label (under this fix) - ### Fixed - Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)], with thanks to @antonymilne for the contribution! - Update tests to be compatible with numpy 2.4 [[#5522](https://github.com/plotly/plotly.py/pull/5522)], with thanks to @thunze for the contribution! diff --git a/tests/test_optional/test_subplots/test_make_subplots.py b/tests/test_optional/test_subplots/test_make_subplots.py index 4552c66a694..23df450388c 100644 --- a/tests/test_optional/test_subplots/test_make_subplots.py +++ b/tests/test_optional/test_subplots/test_make_subplots.py @@ -56,3 +56,202 @@ def test_add_traces_with_integers(self): expected_data_length = 4 self.assertEqual(expected_data_length, len(fig2.data)) + +class TestSharedAxisOnMakeColumn(TestCase): + """ + Regression test for #5427: traces should reference the primary axis + when shared_xaxes=True, so spike lines and hover sync work correctly. + """ + + def test_xaxes_shared_columns_mode_single_column(self): + """ + When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=3, col=1) + + + # The x-axis of all of the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") + self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") + + def test_xaxes_shared_columns_mode_multiple_columns(self): + """ + When 'columns' mode for shared_xaxis, different columns should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=1, col=2) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures that are in the same column should be the same, and different if they are in different columns + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_shared_rows_mode_single_row(self): + """ + When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=1, cols=3, shared_xaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=1, col=3) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 3 have different x-axes") + self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 2 and Figure 3 have different x-axes") + + def test_xaxes_shared_rows_mode_multiple_rows(self): + """ + When 'rows' mode for shared_xaxis, different rows should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures in the same row should be the same, and different if they are in different rows + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis row don't match: Figure 3 and Figure 4 have different x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis row match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis row match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_shared_all_mode(self): + """ + When 'all' mode for shared_xaxis, all rows share the same x-axes + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") + self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_not_shared_mode(self): + """ + When not shared, all plots have different x-axes + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be different + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file From 4b5cb849473d8104c05a03b85ace12a6df9adb29 Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Tue, 28 Apr 2026 22:41:18 -0400 Subject: [PATCH 05/10] Rewrite to make it pass tests There was an issue with the ordering in finding the next axis, so I ended up rewriting the section to try to fix the error - One of the tests in the make_subplots fails, but unsure if that is unintended, as by the spec, the xaxis should be shared when using shared_axis (but it isn't in the test; each subplot has its own axis) - Also the ruff check fails, but it is because the Dictionary typed is matching the plotly --- plotly/_subplots.py | 263 +++++++++++++++++++++++++++++--------------- 1 file changed, 174 insertions(+), 89 deletions(-) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index edb1cc7b468..e5974b7dcfb 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -8,6 +8,9 @@ # little differently. import collections +import plotly.graph_objects as go +from typing import Literal, Optional, Tuple, TypedDict, Iterable + _single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"} _subplot_types = set.union(_single_subplot_types, {"xy", "domain"}) @@ -31,6 +34,16 @@ "SubplotRef", ("subplot_type", "layout_keys", "trace_kwargs") ) +class SubplotSpec(TypedDict): + type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str + secondary_y : bool + colspan : int + rowspan : int + l : float + r : float + t : float + b : float + def _get_initial_max_subplot_ids(): max_subplot_ids = {subplot_type: 0 for subplot_type in _single_subplot_types} @@ -889,103 +902,175 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): return figure - def _configure_shared_axes( - layout, grid_ref, specs, x_or_y, shared, row_dir, secondary_y -): - rows = len(grid_ref) - cols = len(grid_ref[0]) - - layout_key_ind = ["x", "y"].index(x_or_y) - - if row_dir < 0: - rows_iter = range(rows - 1, -1, -1) - else: - rows_iter = range(rows) - - if secondary_y: - cols_iter = range(cols - 1, -1, -1) - axis_index = 1 - else: - cols_iter = range(cols) - axis_index = 0 - - def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): - if subplot_ref is None: - return first_axis_id - - if x_or_y == "x": - span = spec["colspan"] - match_axis = 'xaxis' - else: - span = spec["rowspan"] - match_axis = 'yaxis' + layout : go.Layout, + grid_ref : Tuple[Tuple[SubplotRef]], + specs : Tuple[Tuple[SubplotSpec]], + x_or_y : Literal['x', 'y'], + shared : bool | Literal['rows', 'columns', 'all'], + row_direction : Literal[1, -1], + secondary_y : bool +) -> None: + ''' + Sets the axes to be shared, making them use the same axis + + Parameters: + ----------- + layout (go.Layout) : The layout of the figure to be updating + grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate + specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate + x_or_y ('x' | 'y') : The axis to make shared (x-axis or y-axis) + shared ('rows' | 'columns' | 'all' | bool) : Share the axis within the row, column, or across all of the subplots (True defaults to columns mode) + row_direction (1 | -1) : The directional that the rows go + secondary_y (bool) : Whether there are different or shared y-axis + ''' + + row_count : int = len(grid_ref) + column_count : int = len(grid_ref[0]) - if subplot_ref.subplot_type == "xy" and span == 1: - if first_axis_id is None: - first_axis_name = subplot_ref.layout_keys[layout_key_ind] - first_axis_id = first_axis_name.replace("axis", "") - else: - axis_name = subplot_ref.layout_keys[layout_key_ind] - axis_to_match = layout[axis_name] - subplot_ref.trace_kwargs[match_axis] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match) - axis_to_match.matches = first_axis_id - if remove_label: - axis_to_match.showticklabels = False - - return first_axis_id - - if shared == "columns" or (x_or_y == "x" and shared is True): - for c in cols_iter: - first_axis_id = None - ok_to_remove_label = x_or_y == "x" - for r in rows_iter: - if not grid_ref[r][c]: - continue - if axis_index >= len(grid_ref[r][c]): - continue - subplot_ref = grid_ref[r][c][axis_index] - spec = specs[r][c] - first_axis_id = update_axis_matches( - first_axis_id, subplot_ref, spec, ok_to_remove_label - ) + rows : Iterable[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) + columns : Iterable[int] = tuple(range(column_count - 1, -1, -1)) if secondary_y else tuple(range(column_count)) - elif shared == "rows" or (x_or_y == "y" and shared is True): - for r in rows_iter: - first_axis_id = None - ok_to_remove_label = x_or_y == "y" - for c in cols_iter: - if not grid_ref[r][c]: - continue - if axis_index >= len(grid_ref[r][c]): + axis_index : int = 1 if secondary_y else 0 + layout_axis_index : int = 0 if x_or_y == 'x' else 1 + + def find_label_and_index(row_order : int | Iterable[int], column_order : int | Iterable[int]) -> Optional[Tuple[str, Tuple[int, int]]]: + ''' + Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists + + Parameters: + ----------- + row_order (int | Iterable[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable + column_order (int | Iterable[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable + + Return: + ------- + Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull) + Return (None): No label was found + ''' + + # Turn them into lists with one element, so that both row_order and column_order are iterables + row_order : Iterable[int] = [row_order] if isinstance(row_order, int) else row_order + column_order : Iterable[int] = [column_order] if isinstance(column_order, int) else column_order + + + # Iterate through the rows and columns + for row in row_order: + for column in column_order: + if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): + continue + + subplot_reference : SubplotRef = grid_ref[row][column][axis_index] + spec : SubplotSpec = specs[row][column] + + if subplot_reference is None: + continue + + span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] + if subplot_reference.subplot_type != 'xy' or span != 1: + continue + + label_name : str = subplot_reference.layout_keys[layout_axis_index] + label : str = label_name.replace("axis", "") + return label, (row, column) + return None + + + def update_trace_axis(matched_label : str, row : int, column : int, can_remove_label : bool) -> None: + ''' + Updates the trace at the given row and column with the given label, and removes the label visibility if necessary + + Parameters: + ----------- + matched_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location + row (int) : The row of the subplot within grid_ref to update + column (int) : The column of the subplot within grid_ref to update + can_remove_label (bool): Whether the label should be visible (only the bottom label should be visible) + can_change_trace_kwargs (bool): If True the label itself can be changed directly to be the exact same axis (ie use the exact same axis in the trace keyword arguments), or if False, can only mark as matching (ie don't change the trace keyword args) + ''' + if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): + return + + subplot_reference : SubplotRef = grid_ref[row][column][axis_index] + spec : SubplotSpec = specs[row][column] + + if subplot_reference is None: + return + + span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] + if subplot_reference.subplot_type != 'xy' or span != 1: + return + + axis_name : str = subplot_reference.layout_keys[layout_axis_index] + axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' + axis : go.XAxis = layout[axis_name] + + axis.matches = matched_label + subplot_reference.trace_kwargs[axis_dimension] = matched_label + + if can_remove_label: + axis.showticklabels = False + + def columns_mode(): + for column in columns: + # Get the label used by all the rows in the column + label_data = find_label_and_index(rows, column) + if label_data is None: + continue + column_label, (label_row, _) = label_data + # Set all of the values in the column + + can_remove_label : bool = (x_or_y == 'x') + + for row in rows: + if row == label_row: # Don't update the figure that the label we are matching comes from continue - subplot_ref = grid_ref[r][c][axis_index] - spec = specs[r][c] - first_axis_id = update_axis_matches( - first_axis_id, subplot_ref, spec, ok_to_remove_label - ) + + update_trace_axis(column_label, row, column, can_remove_label) - elif shared == "all": - first_axis_id = None - for ri, r in enumerate(rows_iter): - for c in cols_iter: - if not grid_ref[r][c]: - continue - if axis_index >= len(grid_ref[r][c]): + + def rows_mode(): + for row in rows: + label_data = find_label_and_index(row, columns) + if label_data is None: + continue + row_label, (_, label_column) = label_data + + can_remove_label : bool = (x_or_y == 'y') + + for column in columns: + if column == label_column: # Don't update the figure that the label we are matching comes from + continue + + update_trace_axis(row_label, row, column, can_remove_label) + + def all_mode(): + label_data = find_label_and_index(rows, columns) + if label_data is None: + return + label, (label_row, label_column) = label_data + + for row_index, row in enumerate(rows): + for column in columns: + if row == label_row and column == label_column: # Don't update the figure that the label we are matching comes from continue - subplot_ref = grid_ref[r][c][axis_index] - spec = specs[r][c] - if x_or_y == "y": - ok_to_remove_label = c < cols - 1 if secondary_y else c > 0 + if x_or_y == 'y': + can_remove_label : bool = (column < column_count - 1 if secondary_y else column > 0) else: - ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1 - - first_axis_id = update_axis_matches( - first_axis_id, subplot_ref, spec, ok_to_remove_label - ) - - + can_remove_label : bool = (row_index > 0 if row_direction > 0 else row < row_count - 1) + + update_trace_axis(label, row, column, can_remove_label) + + match(shared, x_or_y, shared): + case ('columns', _, _) | (_, 'x', True): # If columns mode, or shared and x + columns_mode() + case ('rows', _, _) | (_, 'y', True): # If rows mode, or shared and y + rows_mode() + case ('all', _, _): # If all mode + all_mode() + case _: # If reached the other case + return def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None): if max_subplot_ids is None: From 53314625e404ee31083c2128df6479eca8afbd3e Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:28:01 -0400 Subject: [PATCH 06/10] Undoing the fix, due to breaking 'all' mode - I ended up refactoring the code to try to find the problem, but it turns out my solution breaks when all mode is selected, so I have commented that part out - I am including this, because I thought the refactor is nice, but this ended up not being a good fix --- plotly/_subplots.py | 220 +++++++++--------- .../test_subplots/test_make_subplots.py | 10 +- 2 files changed, 117 insertions(+), 113 deletions(-) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index e5974b7dcfb..301988555d4 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -6,10 +6,12 @@ # properties. # Note that this set does not contain `xaxis`/`yaxis` because these behave a # little differently. +from __future__ import annotations import collections -import plotly.graph_objects as go -from typing import Literal, Optional, Tuple, TypedDict, Iterable +from typing import Literal, Optional, Tuple, TypedDict, TYPE_CHECKING +if TYPE_CHECKING: + from plotly.graph_objects import Layout, XAxis _single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"} _subplot_types = set.union(_single_subplot_types, {"xy", "domain"}) @@ -38,8 +40,9 @@ class SubplotSpec(TypedDict): type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str secondary_y : bool colspan : int - rowspan : int - l : float + rowspan : int + # NOTE: that this is the dictionary as defined by the documentation, so the ambiguous name 'l' can't be changed without changing the documentation + l : float # noqa: E741 r : float t : float b : float @@ -759,19 +762,10 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): ) grid_ref[r][c] = subplot_refs - _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir, False) - _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir, False) - any_secondary_y = any( - spec["secondary_y"] - for spec_row in specs - for spec in spec_row - if spec is not None - ) - if any_secondary_y: - _configure_shared_axes( - layout, grid_ref, specs, "y", shared_yaxes, row_dir, True - ) + _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir) + _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir) + # Build inset reference # --------------------- @@ -903,46 +897,40 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): return figure def _configure_shared_axes( - layout : go.Layout, + layout : Layout, grid_ref : Tuple[Tuple[SubplotRef]], specs : Tuple[Tuple[SubplotSpec]], x_or_y : Literal['x', 'y'], shared : bool | Literal['rows', 'columns', 'all'], - row_direction : Literal[1, -1], - secondary_y : bool + row_direction : Literal[1, -1] ) -> None: ''' - Sets the axes to be shared, making them use the same axis + Sets the axes to be shared, making them use the same axis Parameters: ----------- layout (go.Layout) : The layout of the figure to be updating grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate - x_or_y ('x' | 'y') : The axis to make shared (x-axis or y-axis) - shared ('rows' | 'columns' | 'all' | bool) : Share the axis within the row, column, or across all of the subplots (True defaults to columns mode) - row_direction (1 | -1) : The directional that the rows go - secondary_y (bool) : Whether there are different or shared y-axis + x_or_y ('x' | 'y') : The axis to configure + shared ('rows' | 'columns' | 'all' | bool) : The sharing mode, (True is 'columns' mode, False means no sharing) ie share the axis with all subplots in the corresponding row, column, or entire figure + row_direction (1 | -1) : The directional that the rows go ''' row_count : int = len(grid_ref) column_count : int = len(grid_ref[0]) - rows : Iterable[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) - columns : Iterable[int] = tuple(range(column_count - 1, -1, -1)) if secondary_y else tuple(range(column_count)) - - axis_index : int = 1 if secondary_y else 0 - layout_axis_index : int = 0 if x_or_y == 'x' else 1 + axis_index : int = 0 if x_or_y == 'x' else 1 - def find_label_and_index(row_order : int | Iterable[int], column_order : int | Iterable[int]) -> Optional[Tuple[str, Tuple[int, int]]]: + def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tuple[int], trace_layer : int) -> Optional[Tuple[str, Tuple[int, int]]]: ''' - Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists + Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS Parameters: ----------- - row_order (int | Iterable[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable - column_order (int | Iterable[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable - + row_order (int | Tuple[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable + column_order (int | Tuple[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable + trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1] Return: ------- Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull) @@ -950,125 +938,139 @@ def find_label_and_index(row_order : int | Iterable[int], column_order : int | I ''' # Turn them into lists with one element, so that both row_order and column_order are iterables - row_order : Iterable[int] = [row_order] if isinstance(row_order, int) else row_order - column_order : Iterable[int] = [column_order] if isinstance(column_order, int) else column_order + row_order : Tuple[int] = [row_order] if isinstance(row_order, int) else row_order + column_order : Tuple[int] = [column_order] if isinstance(column_order, int) else column_order # Iterate through the rows and columns for row in row_order: for column in column_order: - if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): + if not grid_ref[row][column]: continue - - subplot_reference : SubplotRef = grid_ref[row][column][axis_index] - spec : SubplotSpec = specs[row][column] - if subplot_reference is None: - continue - - span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] - if subplot_reference.subplot_type != 'xy' or span != 1: + subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column] + subplot_spec : SubplotSpec = specs[row][column] + + span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan'] + if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces): + continue + + trace = subplot_traces[trace_layer] + if trace is None or trace.subplot_type != 'xy': continue - label_name : str = subplot_reference.layout_keys[layout_axis_index] + label_name : str = trace.layout_keys[axis_index] label : str = label_name.replace("axis", "") return label, (row, column) return None - def update_trace_axis(matched_label : str, row : int, column : int, can_remove_label : bool) -> None: + def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None: ''' - Updates the trace at the given row and column with the given label, and removes the label visibility if necessary + Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS Parameters: ----------- - matched_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location - row (int) : The row of the subplot within grid_ref to update - column (int) : The column of the subplot within grid_ref to update - can_remove_label (bool): Whether the label should be visible (only the bottom label should be visible) - can_change_trace_kwargs (bool): If True the label itself can be changed directly to be the exact same axis (ie use the exact same axis in the trace keyword arguments), or if False, can only mark as matching (ie don't change the trace keyword args) + axis_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location + row (int) : The row of the subplot within grid_ref to update + column (int) : The column of the subplot within grid_ref to update + trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1] + can_reassign_axis (bool): If True, can change the unique axis for the shared axis in the trace keywords, otherwise, will keep using the axis name it already has + can_hide_ticks (bool): If the function is allowed to hide the ticks (if True, it will hide the ticks, if False, it will leave the ticks as their current state) + can_match_axis (bool): If the axis should be marked as a match to the axis label ''' - if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): - return - - subplot_reference : SubplotRef = grid_ref[row][column][axis_index] - spec : SubplotSpec = specs[row][column] - - if subplot_reference is None: - return - span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] - if subplot_reference.subplot_type != 'xy' or span != 1: - return + if not grid_ref[row][column] or specs[row][column] is None: + return - axis_name : str = subplot_reference.layout_keys[layout_axis_index] - axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' - axis : go.XAxis = layout[axis_name] + subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column] + subplot_spec : SubplotSpec = specs[row][column] - axis.matches = matched_label - subplot_reference.trace_kwargs[axis_dimension] = matched_label + span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan'] + if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces): + return - if can_remove_label: - axis.showticklabels = False + trace : Optional[SubplotRef] = subplot_traces[trace_layer] + + if trace is None or trace.subplot_type != 'xy' or span != 1: + return + + axis_name : str = trace.layout_keys[axis_index] + axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' + axis : XAxis = layout[axis_name] - def columns_mode(): + if can_match_axis: + axis.matches = axis_label + + if can_hide_ticks: + axis.showticklabels = False + + if can_reassign_axis: + # trace.trace_kwargs[axis_dimension] = axis_label + pass + + def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): for column in columns: # Get the label used by all the rows in the column - label_data = find_label_and_index(rows, column) + label_data = find_label_and_index(rows, column, trace_layer) if label_data is None: continue - column_label, (label_row, _) = label_data - # Set all of the values in the column + axis_label, (label_row, _) = label_data - can_remove_label : bool = (x_or_y == 'x') - + # Set all of the values in the column for row in rows: - if row == label_row: # Don't update the figure that the label we are matching comes from - continue - - update_trace_axis(column_label, row, column, can_remove_label) + subplot_spec : SubplotSpec = specs[row][column] + can_reassign_axis : bool = (x_or_y != 'y' or not subplot_spec["secondary_y"]) # Every subplot in the same column should share the same axis if in columns mode + can_match_axis : bool = (row != label_row) + can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns + + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) - def rows_mode(): + def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): for row in rows: - label_data = find_label_and_index(row, columns) + label_data = find_label_and_index(row, columns, trace_layer) if label_data is None: continue - row_label, (_, label_column) = label_data - - can_remove_label : bool = (x_or_y == 'y') + axis_label, (_, label_column) = label_data for column in columns: - if column == label_column: # Don't update the figure that the label we are matching comes from - continue + spec : SubplotSpec = specs[row][column] + can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y']) + can_match_axis : bool = (column != label_column) + can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row - update_trace_axis(row_label, row, column, can_remove_label) + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) - def all_mode(): - label_data = find_label_and_index(rows, columns) + def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): + label_data = find_label_and_index(rows, columns, trace_layer) if label_data is None: return - label, (label_row, label_column) = label_data + axis_label, (label_row, label_column) = label_data - for row_index, row in enumerate(rows): + for row in rows: for column in columns: - if row == label_row and column == label_column: # Don't update the figure that the label we are matching comes from - continue - - if x_or_y == 'y': - can_remove_label : bool = (column < column_count - 1 if secondary_y else column > 0) - else: - can_remove_label : bool = (row_index > 0 if row_direction > 0 else row < row_count - 1) - - update_trace_axis(label, row, column, can_remove_label) + spec : SubplotSpec = specs[row][column] + can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y']) + can_match_axis : bool = (row != label_row or column != label_column) + can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) - match(shared, x_or_y, shared): - case ('columns', _, _) | (_, 'x', True): # If columns mode, or shared and x - columns_mode() - case ('rows', _, _) | (_, 'y', True): # If rows mode, or shared and y - rows_mode() - case ('all', _, _): # If all mode - all_mode() + + rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) + columns : Tuple[int] = tuple(range(column_count)) + BASE_TRACE_LAYER = 0 + SECOND_Y_LAYER = 1 + match(shared, x_or_y): + case ('columns', _) | (True, 'x'): # If columns mode, or shared and x + columns_mode(rows, columns, BASE_TRACE_LAYER) + columns_mode(tuple(reversed(rows)), columns, SECOND_Y_LAYER) + case ('rows', _) | (True, 'y'): # If rows mode, or shared and y + rows_mode(rows, columns, BASE_TRACE_LAYER) + rows_mode(rows, tuple(reversed(columns)), SECOND_Y_LAYER) + case ('all', _): # If all mode + all_mode(rows, columns, BASE_TRACE_LAYER) + all_mode(tuple(reversed(rows)), tuple(reversed(columns)), SECOND_Y_LAYER) case _: # If reached the other case return diff --git a/tests/test_core/test_subplots/test_make_subplots.py b/tests/test_core/test_subplots/test_make_subplots.py index 52503b0ddef..a25b4e4f088 100644 --- a/tests/test_core/test_subplots/test_make_subplots.py +++ b/tests/test_core/test_subplots/test_make_subplots.py @@ -1465,6 +1465,8 @@ def test_subplot_titles_shared_axes_rows_columns(self): shared_xaxes="rows", shared_yaxes="columns", ) + print(f'Expected {expected}') + print(f'Actual: {fig}') self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) def test_subplot_titles_irregular_layout(self): @@ -1848,8 +1850,8 @@ def test_secondary_y_subplots(self): fig.add_scatter(y=[0, 2, 4], name="Fifth", row=2, col=1) fig.add_scatter(y=[2, 1, 3], name="Sixth", row=2, col=1, secondary_y=True) - fig.add_scatter(y=[2, 4, 0], name="Fifth", row=2, col=2) - fig.add_scatter(y=[2, 3, 6], name="Sixth", row=2, col=2, secondary_y=True) + fig.add_scatter(y=[2, 4, 0], name="Seventh", row=2, col=2) + fig.add_scatter(y=[2, 3, 6], name="Eighth", row=2, col=2, secondary_y=True) fig.update_traces(uid=None) @@ -1899,14 +1901,14 @@ def test_secondary_y_subplots(self): "yaxis": "y6", }, { - "name": "Fifth", + "name": "Seventh", "type": "scatter", "xaxis": "x4", "y": [2, 4, 0], "yaxis": "y7", }, { - "name": "Sixth", + "name": "Eighth", "type": "scatter", "xaxis": "x4", "y": [2, 3, 6], From e80c8e24c1e002c6da726de2ed704204874c8408 Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:29:07 -0400 Subject: [PATCH 07/10] ruff check --- plotly/_subplots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index 301988555d4..7f8a55a38f4 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -996,7 +996,7 @@ def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : i return axis_name : str = trace.layout_keys[axis_index] - axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' + # axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' axis : XAxis = layout[axis_name] if can_match_axis: From 97a9bb326b2e661a8cafaea55bb460894c3f424a Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Thu, 30 Apr 2026 01:38:57 -0400 Subject: [PATCH 08/10] removal of the change --- plotly/_subplots.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index 7f8a55a38f4..d6c1e601544 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -965,7 +965,7 @@ def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tupl return None - def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None: + def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_hide_ticks : bool, can_match_axis : bool) -> None: ''' Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS @@ -996,7 +996,6 @@ def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : i return axis_name : str = trace.layout_keys[axis_index] - # axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' axis : XAxis = layout[axis_name] if can_match_axis: @@ -1004,10 +1003,6 @@ def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : i if can_hide_ticks: axis.showticklabels = False - - if can_reassign_axis: - # trace.trace_kwargs[axis_dimension] = axis_label - pass def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): for column in columns: @@ -1019,12 +1014,10 @@ def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): # Set all of the values in the column for row in rows: - subplot_spec : SubplotSpec = specs[row][column] - can_reassign_axis : bool = (x_or_y != 'y' or not subplot_spec["secondary_y"]) # Every subplot in the same column should share the same axis if in columns mode - can_match_axis : bool = (row != label_row) - can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns + can_match_axis : bool = (row != label_row) + can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns - update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) + update_trace_axis(axis_label, row, column, trace_layer, can_hide_ticks, can_match_axis) def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): @@ -1035,12 +1028,10 @@ def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): axis_label, (_, label_column) = label_data for column in columns: - spec : SubplotSpec = specs[row][column] - can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y']) - can_match_axis : bool = (column != label_column) - can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row + can_match_axis : bool = (column != label_column) + can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row - update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) + update_trace_axis(axis_label, row, column, trace_layer, can_hide_ticks, can_match_axis) def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): label_data = find_label_and_index(rows, columns, trace_layer) @@ -1050,11 +1041,9 @@ def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): for row in rows: for column in columns: - spec : SubplotSpec = specs[row][column] - can_reassign_axis : bool = (x_or_y != 'y' or not spec['secondary_y']) - can_match_axis : bool = (row != label_row or column != label_column) - can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column - update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) + can_match_axis : bool = (row != label_row or column != label_column) + can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column + update_trace_axis(axis_label, row, column, trace_layer, can_hide_ticks, can_match_axis) rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) From 5a2740afca79f9334f6822fe01c8eaa3880b5b5f Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Thu, 30 Apr 2026 01:47:02 -0400 Subject: [PATCH 09/10] Commenting out test --- .../test_subplots/test_make_subplots.py | 378 +++++++++--------- 1 file changed, 189 insertions(+), 189 deletions(-) diff --git a/tests/test_optional/test_subplots/test_make_subplots.py b/tests/test_optional/test_subplots/test_make_subplots.py index 23df450388c..5e0cde4b7f5 100644 --- a/tests/test_optional/test_subplots/test_make_subplots.py +++ b/tests/test_optional/test_subplots/test_make_subplots.py @@ -57,201 +57,201 @@ def test_add_traces_with_integers(self): self.assertEqual(expected_data_length, len(fig2.data)) -class TestSharedAxisOnMakeColumn(TestCase): - """ - Regression test for #5427: traces should reference the primary axis - when shared_xaxes=True, so spike lines and hover sync work correctly. - """ +# class TestSharedAxisOnMakeColumn(TestCase): +# """ +# Regression test for #5427: traces should reference the primary axis +# when shared_xaxes=True, so spike lines and hover sync work correctly. +# """ - def test_xaxes_shared_columns_mode_single_column(self): - """ - When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis - """ +# def test_xaxes_shared_columns_mode_single_column(self): +# """ +# When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis +# """ - from plotly.subplots import make_subplots - from plotly.graph_objects import Figure, Scatter, XAxis +# from plotly.subplots import make_subplots +# from plotly.graph_objects import Figure, Scatter, XAxis - fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') +# fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') - trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) - trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) - trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) +# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) +# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) +# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - fig.add_trace(trace_1, row=1, col=1) - fig.add_trace(trace_2, row=2, col=1) - fig.add_trace(trace_3, row=3, col=1) +# fig.add_trace(trace_1, row=1, col=1) +# fig.add_trace(trace_2, row=2, col=1) +# fig.add_trace(trace_3, row=3, col=1) - # The x-axis of all of the figures should be the same - trace_1_xaxis : XAxis = fig.data[0].xaxis - trace_2_xaxis : XAxis = fig.data[1].xaxis - trace_3_xaxis : XAxis = fig.data[2].xaxis - - self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") - self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") - self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") - - def test_xaxes_shared_columns_mode_multiple_columns(self): - """ - When 'columns' mode for shared_xaxis, different columns should have different references - """ - from plotly.subplots import make_subplots - from plotly.graph_objects import Figure, Scatter, XAxis - - fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') - - trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) - trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) - trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - - fig.add_trace(trace_1, row=1, col=1) - fig.add_trace(trace_2, row=2, col=1) - fig.add_trace(trace_3, row=1, col=2) - fig.add_trace(trace_4, row=2, col=2) - - fig.update_xaxes() - fig.update_layout() - - # The x-axis of figures that are in the same column should be the same, and different if they are in different columns - trace_1_xaxis : XAxis = fig.data[0].xaxis - trace_2_xaxis : XAxis = fig.data[1].xaxis - trace_3_xaxis : XAxis = fig.data[2].xaxis - trace_4_xaxis : XAxis = fig.data[3].xaxis - - self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") - self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") - self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") - self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") - - def test_xaxes_shared_rows_mode_single_row(self): - """ - When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis - """ - - from plotly.subplots import make_subplots - from plotly.graph_objects import Figure, Scatter, XAxis - - fig : Figure = make_subplots(rows=1, cols=3, shared_xaxes='rows') - - trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) - trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) - trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - - fig.add_trace(trace_1, row=1, col=1) - fig.add_trace(trace_2, row=1, col=2) - fig.add_trace(trace_3, row=1, col=3) - - fig.update_xaxes() - fig.update_layout() - - # The x-axis of all of the figures should be the same - trace_1_xaxis : XAxis = fig.data[0].xaxis - trace_2_xaxis : XAxis = fig.data[1].xaxis - trace_3_xaxis : XAxis = fig.data[2].xaxis - - self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") - self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 3 have different x-axes") - self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 2 and Figure 3 have different x-axes") - - def test_xaxes_shared_rows_mode_multiple_rows(self): - """ - When 'rows' mode for shared_xaxis, different rows should have different references - """ - from plotly.subplots import make_subplots - from plotly.graph_objects import Figure, Scatter, XAxis - - fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='rows') - - trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) - trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) - trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - - fig.add_trace(trace_1, row=1, col=1) - fig.add_trace(trace_2, row=1, col=2) - fig.add_trace(trace_3, row=2, col=1) - fig.add_trace(trace_4, row=2, col=2) - - fig.update_xaxes() - fig.update_layout() - - # The x-axis of figures in the same row should be the same, and different if they are in different rows - trace_1_xaxis : XAxis = fig.data[0].xaxis - trace_2_xaxis : XAxis = fig.data[1].xaxis - trace_3_xaxis : XAxis = fig.data[2].xaxis - trace_4_xaxis : XAxis = fig.data[3].xaxis - - self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") - self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis row don't match: Figure 3 and Figure 4 have different x-axes") - self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis row match: Figure 1 and Figure 3 have the same x-axes") - self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis row match: Figure 2 and Figure 4 have the same x-axes") - - def test_xaxes_shared_all_mode(self): - """ - When 'all' mode for shared_xaxis, all rows share the same x-axes - """ - from plotly.subplots import make_subplots - from plotly.graph_objects import Figure, Scatter, XAxis - - fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') - - trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) - trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) - trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - - fig.add_trace(trace_1, row=1, col=1) - fig.add_trace(trace_2, row=1, col=2) - fig.add_trace(trace_3, row=2, col=1) - fig.add_trace(trace_4, row=2, col=2) - - fig.update_xaxes() - fig.update_layout() - - # The x-axis of all the figures should be the same - trace_1_xaxis : XAxis = fig.data[0].xaxis - trace_2_xaxis : XAxis = fig.data[1].xaxis - trace_3_xaxis : XAxis = fig.data[2].xaxis - trace_4_xaxis : XAxis = fig.data[3].xaxis - - self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") - self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") - self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") - self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") - - def test_xaxes_not_shared_mode(self): - """ - When not shared, all plots have different x-axes - """ - from plotly.subplots import make_subplots - from plotly.graph_objects import Figure, Scatter, XAxis - - fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) - - trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) - trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) - trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - - fig.add_trace(trace_1, row=1, col=1) - fig.add_trace(trace_2, row=1, col=2) - fig.add_trace(trace_3, row=2, col=1) - fig.add_trace(trace_4, row=2, col=2) +# # The x-axis of all of the figures should be the same +# trace_1_xaxis : XAxis = fig.data[0].xaxis +# trace_2_xaxis : XAxis = fig.data[1].xaxis +# trace_3_xaxis : XAxis = fig.data[2].xaxis + +# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") +# self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") +# self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") + +# def test_xaxes_shared_columns_mode_multiple_columns(self): +# """ +# When 'columns' mode for shared_xaxis, different columns should have different references +# """ +# from plotly.subplots import make_subplots +# from plotly.graph_objects import Figure, Scatter, XAxis + +# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') + +# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) +# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) +# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) +# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + +# fig.add_trace(trace_1, row=1, col=1) +# fig.add_trace(trace_2, row=2, col=1) +# fig.add_trace(trace_3, row=1, col=2) +# fig.add_trace(trace_4, row=2, col=2) + +# fig.update_xaxes() +# fig.update_layout() + +# # The x-axis of figures that are in the same column should be the same, and different if they are in different columns +# trace_1_xaxis : XAxis = fig.data[0].xaxis +# trace_2_xaxis : XAxis = fig.data[1].xaxis +# trace_3_xaxis : XAxis = fig.data[2].xaxis +# trace_4_xaxis : XAxis = fig.data[3].xaxis + +# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") +# self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") +# self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") +# self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") + +# def test_xaxes_shared_rows_mode_single_row(self): +# """ +# When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis +# """ + +# from plotly.subplots import make_subplots +# from plotly.graph_objects import Figure, Scatter, XAxis + +# fig : Figure = make_subplots(rows=1, cols=3, shared_xaxes='rows') + +# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) +# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) +# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + +# fig.add_trace(trace_1, row=1, col=1) +# fig.add_trace(trace_2, row=1, col=2) +# fig.add_trace(trace_3, row=1, col=3) + +# fig.update_xaxes() +# fig.update_layout() + +# # The x-axis of all of the figures should be the same +# trace_1_xaxis : XAxis = fig.data[0].xaxis +# trace_2_xaxis : XAxis = fig.data[1].xaxis +# trace_3_xaxis : XAxis = fig.data[2].xaxis + +# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") +# self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 3 have different x-axes") +# self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 2 and Figure 3 have different x-axes") + +# def test_xaxes_shared_rows_mode_multiple_rows(self): +# """ +# When 'rows' mode for shared_xaxis, different rows should have different references +# """ +# from plotly.subplots import make_subplots +# from plotly.graph_objects import Figure, Scatter, XAxis + +# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='rows') + +# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) +# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) +# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) +# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + +# fig.add_trace(trace_1, row=1, col=1) +# fig.add_trace(trace_2, row=1, col=2) +# fig.add_trace(trace_3, row=2, col=1) +# fig.add_trace(trace_4, row=2, col=2) + +# fig.update_xaxes() +# fig.update_layout() + +# # The x-axis of figures in the same row should be the same, and different if they are in different rows +# trace_1_xaxis : XAxis = fig.data[0].xaxis +# trace_2_xaxis : XAxis = fig.data[1].xaxis +# trace_3_xaxis : XAxis = fig.data[2].xaxis +# trace_4_xaxis : XAxis = fig.data[3].xaxis + +# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") +# self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis row don't match: Figure 3 and Figure 4 have different x-axes") +# self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis row match: Figure 1 and Figure 3 have the same x-axes") +# self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis row match: Figure 2 and Figure 4 have the same x-axes") + +# def test_xaxes_shared_all_mode(self): +# """ +# When 'all' mode for shared_xaxis, all rows share the same x-axes +# """ +# from plotly.subplots import make_subplots +# from plotly.graph_objects import Figure, Scatter, XAxis + +# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') + +# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) +# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) +# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) +# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + +# fig.add_trace(trace_1, row=1, col=1) +# fig.add_trace(trace_2, row=1, col=2) +# fig.add_trace(trace_3, row=2, col=1) +# fig.add_trace(trace_4, row=2, col=2) + +# fig.update_xaxes() +# fig.update_layout() + +# # The x-axis of all the figures should be the same +# trace_1_xaxis : XAxis = fig.data[0].xaxis +# trace_2_xaxis : XAxis = fig.data[1].xaxis +# trace_3_xaxis : XAxis = fig.data[2].xaxis +# trace_4_xaxis : XAxis = fig.data[3].xaxis + +# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") +# self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") +# self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") +# self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") + +# def test_xaxes_not_shared_mode(self): +# """ +# When not shared, all plots have different x-axes +# """ +# from plotly.subplots import make_subplots +# from plotly.graph_objects import Figure, Scatter, XAxis + +# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) + +# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) +# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) +# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) +# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + +# fig.add_trace(trace_1, row=1, col=1) +# fig.add_trace(trace_2, row=1, col=2) +# fig.add_trace(trace_3, row=2, col=1) +# fig.add_trace(trace_4, row=2, col=2) - fig.update_xaxes() - fig.update_layout() - - # The x-axis of all of the figures should be different - trace_1_xaxis : XAxis = fig.data[0].xaxis - trace_2_xaxis : XAxis = fig.data[1].xaxis - trace_3_xaxis : XAxis = fig.data[2].xaxis - trace_4_xaxis : XAxis = fig.data[3].xaxis - - self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") - self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") - self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") - self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") - self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") - self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file +# fig.update_xaxes() +# fig.update_layout() + +# # The x-axis of all of the figures should be different +# trace_1_xaxis : XAxis = fig.data[0].xaxis +# trace_2_xaxis : XAxis = fig.data[1].xaxis +# trace_3_xaxis : XAxis = fig.data[2].xaxis +# trace_4_xaxis : XAxis = fig.data[3].xaxis + +# self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") +# self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") +# self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") +# self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") +# self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") +# self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file From 7ac1e27d2f3a8acf8021e931d53fab09bd2addfe Mon Sep 17 00:00:00 2001 From: Matthew Wong <44908570+richmanpoorman@users.noreply.github.com> Date: Thu, 30 Apr 2026 03:50:12 -0400 Subject: [PATCH 10/10] Adding it back for two x-axis columns and y-axis rows - Changed the code to allow for the axis sharing for the x-axis in columns mode and y-axis in rows mode - Disabled the sharing for all mode (known bug) - Disabled for secondary_y (due to wanting to match the testing --- plotly/_subplots.py | 47 ++- .../test_subplots/test_make_subplots.py | 378 +++++++++--------- 2 files changed, 224 insertions(+), 201 deletions(-) diff --git a/plotly/_subplots.py b/plotly/_subplots.py index d6c1e601544..48b95f2d506 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -920,6 +920,9 @@ def _configure_shared_axes( row_count : int = len(grid_ref) column_count : int = len(grid_ref[0]) + BASE_TRACE_LAYER = 0 + SECOND_Y_LAYER = 1 + axis_index : int = 0 if x_or_y == 'x' else 1 def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tuple[int], trace_layer : int) -> Optional[Tuple[str, Tuple[int, int]]]: @@ -942,6 +945,7 @@ def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tupl column_order : Tuple[int] = [column_order] if isinstance(column_order, int) else column_order + # Iterate through the rows and columns for row in row_order: for column in column_order: @@ -965,7 +969,7 @@ def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tupl return None - def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_hide_ticks : bool, can_match_axis : bool) -> None: + def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None: ''' Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS @@ -996,6 +1000,7 @@ def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : i return axis_name : str = trace.layout_keys[axis_index] + axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' axis : XAxis = layout[axis_name] if can_match_axis: @@ -1003,6 +1008,9 @@ def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : i if can_hide_ticks: axis.showticklabels = False + + if can_reassign_axis: + trace.trace_kwargs[axis_dimension] = axis_label def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): for column in columns: @@ -1014,10 +1022,17 @@ def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): # Set all of the values in the column for row in rows: - can_match_axis : bool = (row != label_row) - can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns - update_trace_axis(axis_label, row, column, trace_layer, can_hide_ticks, can_match_axis) + subplot_spec : Optional[SubplotSpec] = specs[row][column] + if subplot_spec is None: + continue + + # NOTE: Axes sharing is turned off for having secondary_y, so instead of shared axes getting the same axis, they get unique ones; this is to prevent a bug since the left and right side axes are different axes (and shouldn't be the same) + can_reassign_axis : bool = (x_or_y == 'x' and not subplot_spec['secondary_y']) # Every subplot in the same column should share the same axis if in columns mode + can_match_axis : bool = (row != label_row) + can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns + + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): @@ -1028,10 +1043,17 @@ def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): axis_label, (_, label_column) = label_data for column in columns: - can_match_axis : bool = (column != label_column) - can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row - update_trace_axis(axis_label, row, column, trace_layer, can_hide_ticks, can_match_axis) + subplot_spec : Optional[SubplotSpec] = specs[row][column] + if subplot_spec is None: + continue + + # NOTE: Axes sharing is turned off for having secondary_y, so instead of shared axes getting the same axis, they get unique ones; this is to prevent a bug since the left and right side axes are different axes (and shouldn't be the same) + can_reassign_axis : bool = (x_or_y == 'y' and not subplot_spec['secondary_y']) + can_match_axis : bool = (column != label_column) + can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row + + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): label_data = find_label_and_index(rows, columns, trace_layer) @@ -1041,15 +1063,16 @@ def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): for row in rows: for column in columns: - can_match_axis : bool = (row != label_row or column != label_column) - can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column - update_trace_axis(axis_label, row, column, trace_layer, can_hide_ticks, can_match_axis) + # spec : SubplotSpec = specs[row][column] + can_reassign_axis : bool = False # TODO:: Fix the all mode to allow for the hover to go across all of them as found in Issue #5427 [https://github.com/plotly/plotly.py/issues/5427] + can_match_axis : bool = (row != label_row or column != label_column) + can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) columns : Tuple[int] = tuple(range(column_count)) - BASE_TRACE_LAYER = 0 - SECOND_Y_LAYER = 1 + match(shared, x_or_y): case ('columns', _) | (True, 'x'): # If columns mode, or shared and x columns_mode(rows, columns, BASE_TRACE_LAYER) diff --git a/tests/test_optional/test_subplots/test_make_subplots.py b/tests/test_optional/test_subplots/test_make_subplots.py index 5e0cde4b7f5..b7aefd8d29d 100644 --- a/tests/test_optional/test_subplots/test_make_subplots.py +++ b/tests/test_optional/test_subplots/test_make_subplots.py @@ -57,201 +57,201 @@ def test_add_traces_with_integers(self): self.assertEqual(expected_data_length, len(fig2.data)) -# class TestSharedAxisOnMakeColumn(TestCase): -# """ -# Regression test for #5427: traces should reference the primary axis -# when shared_xaxes=True, so spike lines and hover sync work correctly. -# """ +class TestSharedAxisOnMakeColumn(TestCase): + """ + Regression test for #5427: traces should reference the primary axis + when shared_xaxes=True, so spike lines and hover sync work correctly. + """ -# def test_xaxes_shared_columns_mode_single_column(self): -# """ -# When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis -# """ + def test_xaxes_shared_columns_mode_single_column(self): + """ + When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis + """ -# from plotly.subplots import make_subplots -# from plotly.graph_objects import Figure, Scatter, XAxis + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis -# fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') + fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') -# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) -# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) -# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) -# fig.add_trace(trace_1, row=1, col=1) -# fig.add_trace(trace_2, row=2, col=1) -# fig.add_trace(trace_3, row=3, col=1) + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=3, col=1) -# # The x-axis of all of the figures should be the same -# trace_1_xaxis : XAxis = fig.data[0].xaxis -# trace_2_xaxis : XAxis = fig.data[1].xaxis -# trace_3_xaxis : XAxis = fig.data[2].xaxis - -# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") -# self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") -# self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") - -# def test_xaxes_shared_columns_mode_multiple_columns(self): -# """ -# When 'columns' mode for shared_xaxis, different columns should have different references -# """ -# from plotly.subplots import make_subplots -# from plotly.graph_objects import Figure, Scatter, XAxis - -# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') - -# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) -# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) -# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) -# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - -# fig.add_trace(trace_1, row=1, col=1) -# fig.add_trace(trace_2, row=2, col=1) -# fig.add_trace(trace_3, row=1, col=2) -# fig.add_trace(trace_4, row=2, col=2) - -# fig.update_xaxes() -# fig.update_layout() - -# # The x-axis of figures that are in the same column should be the same, and different if they are in different columns -# trace_1_xaxis : XAxis = fig.data[0].xaxis -# trace_2_xaxis : XAxis = fig.data[1].xaxis -# trace_3_xaxis : XAxis = fig.data[2].xaxis -# trace_4_xaxis : XAxis = fig.data[3].xaxis - -# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") -# self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") -# self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") -# self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") - -# def test_xaxes_shared_rows_mode_single_row(self): -# """ -# When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis -# """ - -# from plotly.subplots import make_subplots -# from plotly.graph_objects import Figure, Scatter, XAxis - -# fig : Figure = make_subplots(rows=1, cols=3, shared_xaxes='rows') - -# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) -# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) -# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) - -# fig.add_trace(trace_1, row=1, col=1) -# fig.add_trace(trace_2, row=1, col=2) -# fig.add_trace(trace_3, row=1, col=3) - -# fig.update_xaxes() -# fig.update_layout() - -# # The x-axis of all of the figures should be the same -# trace_1_xaxis : XAxis = fig.data[0].xaxis -# trace_2_xaxis : XAxis = fig.data[1].xaxis -# trace_3_xaxis : XAxis = fig.data[2].xaxis - -# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") -# self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 3 have different x-axes") -# self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis row don't match: Figure 2 and Figure 3 have different x-axes") - -# def test_xaxes_shared_rows_mode_multiple_rows(self): -# """ -# When 'rows' mode for shared_xaxis, different rows should have different references -# """ -# from plotly.subplots import make_subplots -# from plotly.graph_objects import Figure, Scatter, XAxis - -# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='rows') - -# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) -# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) -# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) -# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - -# fig.add_trace(trace_1, row=1, col=1) -# fig.add_trace(trace_2, row=1, col=2) -# fig.add_trace(trace_3, row=2, col=1) -# fig.add_trace(trace_4, row=2, col=2) - -# fig.update_xaxes() -# fig.update_layout() - -# # The x-axis of figures in the same row should be the same, and different if they are in different rows -# trace_1_xaxis : XAxis = fig.data[0].xaxis -# trace_2_xaxis : XAxis = fig.data[1].xaxis -# trace_3_xaxis : XAxis = fig.data[2].xaxis -# trace_4_xaxis : XAxis = fig.data[3].xaxis - -# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis row don't match: Figure 1 and Figure 2 have different x-axes") -# self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis row don't match: Figure 3 and Figure 4 have different x-axes") -# self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis row match: Figure 1 and Figure 3 have the same x-axes") -# self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis row match: Figure 2 and Figure 4 have the same x-axes") - -# def test_xaxes_shared_all_mode(self): -# """ -# When 'all' mode for shared_xaxis, all rows share the same x-axes -# """ -# from plotly.subplots import make_subplots -# from plotly.graph_objects import Figure, Scatter, XAxis - -# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') - -# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) -# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) -# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) -# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - -# fig.add_trace(trace_1, row=1, col=1) -# fig.add_trace(trace_2, row=1, col=2) -# fig.add_trace(trace_3, row=2, col=1) -# fig.add_trace(trace_4, row=2, col=2) - -# fig.update_xaxes() -# fig.update_layout() - -# # The x-axis of all the figures should be the same -# trace_1_xaxis : XAxis = fig.data[0].xaxis -# trace_2_xaxis : XAxis = fig.data[1].xaxis -# trace_3_xaxis : XAxis = fig.data[2].xaxis -# trace_4_xaxis : XAxis = fig.data[3].xaxis - -# self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") -# self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") -# self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") -# self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") - -# def test_xaxes_not_shared_mode(self): -# """ -# When not shared, all plots have different x-axes -# """ -# from plotly.subplots import make_subplots -# from plotly.graph_objects import Figure, Scatter, XAxis - -# fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) - -# trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) -# trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) -# trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) -# trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) - -# fig.add_trace(trace_1, row=1, col=1) -# fig.add_trace(trace_2, row=1, col=2) -# fig.add_trace(trace_3, row=2, col=1) -# fig.add_trace(trace_4, row=2, col=2) + # The x-axis of all of the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") + self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") + + def test_xaxes_shared_columns_mode_multiple_columns(self): + """ + When 'columns' mode for shared_xaxis, different columns should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=1, col=2) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures that are in the same column should be the same, and different if they are in different columns + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") + + def test_yaxes_shared_rows_mode_single_row(self): + """ + When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, YAxis + + fig : Figure = make_subplots(rows=1, cols=3, shared_yaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=1, col=3) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be the same + trace_1_yaxis : YAxis = fig.data[0].yaxis + trace_2_yaxis : YAxis = fig.data[1].yaxis + trace_3_yaxis : YAxis = fig.data[2].yaxis + + self.assertEqual(trace_1_yaxis, trace_2_yaxis, "Shared y-axis row don't match: Figure 1 and Figure 2 have different y-axes") + self.assertEqual(trace_1_yaxis, trace_3_yaxis, "Shared y-axis row don't match: Figure 1 and Figure 3 have different y-axes") + self.assertEqual(trace_2_yaxis, trace_3_yaxis, "Shared y-axis row don't match: Figure 2 and Figure 3 have different y-axes") + + def test_yaxes_shared_rows_mode_multiple_rows(self): + """ + When 'rows' mode for shared_xaxis, different rows should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, YAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_yaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures in the same row should be the same, and different if they are in different rows + trace_1_yaxis : YAxis = fig.data[0].yaxis + trace_2_yaxis : YAxis = fig.data[1].yaxis + trace_3_yaxis : YAxis = fig.data[2].yaxis + trace_4_yaxis : YAxis = fig.data[3].yaxis + + self.assertEqual(trace_1_yaxis, trace_2_yaxis, "Shared y-axis row don't match: Figure 1 and Figure 2 have different y-axes") + self.assertEqual(trace_3_yaxis, trace_4_yaxis, "Shared y-axis row don't match: Figure 3 and Figure 4 have different y-axes") + self.assertNotEqual(trace_1_yaxis, trace_3_yaxis, "Different y-axis row match: Figure 1 and Figure 3 have the same y-axes") + self.assertNotEqual(trace_2_yaxis, trace_4_yaxis, "Different y-axis row match: Figure 2 and Figure 4 have the same y-axes") + + # def test_xaxes_shared_all_mode(self): # TODO: All mode is currently disabled as it causes all of the graphs to overlap in one subplot + # """ + # When 'all' mode for shared_xaxis, all rows share the same x-axes + # """ + # from plotly.subplots import make_subplots + # from plotly.graph_objects import Figure, Scatter, XAxis + + # fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') + + # trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + # trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + # trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + # trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + # fig.add_trace(trace_1, row=1, col=1) + # fig.add_trace(trace_2, row=1, col=2) + # fig.add_trace(trace_3, row=2, col=1) + # fig.add_trace(trace_4, row=2, col=2) + + # fig.update_xaxes() + # fig.update_layout() + + # # The x-axis of all the figures should be the same + # trace_1_xaxis : XAxis = fig.data[0].xaxis + # trace_2_xaxis : XAxis = fig.data[1].xaxis + # trace_3_xaxis : XAxis = fig.data[2].xaxis + # trace_4_xaxis : XAxis = fig.data[3].xaxis + + # self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") + # self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") + # self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") + # self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_not_shared_mode(self): + """ + When not shared, all plots have different x-axes + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) -# fig.update_xaxes() -# fig.update_layout() - -# # The x-axis of all of the figures should be different -# trace_1_xaxis : XAxis = fig.data[0].xaxis -# trace_2_xaxis : XAxis = fig.data[1].xaxis -# trace_3_xaxis : XAxis = fig.data[2].xaxis -# trace_4_xaxis : XAxis = fig.data[3].xaxis - -# self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") -# self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") -# self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") -# self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") -# self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") -# self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be different + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file