diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 7f1524642b..5168640133 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -18,6 +18,8 @@ import itertools import os +import random +import time import uuid import warnings from abc import ABC, abstractmethod @@ -31,6 +33,7 @@ from pydantic import Field import pyiceberg.expressions.parser as parser +from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, BooleanExpression, EqualTo, IsNull, Or, Reference from pyiceberg.expressions.visitors import ( ResidualEvaluator, @@ -205,6 +208,22 @@ class TableProperties: MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep" MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1 + COMMIT_NUM_RETRIES = "commit.retry.num-retries" + COMMIT_NUM_RETRIES_DEFAULT = 4 + + COMMIT_MIN_RETRY_WAIT_MS = "commit.retry.min-wait-ms" + COMMIT_MIN_RETRY_WAIT_MS_DEFAULT = 100 + + COMMIT_MAX_RETRY_WAIT_MS = "commit.retry.max-wait-ms" + COMMIT_MAX_RETRY_WAIT_MS_DEFAULT = 60000 + + COMMIT_TOTAL_RETRY_TIME_MS = "commit.retry.total-timeout-ms" + COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT = 1800000 # 30 minutes + + WRITE_DELETE_ISOLATION_LEVEL = "write.delete.isolation-level" + WRITE_UPDATE_ISOLATION_LEVEL = "write.update.isolation-level" + WRITE_ISOLATION_LEVEL_DEFAULT = "serializable" + class Transaction: _table: Table @@ -223,6 +242,7 @@ def __init__(self, table: Table, autocommit: bool = False): self._autocommit = autocommit self._updates = () self._requirements = () + self._snapshot_producers: list[Any] = [] @property def table_metadata(self) -> TableMetadata: @@ -265,6 +285,10 @@ def _stage( return self + def _register_snapshot_producer(self, producer: Any) -> None: + """Register a snapshot producer for retry support.""" + self._snapshot_producers.append(producer) + def _apply( self, updates: tuple[TableUpdate, ...], @@ -546,7 +570,12 @@ def dynamic_partition_overwrite( delete_filter = self._build_partition_predicate( partition_records=partitions_to_overwrite, spec=self.table_metadata.spec(), schema=self.table_metadata.schema() ) - self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties, branch=branch) + self.delete( + delete_filter=delete_filter, + snapshot_properties=snapshot_properties, + branch=branch, + _isolation_level_property=TableProperties.WRITE_UPDATE_ISOLATION_LEVEL, + ) with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: append_files.commit_uuid = append_snapshot_commit_uuid @@ -603,6 +632,7 @@ def overwrite( case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch, + _isolation_level_property=TableProperties.WRITE_UPDATE_ISOLATION_LEVEL, ) with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: @@ -620,6 +650,7 @@ def delete( snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, + _isolation_level_property: str | None = None, ) -> None: """ Shorthand for deleting record from a table. @@ -647,6 +678,8 @@ def delete( delete_filter = _parse_row_filter(delete_filter) with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot: + if _isolation_level_property is not None: + delete_snapshot._isolation_level_property = _isolation_level_property delete_snapshot.delete_by_predicate(delete_filter, case_sensitive) # Check if there are any files that require an actual rewrite of a data file @@ -702,7 +735,10 @@ def delete( with self.update_snapshot( snapshot_properties=snapshot_properties, branch=branch ).overwrite() as overwrite_snapshot: + if _isolation_level_property is not None: + overwrite_snapshot._isolation_level_property = _isolation_level_property overwrite_snapshot.commit_uuid = commit_uuid + overwrite_snapshot.delete_by_predicate(delete_filter, case_sensitive) for original_data_file, replaced_data_files in replaced_files: overwrite_snapshot.delete_data_file(original_data_file) for replaced_data_file in replaced_data_files: @@ -939,17 +975,73 @@ def commit_transaction(self) -> Table: The table with the updates applied. """ if len(self._updates) > 0: - self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),) - self._table._do_commit( # pylint: disable=W0212 - updates=self._updates, - requirements=self._requirements, + from pyiceberg.utils.properties import property_as_int + + properties = self._table.metadata.properties + num_retries_val = property_as_int( + properties, TableProperties.COMMIT_NUM_RETRIES, TableProperties.COMMIT_NUM_RETRIES_DEFAULT + ) + num_retries = num_retries_val if num_retries_val is not None else TableProperties.COMMIT_NUM_RETRIES_DEFAULT + min_wait_val = property_as_int( + properties, TableProperties.COMMIT_MIN_RETRY_WAIT_MS, TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT + ) + min_wait_ms = min_wait_val if min_wait_val is not None else TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT + max_wait_val = property_as_int( + properties, TableProperties.COMMIT_MAX_RETRY_WAIT_MS, TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT ) + max_wait_ms = max_wait_val if max_wait_val is not None else TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT + total_timeout_val = property_as_int( + properties, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT + ) + total_timeout_ms = ( + total_timeout_val if total_timeout_val is not None else TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT + ) + start_time = time.monotonic() + + for attempt in range(num_retries + 1): + try: + self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),) + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=self._requirements, + ) + self._cleanup_uncommitted_manifests() + break + except CommitFailedException: + elapsed_ms = (time.monotonic() - start_time) * 1000 + if attempt == num_retries or not self._snapshot_producers or elapsed_ms >= total_timeout_ms: + raise + + wait = min(min_wait_ms * (2**attempt), max_wait_ms) + jitter = random.uniform(0, 0.25 * wait) + time.sleep((wait + jitter) / 1000.0) + + self._table.refresh() + self._rebuild_snapshot_updates() self._updates = () self._requirements = () return self._table + def _cleanup_uncommitted_manifests(self) -> None: + """Clean up manifests from failed retry attempts after a successful commit.""" + for producer in self._snapshot_producers: + producer._cleanup_uncommitted() + + def _rebuild_snapshot_updates(self) -> None: + """Rebuild snapshot updates for retry by re-executing registered producers.""" + from pyiceberg.table.update import AddSnapshotUpdate, AssertRefSnapshotId, SetSnapshotRefUpdate + + self._updates = tuple(u for u in self._updates if not isinstance(u, (AddSnapshotUpdate, SetSnapshotRefUpdate))) + self._requirements = tuple(r for r in self._requirements if not isinstance(r, (AssertRefSnapshotId, AssertTableUUID))) + + for producer in self._snapshot_producers: + producer._refresh_for_retry() + producer._validate_concurrency() + updates, requirements = producer._commit() + self._stage(updates, requirements) + class CreateTableTransaction(Transaction): """A transaction that involves the creation of a new table.""" @@ -1961,13 +2053,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu # The lambda created here is run in multiple threads. # So we avoid creating _EvaluatorExpression methods bound to a single # shared instance across multiple threads. - return lambda datafile: ( - residual_evaluator_of( - spec=spec, - expr=self.row_filter, - case_sensitive=self.case_sensitive, - schema=self.table_metadata.schema(), - ) + return lambda datafile: residual_evaluator_of( + spec=spec, + expr=self.row_filter, + case_sensitive=self.case_sensitive, + schema=self.table_metadata.schema(), ) @staticmethod diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..e10965e621 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -86,6 +86,13 @@ def __repr__(self) -> str: return f"Operation.{self.name}" +class IsolationLevel(str, Enum): + """Transaction isolation level for concurrent write validation.""" + + SERIALIZABLE = "serializable" + SNAPSHOT = "snapshot" + + class UpdateMetrics: added_file_size: int removed_file_size: int diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 1df4e64c24..7ae6d898f6 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -17,6 +17,7 @@ from __future__ import annotations import itertools +import logging import uuid from abc import abstractmethod from collections import defaultdict @@ -80,6 +81,8 @@ if TYPE_CHECKING: from pyiceberg.table import Transaction +logger = logging.getLogger(__name__) + def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str: return f"{commit_uuid}-m{num}.avro" @@ -104,6 +107,8 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _target_branch: str | None _predicate: BooleanExpression _case_sensitive: bool + _written_manifests: list[str] + _uncommitted_manifests: list[str] def __init__( self, @@ -123,6 +128,8 @@ def __init__( self._deleted_data_files = set() self.snapshot_properties = snapshot_properties self._manifest_num_counter = itertools.count(0) + self._written_manifests = [] + self._uncommitted_manifests = [] from pyiceberg.table import TableProperties self._compression = self._transaction.table_metadata.properties.get( # type: ignore @@ -134,6 +141,7 @@ def __init__( ) self._predicate = AlwaysFalse() self._case_sensitive = True + self._isolation_level_property: str = TableProperties.WRITE_DELETE_ISOLATION_LEVEL def _validate_target_branch(self, branch: str | None) -> str | None: # if branch is none, write will be written into a staging snapshot @@ -351,11 +359,39 @@ def new_manifest_output(self) -> OutputFile: location_provider = self._transaction._table.location_provider() file_name = _new_manifest_file_name(num=next(self._manifest_num_counter), commit_uuid=self.commit_uuid) file_path = location_provider.new_metadata_location(file_name) + self._written_manifests.append(file_path) return self._io.new_output(file_path) def fetch_manifest_entry(self, manifest: ManifestFile, discard_deleted: bool = True) -> list[ManifestEntry]: return manifest.fetch_manifest_entry(io=self._io, discard_deleted=discard_deleted) + def commit(self) -> None: + self._transaction._register_snapshot_producer(self) + self._transaction._apply(*self._commit()) + + def _cleanup_uncommitted(self) -> None: + """Delete manifest files from failed retry attempts.""" + for path in self._uncommitted_manifests: + try: + self._io.delete(path) + except Exception: + logger.warning("Failed to delete uncommitted manifest: %s", path, exc_info=True) + self._uncommitted_manifests.clear() + + def _refresh_for_retry(self) -> None: + """Reset state for a retry attempt with refreshed metadata.""" + self._uncommitted_manifests.extend(self._written_manifests) + self._written_manifests.clear() + self._parent_snapshot_id = ( + snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None + ) + self._snapshot_id = self._transaction.table_metadata.new_snapshot_id() + self._manifest_num_counter = itertools.count(0) + self.commit_uuid = uuid.uuid4() + + def _validate_concurrency(self) -> None: + """Validate that concurrent changes do not conflict with this operation. No-op by default.""" + def _build_partition_projection(self, spec_id: int) -> BooleanExpression: project = inclusive_projection(self.schema(), self.spec(spec_id), self._case_sensitive) return project(self._predicate) @@ -382,7 +418,8 @@ def _build_delete_files_partition_predicate(self) -> None: self.delete_by_predicate( self._transaction._build_partition_predicate( partition_records=partition_records, schema=self.schema(), spec=self.spec(spec_id) - ) + ), + self._case_sensitive, ) @@ -495,6 +532,57 @@ def files_affected(self) -> bool: """Indicate if any manifest-entries can be dropped.""" return len(self._deleted_entries()) > 0 + def _refresh_for_retry(self) -> None: + """Reset state for a retry attempt, clearing the cached delete computation.""" + super()._refresh_for_retry() + # Clear @cached_property by removing it from the instance __dict__. + # _compute_deletes depends on _parent_snapshot_id which changes on retry. + if "_compute_deletes" in self.__dict__: + del self.__dict__["_compute_deletes"] + + def _validate_concurrency(self) -> None: + """Validate that concurrent changes do not conflict with this delete. + + Note: This method is intentionally duplicated in _OverwriteFiles rather than + extracted to the base class. While the logic is currently identical, Java Iceberg's + BaseOverwriteFiles and BaseRowDelta have divergent validation. Keeping them separate + makes it easier to add RowDelta-specific validation in the future. + """ + from pyiceberg.table import TableProperties + from pyiceberg.table.snapshots import IsolationLevel + from pyiceberg.table.update.validate import ( + _validate_added_data_files, + _validate_deleted_data_files, + _validate_no_new_delete_files, + _validate_no_new_deletes_for_data_files, + ) + + if self._parent_snapshot_id is None: + return + + table = self._transaction._table + parent_snapshot = table.metadata.snapshot_by_id(self._parent_snapshot_id) + if parent_snapshot is None: + return + + isolation_level_str = table.metadata.properties.get( + self._isolation_level_property, TableProperties.WRITE_ISOLATION_LEVEL_DEFAULT + ) + isolation_level = IsolationLevel(isolation_level_str) + conflict_detection_filter = self._predicate if self._predicate != AlwaysFalse() else None + + if isolation_level == IsolationLevel.SERIALIZABLE: + _validate_added_data_files(table, parent_snapshot, conflict_detection_filter, parent_snapshot) + + if conflict_detection_filter is not None: + _validate_no_new_delete_files(table, parent_snapshot, conflict_detection_filter, None, parent_snapshot) + _validate_deleted_data_files(table, parent_snapshot, conflict_detection_filter, parent_snapshot) + + if self._deleted_data_files: + _validate_no_new_deletes_for_data_files( + table, parent_snapshot, conflict_detection_filter, self._deleted_data_files, parent_snapshot + ) + class _FastAppendFiles(_SnapshotProducer["_FastAppendFiles"]): def _existing_manifests(self) -> list[ManifestFile]: @@ -666,6 +754,47 @@ def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: else: return [] + def _validate_concurrency(self) -> None: + """Validate that concurrent changes do not conflict with this overwrite. + + Note: See _DeleteFiles._validate_concurrency() for why this is intentionally + duplicated rather than extracted to the base class. + """ + from pyiceberg.table import TableProperties + from pyiceberg.table.snapshots import IsolationLevel + from pyiceberg.table.update.validate import ( + _validate_added_data_files, + _validate_deleted_data_files, + _validate_no_new_delete_files, + _validate_no_new_deletes_for_data_files, + ) + + if self._parent_snapshot_id is None: + return + + table = self._transaction._table + parent_snapshot = table.metadata.snapshot_by_id(self._parent_snapshot_id) + if parent_snapshot is None: + return + + isolation_level_str = table.metadata.properties.get( + self._isolation_level_property, TableProperties.WRITE_ISOLATION_LEVEL_DEFAULT + ) + isolation_level = IsolationLevel(isolation_level_str) + conflict_detection_filter = self._predicate if self._predicate != AlwaysFalse() else None + + if isolation_level == IsolationLevel.SERIALIZABLE: + _validate_added_data_files(table, parent_snapshot, conflict_detection_filter, parent_snapshot) + + if conflict_detection_filter is not None: + _validate_no_new_delete_files(table, parent_snapshot, conflict_detection_filter, None, parent_snapshot) + _validate_deleted_data_files(table, parent_snapshot, conflict_detection_filter, parent_snapshot) + + if self._deleted_data_files: + _validate_no_new_deletes_for_data_files( + table, parent_snapshot, conflict_detection_filter, self._deleted_data_files, parent_snapshot + ) + class UpdateSnapshot: _transaction: Transaction diff --git a/tests/integration/test_writes/test_optimistic_concurrency.py b/tests/integration/test_writes/test_optimistic_concurrency.py index 6ddf4c11d5..299c58c833 100644 --- a/tests/integration/test_writes/test_optimistic_concurrency.py +++ b/tests/integration/test_writes/test_optimistic_concurrency.py @@ -20,7 +20,7 @@ from pyspark.sql import SparkSession from pyiceberg.catalog import Catalog -from pyiceberg.exceptions import CommitFailedException +from pyiceberg.exceptions import ValidationException from utils import _create_table @@ -29,15 +29,14 @@ def test_conflict_delete_delete( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: - """This test should start passing once optimistic concurrency control has been implemented.""" + """Concurrent deletes on the same data should fail with ValidationException.""" identifier = "default.test_conflict" tbl1 = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) tbl2 = session_catalog.load_table(identifier) tbl1.delete("string == 'z'") - with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): - # tbl2 isn't aware of the commit by tbl1 + with pytest.raises(ValidationException): tbl2.delete("string == 'z'") @@ -46,17 +45,13 @@ def test_conflict_delete_delete( def test_conflict_delete_append( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: - """This test should start passing once optimistic concurrency control has been implemented.""" + """Append after a concurrent delete should succeed via retry.""" identifier = "default.test_conflict" tbl1 = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) tbl2 = session_catalog.load_table(identifier) - # This is allowed tbl1.delete("string == 'z'") - - with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): - # tbl2 isn't aware of the commit by tbl1 - tbl2.append(arrow_table_with_null) + tbl2.append(arrow_table_with_null) @pytest.mark.integration @@ -64,15 +59,14 @@ def test_conflict_delete_append( def test_conflict_append_delete( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: - """This test should start passing once optimistic concurrency control has been implemented.""" + """Delete after a concurrent append fails with ValidationException under serializable isolation.""" identifier = "default.test_conflict" tbl1 = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) tbl2 = session_catalog.load_table(identifier) tbl1.append(arrow_table_with_null) - with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): - # tbl2 isn't aware of the commit by tbl1 + with pytest.raises(ValidationException): tbl2.delete("string == 'z'") @@ -81,13 +75,10 @@ def test_conflict_append_delete( def test_conflict_append_append( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: - """This test should start passing once optimistic concurrency control has been implemented.""" + """Concurrent appends should both succeed via retry.""" identifier = "default.test_conflict" tbl1 = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) tbl2 = session_catalog.load_table(identifier) tbl1.append(arrow_table_with_null) - - with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): - # tbl2 isn't aware of the commit by tbl1 - tbl2.append(arrow_table_with_null) + tbl2.append(arrow_table_with_null) diff --git a/tests/table/test_commit_retry.py b/tests/table/test_commit_retry.py new file mode 100644 index 0000000000..11e6cbc80a --- /dev/null +++ b/tests/table/test_commit_retry.py @@ -0,0 +1,558 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any +from unittest.mock import patch + +import pytest + +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import CommitFailedException, ValidationException +from pyiceberg.schema import Schema +from pyiceberg.table import TableProperties, Transaction +from pyiceberg.table.snapshots import IsolationLevel, Operation +from pyiceberg.types import LongType, NestedField, StringType + + +def test_isolation_level_enum() -> None: + assert IsolationLevel.SERIALIZABLE.value == "serializable" + assert IsolationLevel.SNAPSHOT.value == "snapshot" + assert IsolationLevel("serializable") is IsolationLevel.SERIALIZABLE + assert IsolationLevel("snapshot") is IsolationLevel.SNAPSHOT + + +def test_commit_retry_table_properties() -> None: + assert TableProperties.COMMIT_NUM_RETRIES == "commit.retry.num-retries" + assert TableProperties.COMMIT_NUM_RETRIES_DEFAULT == 4 + assert TableProperties.COMMIT_MIN_RETRY_WAIT_MS == "commit.retry.min-wait-ms" + assert TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT == 100 + assert TableProperties.COMMIT_MAX_RETRY_WAIT_MS == "commit.retry.max-wait-ms" + assert TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT == 60000 + assert TableProperties.COMMIT_TOTAL_RETRY_TIME_MS == "commit.retry.total-timeout-ms" + assert TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT == 1800000 + + +def test_isolation_level_table_properties() -> None: + assert TableProperties.WRITE_DELETE_ISOLATION_LEVEL == "write.delete.isolation-level" + assert TableProperties.WRITE_UPDATE_ISOLATION_LEVEL == "write.update.isolation-level" + assert TableProperties.WRITE_ISOLATION_LEVEL_DEFAULT == "serializable" + + +def _test_schema() -> Schema: + return Schema(NestedField(1, "x", LongType(), required=False)) + + +def test_commit_retry_on_commit_failed(catalog: Catalog) -> None: + """Verify that CommitFailedException triggers retry for append operations.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.retry_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + # Load two references to the same table to simulate concurrent access + tbl1 = catalog.load_table("default.retry_test") + tbl2 = catalog.load_table("default.retry_test") + + # First append succeeds + tbl1.append(df) + + # Second append should succeed via retry (append vs append never conflicts) + import pyiceberg.table as _table_module + + RuntimeTransaction = _table_module.Transaction + original_rebuild = RuntimeTransaction._rebuild_snapshot_updates + rebuild_count = 0 + + def counting_rebuild(self_tx: Any) -> None: + nonlocal rebuild_count + rebuild_count += 1 + original_rebuild(self_tx) + + with patch.object(RuntimeTransaction, "_rebuild_snapshot_updates", counting_rebuild): + tbl2.append(df) + + assert rebuild_count == 1, "Expected exactly one retry via _rebuild_snapshot_updates" + + # Both appends should be visible + refreshed = catalog.load_table("default.retry_test") + result = refreshed.scan().to_arrow() + assert len(result) == 6 + + +def test_no_retry_without_snapshot_producers(catalog: Catalog) -> None: + """Verify that a transaction with no snapshot producers has an empty producer list.""" + catalog.create_namespace("default") + schema = _test_schema() + table = catalog.create_table("default.no_retry_test", schema=schema) + + tx = Transaction(table, autocommit=False) + tx.set_properties({"key": "value"}) + + # No snapshot producers registered + assert len(tx._snapshot_producers) == 0 + + +def test_rebuild_snapshot_updates_preserves_non_snapshot_updates(catalog: Catalog) -> None: + """Verify that non-snapshot updates survive retry.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.rebuild_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1]}) + + tbl1 = catalog.load_table("default.rebuild_test") + tbl2 = catalog.load_table("default.rebuild_test") + + # tbl1 commits first + tbl1.append(df) + + # tbl2 does both property change and append in one transaction + with tbl2.transaction() as tx: + tx.set_properties({"test_key": "test_value"}) + tx.append(df) + + # Both the property and the data should be committed + refreshed = catalog.load_table("default.rebuild_test") + assert refreshed.metadata.properties.get("test_key") == "test_value" + assert len(refreshed.scan().to_arrow()) == 2 + + +def test_refresh_for_retry_resets_producer_state(catalog: Catalog) -> None: + """Verify that _refresh_for_retry resets the necessary fields.""" + catalog.create_namespace("default") + schema = _test_schema() + table = catalog.create_table("default.refresh_test", schema=schema) + + from pyiceberg.table.update.snapshot import _FastAppendFiles + + tx = Transaction(table, autocommit=False) + producer = _FastAppendFiles( + operation=Operation.APPEND, + transaction=tx, + io=table.io, + ) + + original_snapshot_id = producer._snapshot_id + original_uuid = producer.commit_uuid + + producer._refresh_for_retry() + + assert producer._snapshot_id != original_snapshot_id + assert producer.commit_uuid != original_uuid + # parent stays None for empty table + assert producer._parent_snapshot_id is None + + +def test_concurrent_delete_delete_raises_validation_exception(catalog: Catalog) -> None: + """Concurrent deletes on the same data should fail with ValidationException.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.del_del_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.del_del_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.del_del_test") + tbl2 = catalog.load_table("default.del_del_test") + + tbl1.delete("x == 1") + + with pytest.raises(ValidationException): + tbl2.delete("x == 1") + + +def test_concurrent_append_delete_raises_validation_exception(catalog: Catalog) -> None: + """Delete after a concurrent append fails with ValidationException under serializable isolation.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.app_del_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.app_del_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.app_del_test") + tbl2 = catalog.load_table("default.app_del_test") + + tbl1.append(df) + + with pytest.raises(ValidationException): + tbl2.delete("x == 1") + + +def test_concurrent_delete_append_retries_successfully(catalog: Catalog) -> None: + """Append after a concurrent delete should succeed via retry.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.del_app_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.del_app_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.del_app_test") + tbl2 = catalog.load_table("default.del_app_test") + + tbl1.delete("x == 1") + + import pyiceberg.table as _table_module + + RuntimeTransaction = _table_module.Transaction + original_rebuild = RuntimeTransaction._rebuild_snapshot_updates + rebuild_count = 0 + + def counting_rebuild(self_tx: Any) -> None: + nonlocal rebuild_count + rebuild_count += 1 + original_rebuild(self_tx) + + with patch.object(RuntimeTransaction, "_rebuild_snapshot_updates", counting_rebuild): + tbl2.append(df) + + assert rebuild_count == 1 + + refreshed = catalog.load_table("default.del_app_test") + result = refreshed.scan().to_arrow() + # Original 3 rows, minus 1 deleted, plus 3 appended = 5 + assert len(result) == 5 + + +def test_retry_exhaustion_raises_commit_failed(catalog: Catalog) -> None: + """When retries are exhausted, CommitFailedException should be raised.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table( + "default.exhaust_test", + schema=schema, + properties={"commit.retry.num-retries": "0"}, + ) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl1 = catalog.load_table("default.exhaust_test") + tbl2 = catalog.load_table("default.exhaust_test") + + tbl1.append(df) + + with pytest.raises(CommitFailedException): + tbl2.append(df) + + +def test_delete_files_refresh_clears_compute_deletes_cache(catalog: Catalog) -> None: + """Verify that _refresh_for_retry clears the _compute_deletes cached property.""" + catalog.create_namespace("default") + schema = _test_schema() + table = catalog.create_table("default.cache_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + table.append(df) + table = catalog.load_table("default.cache_test") + + from pyiceberg.expressions import EqualTo + from pyiceberg.table.update.snapshot import _DeleteFiles + + tx = Transaction(table, autocommit=False) + producer = _DeleteFiles( + operation=Operation.DELETE, + transaction=tx, + io=table.io, + ) + producer.delete_by_predicate(EqualTo("x", 1)) + + # Access _compute_deletes to populate the cache + _ = producer._compute_deletes + + assert "_compute_deletes" in producer.__dict__ + + producer._refresh_for_retry() + + assert "_compute_deletes" not in producer.__dict__ + + +def test_concurrent_overwrite_overwrite_raises_validation_exception(catalog: Catalog) -> None: + """Concurrent overwrites on the same data should fail with ValidationException.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.ow_ow_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.ow_ow_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.ow_ow_test") + tbl2 = catalog.load_table("default.ow_ow_test") + + tbl1.overwrite(pa.table({"x": [10, 20, 30]}), overwrite_filter="x > 0") + with pytest.raises(ValidationException): + tbl2.overwrite(pa.table({"x": [40, 50, 60]}), overwrite_filter="x > 0") + + +def test_concurrent_overwrite_append_retries_successfully(catalog: Catalog) -> None: + """Append after a concurrent overwrite should succeed via retry.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.ow_app_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.ow_app_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.ow_app_test") + tbl2 = catalog.load_table("default.ow_app_test") + + tbl1.overwrite(pa.table({"x": [10, 20, 30]}), overwrite_filter="x > 0") + tbl2.append(pa.table({"x": [4, 5, 6]})) + + refreshed = catalog.load_table("default.ow_app_test") + result = refreshed.scan().to_arrow() + # overwrite replaced 3 rows with 3 new rows, then append added 3 more = 6 + assert len(result) == 6 + + +def test_snapshot_isolation_allows_concurrent_append_delete(catalog: Catalog) -> None: + """Under snapshot isolation, delete after a concurrent append should succeed via retry.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table( + "default.snapshot_iso_test", + schema=schema, + properties={"write.delete.isolation-level": "snapshot"}, + ) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.snapshot_iso_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.snapshot_iso_test") + tbl2 = catalog.load_table("default.snapshot_iso_test") + + tbl1.append(df) + + # Under serializable this would raise ValidationException, + # but under snapshot isolation _validate_added_data_files is skipped + tbl2.delete("x == 1") + + refreshed = catalog.load_table("default.snapshot_iso_test") + result = refreshed.scan().to_arrow() + # Original 3, delete removes x==1 from original (1 row), append adds 3 = 5 + assert len(result) == 5 + + +def test_uncommitted_manifests_tracked_correctly(catalog: Catalog) -> None: + """Verify that uncommitted manifests are moved to _uncommitted_manifests on retry.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table("default.manifest_track_test", schema=schema) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.manifest_track_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.manifest_track_test") + tbl2 = catalog.load_table("default.manifest_track_test") + + tbl1.append(df) + + import pyiceberg.table as _table_module2 + + RuntimeTransaction2 = _table_module2.Transaction + original_rebuild = RuntimeTransaction2._rebuild_snapshot_updates + uncommitted_count_during_rebuild = 0 + + def checking_rebuild(self_tx: Any) -> None: + nonlocal uncommitted_count_during_rebuild + original_rebuild(self_tx) + for producer in self_tx._snapshot_producers: + uncommitted_count_during_rebuild += len(producer._uncommitted_manifests) + + with patch.object(RuntimeTransaction2, "_rebuild_snapshot_updates", checking_rebuild): + tbl2.append(df) + + # After rebuild, the first attempt's manifests should be in _uncommitted_manifests + assert uncommitted_count_during_rebuild > 0 + + +def test_concurrent_deletes_on_different_partitions_succeed(catalog: Catalog) -> None: + """Concurrent deletes on different partitions should succeed via retry thanks to conflict detection filter.""" + from pyiceberg.partitioning import PartitionField, PartitionSpec + from pyiceberg.transforms import IdentityTransform + + catalog.create_namespace("default") + schema = Schema( + NestedField(1, "category", StringType(), required=False), + NestedField(2, "value", LongType(), required=False), + ) + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="category")) + catalog.create_table("default.part_del_test", schema=schema, partition_spec=spec) + + import pyarrow as pa + + df = pa.table( + { + "category": ["a", "a", "b", "b"], + "value": [1, 2, 3, 4], + } + ) + + tbl = catalog.load_table("default.part_del_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.part_del_test") + tbl2 = catalog.load_table("default.part_del_test") + + # Delete from different partitions should not conflict + tbl1.delete("category == 'a'") + tbl2.delete("category == 'b'") + + refreshed = catalog.load_table("default.part_del_test") + result = refreshed.scan().to_arrow() + assert len(result) == 0 + + +def test_concurrent_partial_deletes_on_different_partitions_succeed(catalog: Catalog) -> None: + """Concurrent partial deletes (CoW rewrite) on different partitions should succeed. + + This tests the auto-computed partition predicate from _build_delete_files_partition_predicate. + """ + from pyiceberg.partitioning import PartitionField, PartitionSpec + from pyiceberg.transforms import IdentityTransform + + catalog.create_namespace("default") + schema = Schema( + NestedField(1, "category", StringType(), required=False), + NestedField(2, "value", LongType(), required=False), + ) + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="category")) + catalog.create_table("default.part_partial_del_test", schema=schema, partition_spec=spec) + + import pyarrow as pa + + df = pa.table( + { + "category": ["a", "a", "b", "b"], + "value": [1, 2, 3, 4], + } + ) + + tbl = catalog.load_table("default.part_partial_del_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.part_partial_del_test") + tbl2 = catalog.load_table("default.part_partial_del_test") + + # Partial delete: only value==1 in partition a, triggers CoW rewrite + tbl1.delete("value == 1") + # Partial delete: only value==3 in partition b, triggers CoW rewrite + tbl2.delete("value == 3") + + refreshed = catalog.load_table("default.part_partial_del_test") + result = refreshed.scan().to_arrow() + # Original 4 rows, minus value==1 and value==3 = 2 rows remaining + assert len(result) == 2 + + +def test_overwrite_uses_update_isolation_level(catalog: Catalog) -> None: + """Verify that overwrite() reads write.update.isolation-level, not write.delete.isolation-level.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table( + "default.update_iso_test", + schema=schema, + properties={ + "write.delete.isolation-level": "serializable", + "write.update.isolation-level": "snapshot", + }, + ) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.update_iso_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.update_iso_test") + tbl2 = catalog.load_table("default.update_iso_test") + + tbl1.append(df) + + # Under write.delete.isolation-level=serializable this would raise ValidationException. + # But overwrite() uses write.update.isolation-level=snapshot, so it succeeds. + tbl2.overwrite(pa.table({"x": [10, 20, 30]}), overwrite_filter="x > 0") + + refreshed = catalog.load_table("default.update_iso_test") + result = refreshed.scan().to_arrow() + # overwrite with x > 0 deletes all rows (including tbl1's append), then adds 3 new rows + assert len(result) == 3 + + +def test_overwrite_with_serializable_update_isolation_raises(catalog: Catalog) -> None: + """Verify that overwrite() raises ValidationException when write.update.isolation-level=serializable.""" + catalog.create_namespace("default") + schema = _test_schema() + catalog.create_table( + "default.update_serial_test", + schema=schema, + properties={ + "write.update.isolation-level": "serializable", + }, + ) + + import pyarrow as pa + + df = pa.table({"x": [1, 2, 3]}) + + tbl = catalog.load_table("default.update_serial_test") + tbl.append(df) + + tbl1 = catalog.load_table("default.update_serial_test") + tbl2 = catalog.load_table("default.update_serial_test") + + tbl1.append(df) + + with pytest.raises(ValidationException): + tbl2.overwrite(pa.table({"x": [10, 20, 30]}), overwrite_filter="x > 0")