You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			120 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			120 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
"""
 | 
						|
This file contains a minimal set of tests for compliance with the extension
 | 
						|
array interface test suite, and should contain no other tests.
 | 
						|
The test suite for the full functionality of the array is located in
 | 
						|
`pandas/tests/arrays/`.
 | 
						|
 | 
						|
The tests in this file are inherited from the BaseExtensionTests, and only
 | 
						|
minimal tweaks should be applied to get the tests passing (by overwriting a
 | 
						|
parent method).
 | 
						|
 | 
						|
Additional tests should either be added to one of the BaseExtensionTests
 | 
						|
classes (if they are relevant for the extension interface for all dtypes), or
 | 
						|
be added to the array-specific tests in `pandas/tests/arrays/`.
 | 
						|
 | 
						|
"""
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
from typing import TYPE_CHECKING
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import pytest
 | 
						|
 | 
						|
from pandas._libs import (
 | 
						|
    Period,
 | 
						|
    iNaT,
 | 
						|
)
 | 
						|
from pandas.compat import is_platform_windows
 | 
						|
from pandas.compat.numpy import np_version_gte1p24
 | 
						|
 | 
						|
from pandas.core.dtypes.dtypes import PeriodDtype
 | 
						|
 | 
						|
import pandas._testing as tm
 | 
						|
from pandas.core.arrays import PeriodArray
 | 
						|
from pandas.tests.extension import base
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    import pandas as pd
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture(params=["D", "2D"])
 | 
						|
def dtype(request):
 | 
						|
    return PeriodDtype(freq=request.param)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def data(dtype):
 | 
						|
    return PeriodArray(np.arange(1970, 2070), dtype=dtype)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def data_for_sorting(dtype):
 | 
						|
    return PeriodArray([2018, 2019, 2017], dtype=dtype)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def data_missing(dtype):
 | 
						|
    return PeriodArray([iNaT, 2017], dtype=dtype)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def data_missing_for_sorting(dtype):
 | 
						|
    return PeriodArray([2018, iNaT, 2017], dtype=dtype)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def data_for_grouping(dtype):
 | 
						|
    B = 2018
 | 
						|
    NA = iNaT
 | 
						|
    A = 2017
 | 
						|
    C = 2019
 | 
						|
    return PeriodArray([B, B, NA, NA, A, A, B, C], dtype=dtype)
 | 
						|
 | 
						|
 | 
						|
class TestPeriodArray(base.ExtensionTests):
 | 
						|
    def _get_expected_exception(self, op_name, obj, other):
 | 
						|
        if op_name in ("__sub__", "__rsub__"):
 | 
						|
            return None
 | 
						|
        return super()._get_expected_exception(op_name, obj, other)
 | 
						|
 | 
						|
    def _supports_accumulation(self, ser, op_name: str) -> bool:
 | 
						|
        return op_name in ["cummin", "cummax"]
 | 
						|
 | 
						|
    def _supports_reduction(self, obj, op_name: str) -> bool:
 | 
						|
        return op_name in ["min", "max", "median"]
 | 
						|
 | 
						|
    def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
 | 
						|
        if op_name == "median":
 | 
						|
            res_op = getattr(ser, op_name)
 | 
						|
 | 
						|
            alt = ser.astype("int64")
 | 
						|
 | 
						|
            exp_op = getattr(alt, op_name)
 | 
						|
            result = res_op(skipna=skipna)
 | 
						|
            expected = exp_op(skipna=skipna)
 | 
						|
            # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
 | 
						|
            # attribute "freq"
 | 
						|
            freq = ser.dtype.freq  # type: ignore[union-attr]
 | 
						|
            expected = Period._from_ordinal(int(expected), freq=freq)
 | 
						|
            tm.assert_almost_equal(result, expected)
 | 
						|
 | 
						|
        else:
 | 
						|
            return super().check_reduce(ser, op_name, skipna)
 | 
						|
 | 
						|
    @pytest.mark.parametrize("periods", [1, -2])
 | 
						|
    def test_diff(self, data, periods):
 | 
						|
        if is_platform_windows() and np_version_gte1p24:
 | 
						|
            with tm.assert_produces_warning(RuntimeWarning, check_stacklevel=False):
 | 
						|
                super().test_diff(data, periods)
 | 
						|
        else:
 | 
						|
            super().test_diff(data, periods)
 | 
						|
 | 
						|
    @pytest.mark.parametrize("na_action", [None, "ignore"])
 | 
						|
    def test_map(self, data, na_action):
 | 
						|
        result = data.map(lambda x: x, na_action=na_action)
 | 
						|
        tm.assert_extension_array_equal(result, data)
 | 
						|
 | 
						|
 | 
						|
class Test2DCompat(base.NDArrayBacked2DTests):
 | 
						|
    pass
 |