diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 358a8eef..900a3d57 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -279,18 +279,43 @@ def test_ihfft(x, data): ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape) -@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5))) +@given( + n=st.integers(1, 100), + kw=hh.kwargs(d=st.floats(0.1, 5), dtype=hh.real_floating_dtypes), +) def test_fftfreq(n, kw): - out = xp.fft.fftfreq(n, **kw) - ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n}) - + repro_snippet = ph.format_snippet(f"xp.fft.fftfreq({n!r}, **kw) with {kw = }") + try: + out = xp.fft.fftfreq(n, **kw) + ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n}) + + dt = kw.get("dtype", None) + if dt is None: + dt = xp.__array_namespace_info__().default_dtypes()["real floating"] + assert out.dtype == dt + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise -@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5))) +@given( + n=st.integers(1, 100), + kw=hh.kwargs(d=st.floats(0.1, 5), dtype=hh.real_floating_dtypes) +) def test_rfftfreq(n, kw): - out = xp.fft.rfftfreq(n, **kw) - ph.assert_shape( - "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n} - ) + repro_snippet = ph.format_snippet(f"xp.fft.rfftfreq({n!r}, **kw) with {kw = }") + try: + out = xp.fft.rfftfreq(n, **kw) + ph.assert_shape( + "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n} + ) + + dt = kw.get("dtype", None) + if dt is None: + dt = xp.__array_namespace_info__().default_dtypes()["real floating"] + assert out.dtype == dt + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])