Skip to content

Commit b07f1b6

Browse files
authored
Merge pull request #107 from ev-br/pytest_collection_modify
Refactor `pytest_collection_modify` hook
2 parents d148aa5 + 8e5df5b commit b07f1b6

File tree

2 files changed

+48
-63
lines changed

2 files changed

+48
-63
lines changed

scpdt/impl.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class DTConfig:
6868
NameErrors. Set to True if you want to see these, or if your test
6969
is actually expected to raise NameErrors.
7070
Default is False.
71+
pytest_extra_skips : list
72+
A list of names/modules to skip when run under pytest plugin. Ignored
73+
otherwise.
7174
7275
"""
7376
def __init__(self, *, # DTChecker configuration
@@ -88,6 +91,7 @@ def __init__(self, *, # DTChecker configuration
8891
# Obscure switches
8992
parse_namedtuples=True, # Checker
9093
nameerror_after_exception=False, # Runner
94+
pytest_extra_skips=None, # plugin/collection
9195
):
9296
### DTChecker configuration ###
9397
# The namespace to run examples in
@@ -141,7 +145,8 @@ def __init__(self, *, # DTChecker configuration
141145
'set_title', 'imshow', 'plt.show', '.axis(', '.plot(',
142146
'.bar(', '.title', '.ylabel', '.xlabel', 'set_ylim', 'set_xlim',
143147
'# reformatted', '.set_xlabel(', '.set_ylabel(', '.set_zlabel(',
144-
'.set(xlim=', '.set(ylim=', '.set(xlabel=', '.set(ylabel=', '.xlim('}
148+
'.set(xlim=', '.set(ylim=', '.set(xlabel=', '.set(ylabel=', '.xlim('
149+
'ax.set('}
145150
self.stopwords = stopwords
146151

147152
if pseudocode is None:
@@ -170,6 +175,11 @@ def __init__(self, *, # DTChecker configuration
170175
self.parse_namedtuples = parse_namedtuples
171176
self.nameerror_after_exception = nameerror_after_exception
172177

178+
#### pytest plugin additional switches
179+
if pytest_extra_skips is None:
180+
pytest_extra_skips = []
181+
self.pytest_extra_skips = pytest_extra_skips
182+
173183

174184
def try_convert_namedtuple(got):
175185
# suppose that "got" is smth like MoodResult(statistic=10, pvalue=0.1).

scpdt/plugin.py

+37-62
Original file line numberDiff line numberDiff line change
@@ -57,81 +57,55 @@ def pytest_ignore_collect(collection_path, config):
5757
path_str = str(collection_path)
5858
if "tests" in path_str or "test_" in path_str:
5959
return True
60-
61-
60+
61+
6262
def pytest_collection_modifyitems(config, items):
6363
"""
6464
This hook is executed after test collection and allows you to modify the list of collected items.
6565
66-
The function removes duplicate Doctest items.
66+
The function removes
67+
- duplicate Doctest items (e.g., scipy.stats.norm and scipy.stats.distributions.norm)
68+
- Doctest items from underscored or otherwise private modules (e.g., scipy.special._precompute)
6769
68-
Doctest items are collected from all public modules, including the __all__ attribute in __init__.py.
69-
This may lead to Doctest items being collected and tested more than once.
70-
We therefore need to remove the duplicate items by creating a new list with only unique items.
70+
Note that this functions cooperates with and cleans up after `DTModule.collect`, which does the
71+
bulk of the collection work.
7172
"""
73+
# XXX: The logic in this function can probably be folded into DTModule.collect.
74+
# I (E.B.) quickly tried it and it does not seem to just work. Apparently something
75+
# pytest-y runs in between DTModule.collect and this hook (should that something
76+
# be the proper home for all collection?)
77+
7278
if config.getoption("--doctest-modules"):
73-
seen_test_names = set()
7479
unique_items = []
7580

7681
for item in items:
77-
# Extract the item name, e.g., 'gauss_spline'
78-
# Example item: <DoctestItem scipy.signal._bsplines.gauss_spline>
79-
item_name = str(item).split('.')[-1].strip('>')
80-
81-
# In case the preceding string represents a function or a class,
82-
# We need to keep the object name as both items represent different functions
83-
# eg: <DoctestItem scipy.signal._ltisys.bode>
84-
# <DoctestItem scipy.signal._ltisys.lti.bode>
85-
obj_name = str(item).split('.')[-2]
86-
87-
# Extract the module path from the item's dtest attribute
88-
# Example dtest: <DocTest scipy.signal.__init__.gauss_spline from /scipy/build-install/lib/python3.10/site-packages/scipy/signal/_bsplines.py:226 (5 examples)>
89-
dtest = item.dtest
90-
path = str(dtest).split(' ')[3].split(':')[0]
91-
92-
# Import the module to check if the object name is an attribute of the module
93-
try:
94-
module = import_path(
95-
path,
96-
root=config.rootpath,
97-
mode=config.getoption("importmode"),
98-
)
99-
except ImportError:
100-
module = None
101-
102-
# Combine the module path, object name (if it exists) and item name to create a unique identifier
103-
if module is not None and obj_name != '__init__' and hasattr(module, obj_name) and callable(getattr(module, obj_name)) and obj_name != item_name:
104-
unique_test_name = f"{path}/{obj_name}.{item_name}"
105-
else:
106-
unique_test_name = f"{path}/{item_name}"
107-
108-
# Check if the test name is unique and add it to the unique_items list if it is
109-
if unique_test_name not in seen_test_names:
110-
seen_test_names.add(unique_test_name)
82+
assert isinstance(item.parent, DTModule)
83+
84+
# objects are collected twice: from their public module + from the impl module
85+
# e.g. for `levy_stable` we have
86+
# (Pdb) p item.name, item.parent.name
87+
# ('scipy.stats.levy_stable', 'build-install/lib/python3.10/site-packages/scipy/stats/__init__.py')
88+
# ('scipy.stats.distributions.levy_stable', 'distributions.py')
89+
# so we filter out the second occurence
90+
#
91+
# There are two options:
92+
# - either the impl module has a leading underscore, or
93+
# - it needs to be explicitly listed in 'extra_skips' config key
94+
#
95+
# Note that the last part cannot be automated: scipy.cluster.vq is public, but
96+
# scipy.stats.distributions is not
97+
extra_skips = config.dt_config.pytest_extra_skips
98+
99+
parent_full_name = item.parent.module.__name__
100+
is_public = "._" not in parent_full_name
101+
is_duplicate = parent_full_name in extra_skips or item.name in extra_skips
102+
103+
if is_public and not is_duplicate:
111104
unique_items.append(item)
112105

113106
# Replace the original list of test items with the unique ones
114107
items[:] = unique_items
115108

116-
# Generate a log of the unique items to be doctested
117-
# Extract the DoctestItem name
118-
for item in items:
119-
dtest = item.dtest
120-
path = str(dtest).split(' ')[3].split(':')[0]
121-
122-
# Import the module being doctested
123-
try:
124-
module = import_path(
125-
path,
126-
root=config.rootpath,
127-
mode=config.getoption("importmode"),
128-
)
129-
except ImportError:
130-
module = None
131-
132-
# Use the module and item name to generate a log entry
133-
generate_log(module, item.name)
134-
135109

136110
def copy_local_files(local_resources, destination_dir):
137111
"""
@@ -202,11 +176,12 @@ def collect(self):
202176
optionflags=optionflags,
203177
checker=DTChecker(config=self.config.dt_config)
204178
)
205-
179+
206180
try:
207181
# We utilize scpdt's `find_doctests` function to discover doctests in public, non-deprecated objects in the module
182+
# NB: additional postprocessing in pytest_collection_modifyitems
208183
for test in find_doctests(module, strategy="api", name=module.__name__, config=dt_config):
209-
if test.examples: # skip empty doctests
184+
# if test.examples: # skip empty doctests # FIXME: put this back (simplifies comparing the logs)
210185
yield doctest.DoctestItem.from_parent(
211186
self, name=test.name, runner=runner, dtest=test
212187
)

0 commit comments

Comments
 (0)