|
import sys, re, json, pprint |
|
from pathlib import Path |
|
from collections import defaultdict |
|
from inspect import currentframe, getframeinfo, ismodule |
|
|
|
__all__ = ['this_tests'] |
|
|
|
DB_NAME = 'test_registry.json' |
|
|
|
def _json_set_default(obj): |
|
if isinstance(obj, set): return list(obj) |
|
raise TypeError |
|
|
|
class TestRegistry: |
|
"Tests register which API they validate using this class." |
|
registry = defaultdict(list) |
|
this_tests_check = None |
|
missing_this_tests = set() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def this_tests(*funcs): |
|
prev_frame = currentframe().f_back.f_back |
|
file_name, lineno, test_name, _, _ = getframeinfo(prev_frame) |
|
parent_func_lineno, _ = get_parent_func(lineno, get_lines(file_name)) |
|
entry = {'file': relative_test_path(file_name), 'test': test_name , 'line': parent_func_lineno} |
|
for func in funcs: |
|
if func == 'na': |
|
|
|
|
|
continue |
|
try: |
|
func_fq = get_func_fq_name(func) |
|
except: |
|
raise Exception(f"'{func}' is not a function") from None |
|
if re.match(r'fastai\.', func_fq): |
|
if entry not in TestRegistry.registry[func_fq]: |
|
TestRegistry.registry[func_fq].append(entry) |
|
else: |
|
raise Exception(f"'{func}' is not in the fastai API") from None |
|
TestRegistry.this_tests_check = False |
|
|
|
def this_tests_check_on(): |
|
TestRegistry.this_tests_check = True |
|
|
|
def this_tests_check_off(): |
|
TestRegistry.this_tests_check = False |
|
|
|
def this_tests_check_run(file_name, test_name): |
|
if TestRegistry.this_tests_check: |
|
TestRegistry.missing_this_tests.add(f"{file_name}::{test_name}") |
|
|
|
def registry_save(): |
|
if TestRegistry.registry: |
|
path = Path(__file__).parent.parent.resolve()/DB_NAME |
|
if path.exists(): |
|
|
|
with open(path, 'r') as f: old_registry = json.load(f) |
|
TestRegistry.registry = merge_registries(old_registry, TestRegistry.registry) |
|
|
|
with open(path, 'w') as f: |
|
json.dump(obj=TestRegistry.registry, fp=f, indent=4, sort_keys=True, default=_json_set_default) |
|
|
|
def missing_this_tests_alert(): |
|
if TestRegistry.missing_this_tests: |
|
tests = '\n '.join(sorted(TestRegistry.missing_this_tests)) |
|
print(f""" |
|
*** Attention *** |
|
Please include `this_tests` call in each of the following tests: |
|
{tests} |
|
For details see: https://docs.fast.ai/dev/test.html#test-registry""") |
|
|
|
|
|
|
|
def a2k(a): return '::'.join([a['file'], a['test']]), a['line'] |
|
def k2a(k, v): f,t = k.split('::'); return {"file": f, "line": v, "test": t} |
|
|
|
def merge_lists(a, b): |
|
x = dict(map(a2k, [*a, *b])) |
|
return [k2a(k, v) for k,v in x.items()] |
|
def merge_registries(a, b): |
|
for i in b: a[i] = merge_lists(a[i], b[i]) if i in a else b[i] |
|
return a |
|
|
|
def this_tests(*funcs): TestRegistry.this_tests(*funcs) |
|
|
|
def str2func(name): |
|
"Converts 'fastai.foo.bar' into an function 'object' if such exists" |
|
if isinstance(name, str): subpaths = name.split('.') |
|
else: return None |
|
|
|
module = subpaths.pop(0) |
|
if module in sys.modules: obj = sys.modules[module] |
|
else: return None |
|
|
|
for subpath in subpaths: |
|
obj = getattr(obj, subpath, None) |
|
if obj == None: return None |
|
return obj |
|
|
|
def get_func_fq_name(func): |
|
if ismodule(func): return func.__name__ |
|
if isinstance(func, str): func = str2func(func) |
|
name = None |
|
if hasattr(func, '__qualname__'): name = func.__qualname__ |
|
elif hasattr(func, '__name__'): name = func.__name__ |
|
elif hasattr(func, '__wrapped__'): return get_func_fq_name(func.__wrapped__) |
|
elif hasattr(func, '__class__'): name = func.__class__.__name__ |
|
else: raise Exception(f"'{func}' is not a func or class") |
|
return f'{func.__module__}.{name}' |
|
|
|
def get_parent_func(lineno, lines, ignore_missing=False): |
|
"Find any lines where `elt` is called and return the parent test function" |
|
for idx,l in enumerate(reversed(lines[:lineno])): |
|
if re.match(f'\s*def test', l): return (lineno - idx), l |
|
if re.match(f'\w+', l): break |
|
if ignore_missing: return None |
|
raise LookupError('Could not find parent function for line:', lineno, lines[:lineno]) |
|
|
|
def relative_test_path(test_file:Path)->str: |
|
"Path relative to the `fastai` parent directory" |
|
test_file = Path(test_file) |
|
testdir_idx = list(reversed(test_file.parts)).index('tests') |
|
return '/'.join(test_file.parts[-(testdir_idx+1):]) |
|
|
|
def get_lines(file): |
|
with open(file, 'r') as f: return f.readlines() |
|
|