1 # -*- coding: utf-8 -*-
2 # the above encoding declaration is needed to have non-ascii characters in this file (anywhere even in comments)
3 # from __future__ import unicode_literals # no, since we want to match the return type of str() which is bytes in py2
4 import sys
5
6 import numpy as np
7 import pandas as pd
8 import pytest
9
10 from qdscreen import QDForest, qd_screen
11 from qdscreen.compat import PY2
12 from qdscreen.main import get_adjacency_matrix, remove_redundancies
13 from qdscreen.sklearn import QDScreen
14
15
16 def df_strict1():
17 """ A dataframe with two equivalent redundant variables X and Y """
18 df = pd.DataFrame({
19 'X': ["a", "a", "b", "b", "a", "c", "c", "a", "b", "c"],
20 })
21 # 'Y': ["b", "b", "c", "c", "b", "a", "a", "b", "c", "a"],
22 df['Y'] = df['X'].replace("c", "d").replace("b", "c").replace("a", "b").replace("d", "a")
23 return df
24
25
26 def test_adjacency_strict():
27 """ Tests that `get_adjacency_matrix` works as expected in front of a dataset with redundancy """
28
29 # Compute the adjacency matrix
30 adj_df, df_stats = get_adjacency_matrix(df_strict1())
31 pd.testing.assert_frame_equal(adj_df, pd.DataFrame(data=[
32 [False, True],
33 [True, False]
34 ], index=['X', 'Y'], columns=['X', 'Y']))
35
36
37 def quasi_df1():
38 # in that matrix the H(X|Y) = H(Y|X) = 0.324511 and the relative H(X|Y)/H(X) = H(Y|X)/H(Y) = 0.20657
39 return pd.DataFrame({
40 'X': ["a", "a", "b", "b", "a", "c", "c", "a", "b", "c"],
41 'Y': ["a", "b", "c", "c", "b", "a", "a", "b", "c", "a"],
42 })
43
44
45 def test_adjacency_invalid_thresholds():
46 # an error is raised if both are provided, even if both are zero
47 with pytest.raises(ValueError):
48 get_adjacency_matrix(pd.DataFrame(), eps_absolute=0., eps_relative=0.)
49
50
51 @pytest.mark.parametrize("eps_absolute, eps_relative", [
52 (None, None), # strict mode
53 (0.2, None), # too low absolute
54 (None, 0.15) # too low relative
55 ])
56 def test_adjacency_quasi_low_threshold(eps_absolute, eps_relative):
57 """ Tests that `get_adjacency_matrix` works as expected in quasi mode """
58
59 # in that matrix the H(X|Y) = H(Y|X) = 0.324511 and the relative H(X|Y)/H(X) = H(Y|X)/H(Y) = 0.20657
60 # strict mode or with too low threshold: no arc is detected
61 adj_df, df_stats = get_adjacency_matrix(quasi_df1(), eps_absolute=eps_absolute, eps_relative=eps_relative)
62 pd.testing.assert_frame_equal(adj_df, pd.DataFrame(data=[
63 [False, False],
64 [False, False]
65 ], index=['X', 'Y'], columns=['X', 'Y']))
66
67
68 @pytest.mark.parametrize("eps_absolute, eps_relative", [
69 (0.33, None), # high enough absolute
70 (None, 0.21) # high enough relative
71 ])
72 def test_adjacency_quasi_high_threshold(eps_absolute, eps_relative):
73 """ Tests that `get_adjacency_matrix` works as expected in quasi mode """
74
75 # in that matrix the H(X|Y) = H(Y|X) = 0.324511 and the relative H(X|Y)/H(X) = H(Y|X)/H(Y) = 0.20657
76 # strict mode or with too low threshold: no arc is detected
77 adj_df, df_stats = get_adjacency_matrix(quasi_df1(), eps_absolute=eps_absolute, eps_relative=eps_relative)
78 pd.testing.assert_frame_equal(adj_df, pd.DataFrame(data=[
79 [False, True],
80 [True, False]
81 ], index=['X', 'Y'], columns=['X', 'Y']))
82
83
84 def test_remove_redundancies():
85 """ Tests that the redundancies removal routine works as expected """
86
87 # an adjacency matrix with two redundant nodes X and Y
88 adj_df = pd.DataFrame(data=[
89 [False, True],
90 [True, False]
91 ], index=['X', 'Y'], columns=['X', 'Y'])
92
93 # clean the redundancies with natural order: X is the representative of the class
94 adj_df_clean1 = remove_redundancies(adj_df)
95 pd.testing.assert_frame_equal(adj_df_clean1, pd.DataFrame(data=[
96 [False, True],
97 [False, False]
98 ], index=['X', 'Y'], columns=['X', 'Y']))
99
100 # clean the redundancies with reverse order: Y is the representative of the class
101 adj_df_clean2 = remove_redundancies(adj_df, selection_order=[1, 0])
102 pd.testing.assert_frame_equal(adj_df_clean2, pd.DataFrame(data=[
103 [False, False],
104 [True, False]
105 ], index=['X', 'Y'], columns=['X', 'Y']))
106
107
108 # def test_identify_redundancy_quasi():
109 # df = pd.DataFrame({
110 # 'U': ["a", "b", "d", "a", "b", "c", "a", "b", "d", "c"],
111 # })
112 # df['V'] = df['U'].replace("d", "c") # d -> c
113 # df['W'] = df['V'].replace("c", "b") # c -> b
114 #
115 # adj_df = get_adjacency_matrix(df)
116 # df2 = identify_redundancy(adj_df)
117
118
119 def get_qd_forest1(is_np):
120 """
121 Created a forest with two trees 3->5 and 9->(1,7)
122
123 :param is_np: if true
124 :return:
125 """
126 adjmat_ar = adjmat = np.zeros((10, 10), dtype=bool)
127 adjmat[1, 8] = True
128 adjmat[3, 5] = True
129 adjmat[9, 1] = True
130 adjmat[9, 7] = True
131
132 parents_ar = parents = -np.ones((10,),
-
E501
Line too long (124 > 120 characters)
133 dtype=np.int64) # indeed computing parents from adjmat with np.where returns this dtype
134 parents[5] = 3
135 parents[1] = 9
136 parents[7] = 9
137 parents[8] = 1
138
139 roots_ar = roots = np.array([0, 2, 3, 4, 6, 9])
140 roots_wc_ar = roots_wc = np.array([3, 9])
141
142 if not is_np:
143 varnames = list("abcdefghij")
144 adjmat = pd.DataFrame(adjmat_ar, columns=varnames, index=varnames)
145 parents = pd.DataFrame(parents_ar, index=varnames, columns=('idx',))
146 parents['name'] = parents.index[parents['idx']].where(parents['idx'] >= 0, None)
147 roots = np.array(varnames)[roots_ar]
148 roots_wc = np.array(varnames)[roots_wc_ar]
149
150 return adjmat, adjmat_ar, parents, parents_ar, roots, roots_ar, roots_wc, roots_wc_ar
151
152
153 @pytest.mark.parametrize("from_adjmat", [True, False], ids="from_adjmat={}".format)
154 @pytest.mark.parametrize("is_np", [True, False], ids="is_np={}".format)
155 def test_qd_forest(is_np, from_adjmat):
156 """Tests that QDForest works correctly whether created from adj matrix or parents list"""
157
158 adjmat, adjmat_ar, parents, parents_ar, roots, roots_ar, roots_wc, roots_wc_ar = get_qd_forest1(is_np)
159
160 if from_adjmat:
161 qd1 = QDForest(adjmat=adjmat) # a forest created from the adj matrix
162 else:
163 qd1 = QDForest(parents=parents) # a forest created from the parents coordinates
164
165 # roots
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
166 assert qd1.get_roots(names=False) == list(roots_ar)
167 np.testing.assert_array_equal(qd1.indices_to_mask(roots_ar), qd1.roots_mask_ar)
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
168 assert qd1.get_roots_with_children(names=False) == list(roots_wc_ar)
169 if not is_np:
170 pd.testing.assert_series_equal(qd1.indices_to_mask(roots), qd1.roots_mask)
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
171 assert qd1.get_roots(names=True) == list(roots)
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
172 assert qd1.get_roots_with_children(names=True) == list(roots_wc)
173 else:
174 with pytest.raises(ValueError):
175 qd1.get_roots(names=True)
176 with pytest.raises(ValueError):
177 qd1.get_roots_with_children(names=True)
178 # default
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
179 assert qd1.get_roots() == qd1.get_roots(names=not is_np)
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
180 assert qd1.get_roots_with_children() == qd1.get_roots_with_children(names=not is_np)
181
182 # string representation of arcs
183 # -- indices
184 if from_adjmat:
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
185 assert qd1.get_arcs_str_list(names=False) == ['1 -> 8', '3 -> 5', '9 -> 1', '9 -> 7']
186 else:
187 # make sure the adj matrix was not computed automatically, to test the right method
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
188 assert qd1._adjmat is None
189 # TODO order is not the same as above, see https://github.com/python-qds/qdscreen/issues/9
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
190 assert qd1.get_arcs_str_list(names=False) == ['9 -> 1', '3 -> 5', '9 -> 7', '1 -> 8']
191 # -- names
192 if not is_np:
193 if from_adjmat:
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
194 assert qd1.get_arcs_str_list(names=True) == ['b -> i', 'd -> f', 'j -> b', 'j -> h']
195 else:
196 # make sure the adj matrix was not computed automatically, to test the right method
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
197 assert qd1._adjmat is None
198 # TODO order is not the same as above, see https://github.com/python-qds/qdscreen/issues/9
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
199 assert qd1.get_arcs_str_list(names=True) == ['j -> b', 'd -> f', 'j -> h', 'b -> i']
200 else:
201 with pytest.raises(ValueError):
202 qd1.get_arcs_str_list(names=True)
203 # check the default value of `names`
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
204 assert qd1.get_arcs_str_list() == qd1.get_arcs_str_list(names=not is_np)
205
206 # equivalent adjmat and parents computation, to be sure
207 np.testing.assert_array_equal(qd1.parents_indices_ar, parents_ar)
208 if not is_np:
209 pd.testing.assert_frame_equal(qd1.parents, parents)
210
211 # equivalent adjmat computation, to be sure
212 np.testing.assert_array_equal(qd1.adjmat_ar, adjmat_ar)
213 if not is_np:
214 pd.testing.assert_frame_equal(qd1.adjmat, adjmat)
215
216
217 @pytest.mark.parametrize("from_adjmat", [True, False])
218 @pytest.mark.parametrize("is_np", [True, False])
219 def test_qd_forest_str(is_np, from_adjmat):
220 """Tests the string representation """
221
222 adjmat, adjmat_ar, parents, parents_ar, roots, roots_ar, roots_wc, roots_wc_ar = get_qd_forest1(is_np)
223
224 if from_adjmat:
225 qd1 = QDForest(adjmat=adjmat) # a forest created from the adj matrix
226 else:
227 qd1 = QDForest(parents=parents) # a forest created from the parents coordinates
228
229 # string representation
230 compact_str = qd1.to_str(mode="compact")
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
231 assert compact_str == "QDForest (10 vars = 6 roots + 4 determined by 2 of the roots)"
232
233 roots_str = "0, 2, 3*, 4, 6, 9*" if is_np else "a, c, d*, e, g, j*"
234 others_str = "1, 5, 7, 8" if is_np else "b, f, h, i"
235 headers_str = qd1.to_str(mode="headers")
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
236 assert headers_str == """QDForest (10 vars):
237 - 6 roots (4+2*): %s
238 - 4 other nodes: %s""" % (roots_str, others_str)
239 # this should be the default
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
240 assert qd1.to_str() == headers_str
241
242 trees_str = "\n" + "\n".join(qd1.get_trees_str_list())
243 if is_np:
244 # note the u for python 2 as in main.py we use unicode literals to cope with those non-base chars
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
245 assert trees_str == u"""
246 3
247 └─ 5
248
249 9
250 └─ 1
251 └─ 8
252 └─ 7
253 """
254 else:
255 # note the u for python 2 as in main.py we use unicode literals to cope with those non-base chars
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
256 assert trees_str == u"""
257 d
258 └─ f
259
260 j
261 └─ b
262 └─ i
263 └─ h
264 """
265 full_str = qd1.to_str(mode="full")
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
266 assert full_str == headers_str + u"\n" + trees_str
267 # this should be the default string representation if the nb vars is small enough
268 if PY2:
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
269 assert full_str.encode('utf-8') == str(qd1)
270 else:
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
271 assert full_str == str(qd1)
272
273
274 # def test_sklearn_compat():
275 # """Trying to make sure that this compatibility code works: it does NOT with scikit-learn 0.22.2.post1 :("""
276 # from qdscreen.compat import BaseEstimator
277 # assert BaseEstimator()._more_tags()['requires_y'] is False
278 # assert BaseEstimator()._get_tags()['requires_y'] is False
279
280
281 def test_nans_in_data():
282 """See https://github.com/python-qds/qdscreen/issues/28"""
283
284 df = pd.DataFrame([
285 ["A", "B"],
286 ["A", "B"],
287 ["N", np.nan],
288 ], columns=["a", "b"])
289
290 qd_forest = qd_screen(df)
291 feat_selector = qd_forest.fit_selector_model(df)
292 df_sel = feat_selector.remove_qd(df)
293 pd.testing.assert_frame_equal(df_sel, pd.DataFrame([
294 ["A"],
295 ["A"],
296 ["N"],
297 ], columns=["a"]))
298
299
300 def test_nans_in_data_sklearn():
301 """See https://github.com/python-qds/qdscreen/issues/28"""
302
303 df = pd.DataFrame([
304 ["A", "B"],
305 ["A", "B"],
306 ["N", np.nan],
307 ])
308
309 selector = QDScreen()
310 Xsel = selector.fit_transform(df.to_numpy())
311
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
312 assert Xsel.tolist() == [['A'], ['A'], ['N']]
313
314
315 def test_issue_37_non_categorical():
316 df = pd.DataFrame({
317 "nb": [1, 2],
318 "name": ["A", "B"]
319 })
320 with pytest.raises(ValueError, match="Provided dataframe columns contains non-categorical"):
321 qd_screen(df)
322
323
324 @pytest.mark.skipif(sys.version_info < (3, 6),
325 reason="This test is known to fail for 3.5 and 2.7, see GH#43")
326 def test_issue_40_nan_then_str():
327 df = pd.DataFrame({
328 "foo": ["1", "2"],
329 "bar": [np.nan, "B"]
330 })
331 qd_forest = qd_screen(df)
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
332 assert list(qd_forest.roots) == ["foo"]
333
334 feat_selector = qd_forest.fit_selector_model(df)
335 only_important_features_df = feat_selector.remove_qd(df)
-
S101
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
336 assert list(only_important_features_df.columns) == ["foo"]
337
338 result = feat_selector.predict_qd(only_important_features_df)
339 pd.testing.assert_frame_equal(df, result)