diff --git a/CHANGELOG.md b/CHANGELOG.md index 126ff280b49..b296ee382b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Fixed +- Bug was that function marked the axis to be connected, but the trace_kwargs still had unique axes [[#5427](https://github.com/plotly/plotly.py/issues/5427)] +- Change: change the keyword argument for the trace, so that when the graph is initialized, it uses the correct axis instead of the autogenerated one +- Note: The program generates a unique axis label for each subgraph, and then overwrites the label (under this fix) + ### Fixed - Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)], with thanks to @antonymilne for the contribution! - Update tests to be compatible with numpy 2.4 [[#5522](https://github.com/plotly/plotly.py/pull/5522)], with thanks to @thunze for the contribution! diff --git a/plotly/_subplots.py b/plotly/_subplots.py index 16a3958637e..48b95f2d506 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -6,8 +6,13 @@ # properties. # Note that this set does not contain `xaxis`/`yaxis` because these behave a # little differently. +from __future__ import annotations import collections +from typing import Literal, Optional, Tuple, TypedDict, TYPE_CHECKING +if TYPE_CHECKING: + from plotly.graph_objects import Layout, XAxis + _single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"} _subplot_types = set.union(_single_subplot_types, {"xy", "domain"}) @@ -31,6 +36,17 @@ "SubplotRef", ("subplot_type", "layout_keys", "trace_kwargs") ) +class SubplotSpec(TypedDict): + type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str + secondary_y : bool + colspan : int + rowspan : int + # NOTE: that this is the dictionary as defined by the documentation, so the ambiguous name 'l' can't be changed without changing the documentation + l : float # noqa: E741 + r : float + t : float + b : float + def _get_initial_max_subplot_ids(): max_subplot_ids = {subplot_type: 0 for subplot_type in _single_subplot_types} @@ -746,19 +762,10 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): ) grid_ref[r][c] = subplot_refs - _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir, False) - _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir, False) - any_secondary_y = any( - spec["secondary_y"] - for spec_row in specs - for spec in spec_row - if spec is not None - ) - if any_secondary_y: - _configure_shared_axes( - layout, grid_ref, specs, "y", shared_yaxes, row_dir, True - ) + _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir) + _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir) + # Build inset reference # --------------------- @@ -889,99 +896,195 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): return figure - def _configure_shared_axes( - layout, grid_ref, specs, x_or_y, shared, row_dir, secondary_y -): - rows = len(grid_ref) - cols = len(grid_ref[0]) - - layout_key_ind = ["x", "y"].index(x_or_y) - - if row_dir < 0: - rows_iter = range(rows - 1, -1, -1) - else: - rows_iter = range(rows) - - if secondary_y: - cols_iter = range(cols - 1, -1, -1) - axis_index = 1 - else: - cols_iter = range(cols) - axis_index = 0 - - def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): - if subplot_ref is None: - return first_axis_id - - if x_or_y == "x": - span = spec["colspan"] - else: - span = spec["rowspan"] - - if subplot_ref.subplot_type == "xy" and span == 1: - if first_axis_id is None: - first_axis_name = subplot_ref.layout_keys[layout_key_ind] - first_axis_id = first_axis_name.replace("axis", "") - else: - axis_name = subplot_ref.layout_keys[layout_key_ind] - axis_to_match = layout[axis_name] - axis_to_match.matches = first_axis_id - if remove_label: - axis_to_match.showticklabels = False - - return first_axis_id - - if shared == "columns" or (x_or_y == "x" and shared is True): - for c in cols_iter: - first_axis_id = None - ok_to_remove_label = x_or_y == "x" - for r in rows_iter: - if not grid_ref[r][c]: - continue - if axis_index >= len(grid_ref[r][c]): - continue - subplot_ref = grid_ref[r][c][axis_index] - spec = specs[r][c] - first_axis_id = update_axis_matches( - first_axis_id, subplot_ref, spec, ok_to_remove_label - ) - - elif shared == "rows" or (x_or_y == "y" and shared is True): - for r in rows_iter: - first_axis_id = None - ok_to_remove_label = x_or_y == "y" - for c in cols_iter: - if not grid_ref[r][c]: - continue - if axis_index >= len(grid_ref[r][c]): - continue - subplot_ref = grid_ref[r][c][axis_index] - spec = specs[r][c] - first_axis_id = update_axis_matches( - first_axis_id, subplot_ref, spec, ok_to_remove_label - ) - - elif shared == "all": - first_axis_id = None - for ri, r in enumerate(rows_iter): - for c in cols_iter: - if not grid_ref[r][c]: - continue - if axis_index >= len(grid_ref[r][c]): - continue - subplot_ref = grid_ref[r][c][axis_index] - spec = specs[r][c] - - if x_or_y == "y": - ok_to_remove_label = c < cols - 1 if secondary_y else c > 0 - else: - ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1 - - first_axis_id = update_axis_matches( - first_axis_id, subplot_ref, spec, ok_to_remove_label - ) + layout : Layout, + grid_ref : Tuple[Tuple[SubplotRef]], + specs : Tuple[Tuple[SubplotSpec]], + x_or_y : Literal['x', 'y'], + shared : bool | Literal['rows', 'columns', 'all'], + row_direction : Literal[1, -1] +) -> None: + ''' + Sets the axes to be shared, making them use the same axis + + Parameters: + ----------- + layout (go.Layout) : The layout of the figure to be updating + grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate + specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate + x_or_y ('x' | 'y') : The axis to configure + shared ('rows' | 'columns' | 'all' | bool) : The sharing mode, (True is 'columns' mode, False means no sharing) ie share the axis with all subplots in the corresponding row, column, or entire figure + row_direction (1 | -1) : The directional that the rows go + ''' + + row_count : int = len(grid_ref) + column_count : int = len(grid_ref[0]) + + BASE_TRACE_LAYER = 0 + SECOND_Y_LAYER = 1 + + axis_index : int = 0 if x_or_y == 'x' else 1 + + def find_label_and_index(row_order : int | Tuple[int], column_order : int | Tuple[int], trace_layer : int) -> Optional[Tuple[str, Tuple[int, int]]]: + ''' + Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS + + Parameters: + ----------- + row_order (int | Tuple[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable + column_order (int | Tuple[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable + trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1] + Return: + ------- + Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull) + Return (None): No label was found + ''' + + # Turn them into lists with one element, so that both row_order and column_order are iterables + row_order : Tuple[int] = [row_order] if isinstance(row_order, int) else row_order + column_order : Tuple[int] = [column_order] if isinstance(column_order, int) else column_order + + + + # Iterate through the rows and columns + for row in row_order: + for column in column_order: + if not grid_ref[row][column]: + continue + + subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column] + subplot_spec : SubplotSpec = specs[row][column] + + span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan'] + if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces): + continue + + trace = subplot_traces[trace_layer] + if trace is None or trace.subplot_type != 'xy': + continue + + label_name : str = trace.layout_keys[axis_index] + label : str = label_name.replace("axis", "") + return label, (row, column) + return None + + + def update_trace_axis(axis_label : str, row : int, column : int, trace_layer : int, can_reassign_axis : bool, can_hide_ticks : bool, can_match_axis : bool) -> None: + ''' + Updates the specific subplot trace at the given row and column with the given label, and removes the label visibility if necessary; ONLY WORKS WITH 2D CARTESIAN SUBPLOTS AKA 'xy' TYPE SUBPLOTS + + Parameters: + ----------- + axis_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location + row (int) : The row of the subplot within grid_ref to update + column (int) : The column of the subplot within grid_ref to update + trace_layer (int) : Which axis of traces to look at [Since there can be multiple traces on one subplot ie the secondary_y traces are on layer 1] + can_reassign_axis (bool): If True, can change the unique axis for the shared axis in the trace keywords, otherwise, will keep using the axis name it already has + can_hide_ticks (bool): If the function is allowed to hide the ticks (if True, it will hide the ticks, if False, it will leave the ticks as their current state) + can_match_axis (bool): If the axis should be marked as a match to the axis label + ''' + + if not grid_ref[row][column] or specs[row][column] is None: + return + + subplot_traces : Tuple[Optional[SubplotRef]] = grid_ref[row][column] + subplot_spec : SubplotSpec = specs[row][column] + + span = subplot_spec['colspan'] if x_or_y == 'x' else subplot_spec['rowspan'] + if subplot_spec['type'] != 'xy' or span != 1 or trace_layer >= len(subplot_traces): + return + trace : Optional[SubplotRef] = subplot_traces[trace_layer] + + if trace is None or trace.subplot_type != 'xy' or span != 1: + return + + axis_name : str = trace.layout_keys[axis_index] + axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' + axis : XAxis = layout[axis_name] + + if can_match_axis: + axis.matches = axis_label + + if can_hide_ticks: + axis.showticklabels = False + + if can_reassign_axis: + trace.trace_kwargs[axis_dimension] = axis_label + + def columns_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): + for column in columns: + # Get the label used by all the rows in the column + label_data = find_label_and_index(rows, column, trace_layer) + if label_data is None: + continue + axis_label, (label_row, _) = label_data + + # Set all of the values in the column + for row in rows: + + subplot_spec : Optional[SubplotSpec] = specs[row][column] + if subplot_spec is None: + continue + + # NOTE: Axes sharing is turned off for having secondary_y, so instead of shared axes getting the same axis, they get unique ones; this is to prevent a bug since the left and right side axes are different axes (and shouldn't be the same) + can_reassign_axis : bool = (x_or_y == 'x' and not subplot_spec['secondary_y']) # Every subplot in the same column should share the same axis if in columns mode + can_match_axis : bool = (row != label_row) + can_hide_ticks : bool = can_match_axis and x_or_y == 'x' # Sharing column wise can only hide x-axis; still need all of the different y-axis across plots in the same columns + + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) + + + def rows_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): + for row in rows: + label_data = find_label_and_index(row, columns, trace_layer) + if label_data is None: + continue + axis_label, (_, label_column) = label_data + + for column in columns: + + subplot_spec : Optional[SubplotSpec] = specs[row][column] + if subplot_spec is None: + continue + + # NOTE: Axes sharing is turned off for having secondary_y, so instead of shared axes getting the same axis, they get unique ones; this is to prevent a bug since the left and right side axes are different axes (and shouldn't be the same) + can_reassign_axis : bool = (x_or_y == 'y' and not subplot_spec['secondary_y']) + can_match_axis : bool = (column != label_column) + can_hide_ticks : bool = can_match_axis and x_or_y == 'y' # Sharing row wise can only hide y-axis; still need all of the different x-axis across plots in the same row + + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) + + def all_mode(rows : Tuple[int], columns : Tuple[int], trace_layer : int): + label_data = find_label_and_index(rows, columns, trace_layer) + if label_data is None: + return + axis_label, (label_row, label_column) = label_data + + for row in rows: + for column in columns: + # spec : SubplotSpec = specs[row][column] + can_reassign_axis : bool = False # TODO:: Fix the all mode to allow for the hover to go across all of them as found in Issue #5427 [https://github.com/plotly/plotly.py/issues/5427] + can_match_axis : bool = (row != label_row or column != label_column) + can_hide_ticks : bool = not ((row == label_row and x_or_y == 'x') or (column == label_column and x_or_y == 'y')) # The x-axis is across the first row, and the y-axis is along the first column + update_trace_axis(axis_label, row, column, trace_layer, can_reassign_axis, can_hide_ticks, can_match_axis) + + + rows : Tuple[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) + columns : Tuple[int] = tuple(range(column_count)) + + match(shared, x_or_y): + case ('columns', _) | (True, 'x'): # If columns mode, or shared and x + columns_mode(rows, columns, BASE_TRACE_LAYER) + columns_mode(tuple(reversed(rows)), columns, SECOND_Y_LAYER) + case ('rows', _) | (True, 'y'): # If rows mode, or shared and y + rows_mode(rows, columns, BASE_TRACE_LAYER) + rows_mode(rows, tuple(reversed(columns)), SECOND_Y_LAYER) + case ('all', _): # If all mode + all_mode(rows, columns, BASE_TRACE_LAYER) + all_mode(tuple(reversed(rows)), tuple(reversed(columns)), SECOND_Y_LAYER) + case _: # If reached the other case + return def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None): if max_subplot_ids is None: diff --git a/tests/test_core/test_subplots/test_make_subplots.py b/tests/test_core/test_subplots/test_make_subplots.py index 52503b0ddef..a25b4e4f088 100644 --- a/tests/test_core/test_subplots/test_make_subplots.py +++ b/tests/test_core/test_subplots/test_make_subplots.py @@ -1465,6 +1465,8 @@ def test_subplot_titles_shared_axes_rows_columns(self): shared_xaxes="rows", shared_yaxes="columns", ) + print(f'Expected {expected}') + print(f'Actual: {fig}') self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) def test_subplot_titles_irregular_layout(self): @@ -1848,8 +1850,8 @@ def test_secondary_y_subplots(self): fig.add_scatter(y=[0, 2, 4], name="Fifth", row=2, col=1) fig.add_scatter(y=[2, 1, 3], name="Sixth", row=2, col=1, secondary_y=True) - fig.add_scatter(y=[2, 4, 0], name="Fifth", row=2, col=2) - fig.add_scatter(y=[2, 3, 6], name="Sixth", row=2, col=2, secondary_y=True) + fig.add_scatter(y=[2, 4, 0], name="Seventh", row=2, col=2) + fig.add_scatter(y=[2, 3, 6], name="Eighth", row=2, col=2, secondary_y=True) fig.update_traces(uid=None) @@ -1899,14 +1901,14 @@ def test_secondary_y_subplots(self): "yaxis": "y6", }, { - "name": "Fifth", + "name": "Seventh", "type": "scatter", "xaxis": "x4", "y": [2, 4, 0], "yaxis": "y7", }, { - "name": "Sixth", + "name": "Eighth", "type": "scatter", "xaxis": "x4", "y": [2, 3, 6], diff --git a/tests/test_optional/test_subplots/test_make_subplots.py b/tests/test_optional/test_subplots/test_make_subplots.py index 4552c66a694..b7aefd8d29d 100644 --- a/tests/test_optional/test_subplots/test_make_subplots.py +++ b/tests/test_optional/test_subplots/test_make_subplots.py @@ -56,3 +56,202 @@ def test_add_traces_with_integers(self): expected_data_length = 4 self.assertEqual(expected_data_length, len(fig2.data)) + +class TestSharedAxisOnMakeColumn(TestCase): + """ + Regression test for #5427: traces should reference the primary axis + when shared_xaxes=True, so spike lines and hover sync work correctly. + """ + + def test_xaxes_shared_columns_mode_single_column(self): + """ + When 'columns' mode for shared_xaxis, all of the traces in the same column should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=3, cols=1, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=3, col=1) + + + # The x-axis of all of the figures should be the same + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 3 have different x-axes") + self.assertEqual(trace_2_xaxis, trace_3_xaxis, "Shared x-axis column don't match: Figure 2 and Figure 3 have different x-axes") + + def test_xaxes_shared_columns_mode_multiple_columns(self): + """ + When 'columns' mode for shared_xaxis, different columns should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='columns') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=2, col=1) + fig.add_trace(trace_3, row=1, col=2) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures that are in the same column should be the same, and different if they are in different columns + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis column don't match: Figure 1 and Figure 2 have different x-axes") + self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis column don't match: Figure 3 and Figure 4 have different x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis column match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis column match: Figure 2 and Figure 4 have the same x-axes") + + def test_yaxes_shared_rows_mode_single_row(self): + """ + When 'rows' mode for shared_xaxis, all of the traces in the same row should reference the same x-axis + """ + + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, YAxis + + fig : Figure = make_subplots(rows=1, cols=3, shared_yaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=1, col=3) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be the same + trace_1_yaxis : YAxis = fig.data[0].yaxis + trace_2_yaxis : YAxis = fig.data[1].yaxis + trace_3_yaxis : YAxis = fig.data[2].yaxis + + self.assertEqual(trace_1_yaxis, trace_2_yaxis, "Shared y-axis row don't match: Figure 1 and Figure 2 have different y-axes") + self.assertEqual(trace_1_yaxis, trace_3_yaxis, "Shared y-axis row don't match: Figure 1 and Figure 3 have different y-axes") + self.assertEqual(trace_2_yaxis, trace_3_yaxis, "Shared y-axis row don't match: Figure 2 and Figure 3 have different y-axes") + + def test_yaxes_shared_rows_mode_multiple_rows(self): + """ + When 'rows' mode for shared_xaxis, different rows should have different references + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, YAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_yaxes='rows') + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of figures in the same row should be the same, and different if they are in different rows + trace_1_yaxis : YAxis = fig.data[0].yaxis + trace_2_yaxis : YAxis = fig.data[1].yaxis + trace_3_yaxis : YAxis = fig.data[2].yaxis + trace_4_yaxis : YAxis = fig.data[3].yaxis + + self.assertEqual(trace_1_yaxis, trace_2_yaxis, "Shared y-axis row don't match: Figure 1 and Figure 2 have different y-axes") + self.assertEqual(trace_3_yaxis, trace_4_yaxis, "Shared y-axis row don't match: Figure 3 and Figure 4 have different y-axes") + self.assertNotEqual(trace_1_yaxis, trace_3_yaxis, "Different y-axis row match: Figure 1 and Figure 3 have the same y-axes") + self.assertNotEqual(trace_2_yaxis, trace_4_yaxis, "Different y-axis row match: Figure 2 and Figure 4 have the same y-axes") + + # def test_xaxes_shared_all_mode(self): # TODO: All mode is currently disabled as it causes all of the graphs to overlap in one subplot + # """ + # When 'all' mode for shared_xaxis, all rows share the same x-axes + # """ + # from plotly.subplots import make_subplots + # from plotly.graph_objects import Figure, Scatter, XAxis + + # fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes='all') + + # trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + # trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + # trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + # trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + # fig.add_trace(trace_1, row=1, col=1) + # fig.add_trace(trace_2, row=1, col=2) + # fig.add_trace(trace_3, row=2, col=1) + # fig.add_trace(trace_4, row=2, col=2) + + # fig.update_xaxes() + # fig.update_layout() + + # # The x-axis of all the figures should be the same + # trace_1_xaxis : XAxis = fig.data[0].xaxis + # trace_2_xaxis : XAxis = fig.data[1].xaxis + # trace_3_xaxis : XAxis = fig.data[2].xaxis + # trace_4_xaxis : XAxis = fig.data[3].xaxis + + # self.assertEqual(trace_1_xaxis, trace_2_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 2 have different x-axes") + # self.assertEqual(trace_3_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 3 and Figure 4 have different x-axes") + # self.assertEqual(trace_1_xaxis, trace_3_xaxis, "Shared x-axis all don't match: Figure 1 and Figure 3 have the same x-axes") + # self.assertEqual(trace_2_xaxis, trace_4_xaxis, "Shared x-axis all don't match: Figure 2 and Figure 4 have the same x-axes") + + def test_xaxes_not_shared_mode(self): + """ + When not shared, all plots have different x-axes + """ + from plotly.subplots import make_subplots + from plotly.graph_objects import Figure, Scatter, XAxis + + fig : Figure = make_subplots(rows=2, cols=2, shared_xaxes=False) + + trace_1 : Scatter = Scatter(x=[1, 2, 3], y=[1, 2, 3]) + trace_2 : Scatter = Scatter(x=[1, 2, 3], y=[4, 5, 6]) + trace_3 : Scatter = Scatter(x=[1, 2, 3], y=[7, 8, 9]) + trace_4 : Scatter = Scatter(x=[1, 2, 3], y=[10, 11, 12]) + + fig.add_trace(trace_1, row=1, col=1) + fig.add_trace(trace_2, row=1, col=2) + fig.add_trace(trace_3, row=2, col=1) + fig.add_trace(trace_4, row=2, col=2) + + + fig.update_xaxes() + fig.update_layout() + + # The x-axis of all of the figures should be different + trace_1_xaxis : XAxis = fig.data[0].xaxis + trace_2_xaxis : XAxis = fig.data[1].xaxis + trace_3_xaxis : XAxis = fig.data[2].xaxis + trace_4_xaxis : XAxis = fig.data[3].xaxis + + self.assertNotEqual(trace_1_xaxis, trace_2_xaxis, "Different x-axis match: Figure 1 and Figure 2 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_3_xaxis, "Different x-axis match: Figure 1 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_1_xaxis, trace_4_xaxis, "Different x-axis match: Figure 1 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_3_xaxis, "Different x-axis match: Figure 2 and Figure 3 have the same x-axes") + self.assertNotEqual(trace_2_xaxis, trace_4_xaxis, "Different x-axis match: Figure 2 and Figure 4 have the same x-axes") + self.assertNotEqual(trace_3_xaxis, trace_4_xaxis, "Different x-axis match: Figure 3 and Figure 4 have the same x-axes") \ No newline at end of file