From ec0cb89275baac5f1cc67bdf25d11b4ab7c7b52d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 7 Jul 2025 14:49:16 +0200 Subject: [PATCH] TST: update expected dtype for sum of decimals with pyarrow 21+ --- pandas/compat/__init__.py | 2 ++ pandas/compat/pyarrow.py | 2 ++ pandas/tests/extension/test_arrow.py | 6 +++++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pandas/compat/__init__.py b/pandas/compat/__init__.py index 8ed19f97958b9..d5dbcb74d29e4 100644 --- a/pandas/compat/__init__.py +++ b/pandas/compat/__init__.py @@ -35,6 +35,7 @@ pa_version_under18p0, pa_version_under19p0, pa_version_under20p0, + pa_version_under21p0, ) if TYPE_CHECKING: @@ -168,4 +169,5 @@ def is_ci_environment() -> bool: "pa_version_under18p0", "pa_version_under19p0", "pa_version_under20p0", + "pa_version_under21p0", ] diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py index 569d702592982..1e1989b276eb6 100644 --- a/pandas/compat/pyarrow.py +++ b/pandas/compat/pyarrow.py @@ -18,6 +18,7 @@ pa_version_under18p0 = _palv < Version("18.0.0") pa_version_under19p0 = _palv < Version("19.0.0") pa_version_under20p0 = _palv < Version("20.0.0") + pa_version_under21p0 = _palv < Version("21.0.0") HAS_PYARROW = _palv >= Version("12.0.1") except ImportError: pa_version_under12p1 = True @@ -30,4 +31,5 @@ pa_version_under18p0 = True pa_version_under19p0 = True pa_version_under20p0 = True + pa_version_under21p0 = True HAS_PYARROW = False diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 1bec5f7303355..8db837b176fe9 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -43,6 +43,7 @@ pa_version_under14p0, pa_version_under19p0, pa_version_under20p0, + pa_version_under21p0, ) from pandas.core.dtypes.dtypes import ( @@ -542,7 +543,10 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool): else: cmp_dtype = arr.dtype elif arr.dtype.name == "decimal128(7, 3)[pyarrow]": - if op_name not in ["median", "var", "std", "sem", "skew"]: + if op_name == "sum" and not pa_version_under21p0: + # https://github.com/apache/arrow/pull/44184 + cmp_dtype = ArrowDtype(pa.decimal128(38, 3)) + elif op_name not in ["median", "var", "std", "sem", "skew"]: cmp_dtype = arr.dtype else: cmp_dtype = "float64[pyarrow]"