Skip to content

Commit 3d5c30d

Browse files
committed
ColumnTransformer expects list of column names, not a single string
1 parent f862e94 commit 3d5c30d

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

searchgrid.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ def _name_steps(steps, default='alt'):
176176

177177
def _name_of_estimator(estimator):
178178
if isinstance(estimator, tuple):
179-
# tuples comes from ColumnTransformers. At the moment,
180-
# sklearn accepts both (estimator, 'name') and ('name', estimator)
179+
# tuples comes from ColumnTransformers. At the moment, sklearn accepts
180+
# both (estimator, list_of_columns) and (list_of_columns, estimator)
181181
tuple_types = {type(tuple_entry) for tuple_entry in estimator}
182-
tuple_types.discard(str)
182+
tuple_types.discard(list)
183183
estimator_type = tuple_types.pop()
184184
else:
185185
estimator_type = type(estimator)
@@ -281,7 +281,7 @@ def make_column_transformer(*transformers, **kwargs):
281281
steps
282282
Each step is specified as one of:
283283
284-
* an (estimator, column_name) or (column_name, estimator) tuple
284+
* an (estimator, [column_names]) or ([column_names], estimator) tuple
285285
* None (meaning no features)
286286
* a list of the above, indicating that a grid search should alternate
287287
over the estimators (or None) in the list

test_searchgrid.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ def test_make_pipeline():
114114

115115

116116
def test_make_column_transformer():
117-
t1 = (SelectKBest(), 'column1')
118-
t2 = (SelectKBest(), 'column2')
119-
t3 = (SelectKBest(), 'column3')
120-
t4 = (SelectKBest(), 'column4')
121-
t5 = (SelectPercentile(), 'column5')
122-
t6 = (SelectKBest(), 'column6')
123-
t7 = (SelectKBest(), 'column7')
124-
t8 = (SelectKBest(), 'column8')
125-
t9 = (SelectPercentile(), 'column9')
117+
t1 = (SelectKBest(), ['column1'])
118+
t2 = (SelectKBest(), ['column2a', 'column2b'])
119+
t3 = (SelectKBest(), ['column3'])
120+
t4 = (SelectKBest(), ['column4'])
121+
t5 = (SelectPercentile(), ['column5'])
122+
t6 = (SelectKBest(), ['column6a', 'column6b'])
123+
t7 = (SelectKBest(), ['column7'])
124+
t8 = (SelectKBest(), ['column8'])
125+
t9 = (SelectPercentile(), ['column9'])
126126

127127
in_steps = [[t1, None],
128128
[t2, t3],

0 commit comments

Comments
 (0)