summaryrefslogtreecommitdiff
blob: f3b2ea3ef0ca1c0c9c813d9060ef3d99b4b6a5b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
From ad2a73c18dcff95d844c382c94ab7f73b5571cf3 Mon Sep 17 00:00:00 2001
From: Sebastian Berg <sebastian@sipsolutions.net>
Date: Tue, 4 May 2021 17:43:26 -0500
Subject: [PATCH] MAINT: Adjust NumPy float hashing to Python's slightly
 changed hash

This is necessary, since we use the Python double hash and the
semi-private function to calculate it in Python has a new signature
to return the identity-hash when the value is NaN.

closes gh-18833, gh-18907
---
 numpy/core/src/common/npy_pycompat.h        | 16 ++++++++++
 numpy/core/src/multiarray/scalartypes.c.src | 13 ++++----
 numpy/core/tests/test_scalarmath.py         | 34 +++++++++++++++++++++
 3 files changed, 57 insertions(+), 6 deletions(-)

diff --git a/numpy/core/src/common/npy_pycompat.h b/numpy/core/src/common/npy_pycompat.h
index aa0b5c1224d..9e94a971090 100644
--- a/numpy/core/src/common/npy_pycompat.h
+++ b/numpy/core/src/common/npy_pycompat.h
@@ -3,4 +3,20 @@
 
 #include "numpy/npy_3kcompat.h"
 
+
+/*
+ * In Python 3.10a7 (or b1), python started using the identity for the hash
+ * when a value is NaN.  See https://bugs.python.org/issue43475
+ */
+#if PY_VERSION_HEX > 0x030a00a6
+#define Npy_HashDouble _Py_HashDouble
+#else
+static NPY_INLINE Py_hash_t
+Npy_HashDouble(PyObject *NPY_UNUSED(identity), double val)
+{
+    return _Py_HashDouble(val);
+}
+#endif
+
+
 #endif /* _NPY_COMPAT_H_ */
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index a001500b0a9..9930f7791d6 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -3172,7 +3172,7 @@ static npy_hash_t
 static npy_hash_t
 @lname@_arrtype_hash(PyObject *obj)
 {
-    return _Py_HashDouble((double) PyArrayScalar_VAL(obj, @name@));
+    return Npy_HashDouble(obj, (double)PyArrayScalar_VAL(obj, @name@));
 }
 
 /* borrowed from complex_hash */
@@ -3180,14 +3180,14 @@ static npy_hash_t
 c@lname@_arrtype_hash(PyObject *obj)
 {
     npy_hash_t hashreal, hashimag, combined;
-    hashreal = _Py_HashDouble((double)
-            PyArrayScalar_VAL(obj, C@name@).real);
+    hashreal = Npy_HashDouble(
+            obj, (double)PyArrayScalar_VAL(obj, C@name@).real);
 
     if (hashreal == -1) {
         return -1;
     }
-    hashimag = _Py_HashDouble((double)
-            PyArrayScalar_VAL(obj, C@name@).imag);
+    hashimag = Npy_HashDouble(
+            obj, (double)PyArrayScalar_VAL(obj, C@name@).imag);
     if (hashimag == -1) {
         return -1;
     }
@@ -3202,7 +3202,8 @@ c@lname@_arrtype_hash(PyObject *obj)
 static npy_hash_t
 half_arrtype_hash(PyObject *obj)
 {
-    return _Py_HashDouble(npy_half_to_double(PyArrayScalar_VAL(obj, Half)));
+    return Npy_HashDouble(
+            obj, npy_half_to_double(PyArrayScalar_VAL(obj, Half)));
 }
 
 static npy_hash_t
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index d91b4a39146..09a734284a7 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -712,6 +712,40 @@ def test_shift_all_bits(self, type_code, op):
                 assert_equal(res_arr, res_scl)
 
 
+class TestHash:
+    @pytest.mark.parametrize("type_code", np.typecodes['AllInteger'])
+    def test_integer_hashes(self, type_code):
+        scalar = np.dtype(type_code).type
+        for i in range(128):
+            assert hash(i) == hash(scalar(i))
+
+    @pytest.mark.parametrize("type_code", np.typecodes['AllFloat'])
+    def test_float_and_complex_hashes(self, type_code):
+        scalar = np.dtype(type_code).type
+        for val in [np.pi, np.inf, 3, 6.]:
+            numpy_val = scalar(val)
+            # Cast back to Python, in case the NumPy scalar has less precision
+            if numpy_val.dtype.kind == 'c':
+                val = complex(numpy_val)
+            else:
+                val = float(numpy_val)
+            assert val == numpy_val
+            print(repr(numpy_val), repr(val))
+            assert hash(val) == hash(numpy_val)
+
+        if hash(float(np.nan)) != hash(float(np.nan)):
+            # If Python distinguises different NaNs we do so too (gh-18833)
+            assert hash(scalar(np.nan)) != hash(scalar(np.nan))
+
+    @pytest.mark.parametrize("type_code", np.typecodes['Complex'])
+    def test_complex_hashes(self, type_code):
+        # Test some complex valued hashes specifically:
+        scalar = np.dtype(type_code).type
+        for val in [np.pi+1j, np.inf-3j, 3j, 6.+1j]:
+            numpy_val = scalar(val)
+            assert hash(complex(numpy_val)) == hash(numpy_val)
+
+
 @contextlib.contextmanager
 def recursionlimit(n):
     o = sys.getrecursionlimit()