From 5f5b2d7a89166726d5df4dc816c353f2ce847a27 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 22 Dec 2025 10:40:48 -0500 Subject: [PATCH 1/2] Ensure SeasonResmpler preserves datetime resolution --- doc/whats-new.rst | 3 +++ xarray/groupers.py | 10 +++++++++- xarray/tests/test_groupby.py | 11 ++++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 68f0a0b7aee..a49564649cf 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ Bug Fixes ``np.isclose`` by default to handle accumulated floating point errors from slicing operations. Use ``exact=True`` for exact comparison (:pull:`11035`). By `Ian Hunt-Isaak `_. +- Ensure the :py:class:`~xarray.groupers.SeasonResampler` preserves the datetime + unit of the underlying time index when resampling (:issue:`11048`, + :pull:`11049`). By `Spencer Clark `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/groupers.py b/xarray/groupers.py index a16933e690f..f3cba83dd67 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -14,6 +14,7 @@ from collections import defaultdict from collections.abc import Hashable, Mapping, Sequence from dataclasses import dataclass, field +from functools import partial from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast @@ -40,6 +41,7 @@ Bins, DatetimeLike, GroupIndices, + PDDatetimeUnitOptions, ResampleCompatible, Self, SideOptions, @@ -61,6 +63,11 @@ RESAMPLE_DIM = "__resample_dim__" +def _construct_timestamp_as_unit(unit: PDDatetimeUnitOptions, **kwargs) -> pd.Timestamp: + """Construct a pandas.Timestamp object with a specific resolution.""" + return pd.Timestamp(**kwargs).as_unit(unit) + + @dataclass(init=False) class EncodedGroups: """ @@ -960,7 +967,8 @@ def factorize(self, group: T_Group) -> EncodedGroups: datetime_class = type(first_n_items(group.data, 1).item()) else: index_class = pd.DatetimeIndex - datetime_class = datetime.datetime + unit, _ = np.datetime_data(group.dtype) + datetime_class = partial(_construct_timestamp_as_unit, unit) # these are the seasons that are present unique_coord = index_class( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b9b9fb151c7..b4ef343716e 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -14,7 +14,7 @@ import xarray as xr from xarray import DataArray, Dataset, Variable, date_range from xarray.core.groupby import _consolidate_slices -from xarray.core.types import InterpOptions, ResampleCompatible +from xarray.core.types import InterpOptions, PDDatetimeUnitOptions, ResampleCompatible from xarray.groupers import ( BinGrouper, EncodedGroups, @@ -3605,6 +3605,15 @@ def test_season_resampler_groupby_identical(self): gb = da.groupby(time=resampler).sum() assert_identical(rs, gb) + def test_season_resampler_preserves_time_unit( + self, time_unit: PDDatetimeUnitOptions + ): + time = date_range("2000", periods=12, freq="MS", unit=time_unit) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"]) + rs = da.resample(time=resampler).sum() + assert rs.time.dtype == time.dtype + @pytest.mark.parametrize( "chunk", From 81fe4d31b17b78db3430f1b9fafc57a5c5b695eb Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 22 Dec 2025 11:43:30 -0500 Subject: [PATCH 2/2] Ensure solution works for pandas < 3 --- xarray/groupers.py | 54 ++++++++++++++++++++++++------------ xarray/tests/test_groupby.py | 7 +++-- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index f3cba83dd67..a26741ff3fe 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -12,7 +12,7 @@ import operator from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Callable, Hashable, Mapping, Sequence from dataclasses import dataclass, field from functools import partial from itertools import chain, pairwise @@ -39,6 +39,7 @@ from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import ( Bins, + CFTimeDatetime, DatetimeLike, GroupIndices, PDDatetimeUnitOptions, @@ -63,9 +64,14 @@ RESAMPLE_DIM = "__resample_dim__" -def _construct_timestamp_as_unit(unit: PDDatetimeUnitOptions, **kwargs) -> pd.Timestamp: - """Construct a pandas.Timestamp object with a specific resolution.""" - return pd.Timestamp(**kwargs).as_unit(unit) +def _datetime64_via_timestamp(unit: PDDatetimeUnitOptions, **kwargs) -> np.datetime64: + """Construct a numpy.datetime64 object through the pandas.Timestamp + constructor with a specific resolution.""" + # TODO: when pandas 3 is our minimum requirement we will no longer need to + # convert to np.datetime64 values prior to passing to the DatetimeIndex + # constructor. With pandas < 3 the DatetimeIndex constructor does not + # infer the resolution from the resolution of the Timestamp values. + return pd.Timestamp(**kwargs).as_unit(unit).to_numpy() @dataclass(init=False) @@ -962,20 +968,28 @@ def factorize(self, group: T_Group) -> EncodedGroups: counts = agged["count"] index_class: type[CFTimeIndex | pd.DatetimeIndex] + datetime_class: CFTimeDatetime | Callable[..., np.datetime64] if _contains_cftime_datetimes(group.data): index_class = CFTimeIndex datetime_class = type(first_n_items(group.data, 1).item()) else: index_class = pd.DatetimeIndex unit, _ = np.datetime_data(group.dtype) - datetime_class = partial(_construct_timestamp_as_unit, unit) + unit = cast(PDDatetimeUnitOptions, unit) + datetime_class = partial(_datetime64_via_timestamp, unit) # these are the seasons that are present + + # TODO: when pandas 3 is our minimum requirement we will no longer need + # to cast the list to a NumPy array prior to passing to the index + # constructor. unique_coord = index_class( - [ - datetime_class(year=year, month=season_tuples[season][0], day=1) - for year, season in first_items.index - ] + np.array( + [ + datetime_class(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) ) # This sorted call is a hack. It's hard to figure out how @@ -983,15 +997,21 @@ def factorize(self, group: T_Group) -> EncodedGroups: # for example "DJF" as first entry or last entry # So we construct the largest possible index and slice it to the # range present in the data. + + # TODO: when pandas 3 is our minimum requirement we will no longer need + # to cast the list to a NumPy array prior to passing to the index + # constructor. complete_index = index_class( - sorted( - [ - datetime_class(year=y, month=m, day=1) - for y, m in itertools.product( - range(year[0].item(), year[-1].item() + 1), - [s[0] for s in season_inds], - ) - ] + np.array( + sorted( + [ + datetime_class(year=y, month=m, day=1) + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) ) ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b4ef343716e..c320931098a 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3607,12 +3607,13 @@ def test_season_resampler_groupby_identical(self): def test_season_resampler_preserves_time_unit( self, time_unit: PDDatetimeUnitOptions - ): + ) -> None: time = date_range("2000", periods=12, freq="MS", unit=time_unit) da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"]) - rs = da.resample(time=resampler).sum() - assert rs.time.dtype == time.dtype + result = da.resample(time=resampler).sum() + result_unit, _ = np.datetime_data(result.time.dtype) + assert result_unit == time_unit @pytest.mark.parametrize(