From 7856dba05a13230f90de4cc3964ece1f79201eaf Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:02:24 -0500 Subject: [PATCH 1/4] Added Import Functionality --- src/inline/plugin.py | 60 +++++++++++++++++++++++++++++++++++++------- tests/test_plugin.py | 55 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index f8ddfc1..f2e3189 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -159,6 +159,7 @@ def __init__(self): self.check_stmts = [] self.given_stmts = [] self.previous_stmts = [] + self.import_stmts = [] self.prev_stmt_type = PrevStmtType.StmtExpr # the line number of test statement self.lineno = 0 @@ -174,10 +175,18 @@ def __init__(self): self.devices = None self.globs = {} + def write_imports(self): + import_str = "" + for n in self.import_stmts: + import_str += ExtractInlineTest.node_to_source_code(n) + "\n" + return import_str + def to_test(self): + prefix = "\n" + if self.prev_stmt_type == PrevStmtType.CondExpr: if self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] ) @@ -187,11 +196,11 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join(ExtractInlineTest.node_to_source_code(assume_node)) + return prefix.join(ExtractInlineTest.node_to_source_code(assume_node)) else: if self.assume_stmts is None or self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] @@ -202,7 +211,7 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)]) + return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)]) def build_assume_node(self, assumption_node, body_nodes): return ast.If(assumption_node, body_nodes, []) @@ -296,6 +305,11 @@ class ExtractInlineTest(ast.NodeTransformer): arg_timeout_str = "timeout" assume = "assume" + + import_str = "import" + from_str = "from" + as_str = "as" + inline_module_imported = False def __init__(self): @@ -360,6 +374,23 @@ def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]): inline_test_calls.append(node) self.collect_inline_test_calls(node.func, inline_test_calls) + def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): + """ + collect all import calls in the node (should be done first) + """ + + while not isinstance(node, ast.Module) and node.parent != None: + node = node.parent + + if not isinstance(node, ast.Module): + return + + for child in node.children: + if isinstance(child, ast.Import): + import_calls.append(child) + elif isinstance(child, ast.ImportFrom): + import_from_calls.append(child) + def parse_constructor(self, node): """ Parse a constructor call. @@ -931,8 +962,13 @@ def parse_parameterized_test(self): parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index) def parse_inline_test(self, node): - inline_test_calls = [] + import_calls = [] + import_from_calls = [] + inline_test_calls = [] + self.collect_inline_test_calls(node, inline_test_calls) + self.collect_import_calls(node, import_calls, import_from_calls) + inline_test_calls.reverse() if len(inline_test_calls) <= 1: @@ -953,14 +989,20 @@ def parse_inline_test(self, node): self.parse_assume(call) inline_test_call_index += 1 - # "given(a, 1)" for call in inline_test_calls[inline_test_call_index:]: - if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str: - self.parse_given(call) - inline_test_call_index += 1 + if isinstance(call.func, ast.Attribute): + if call.func.attr == self.given_str: + self.parse_given(call) + inline_test_call_index += 1 else: break + for import_stmt in import_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + for import_stmt in import_from_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + + # "check_eq" or "check_true" or "check_false" or "check_neq" for call in inline_test_calls[inline_test_call_index:]: # "check_eq(a, 1)" diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 40c3096..37e8c99 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -31,6 +31,61 @@ def m(a): items, reprec = pytester.inline_genitems(x) assert len(items) == 0 + def test_inline_detects_imports(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + import datetime + + def m(a): + b = a + datetime.timedelta(days=365) + itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366)) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + def test_inline_detects_import_alias(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + import datetime as dt + + def m(a): + b = a + dt.timedelta(days=365) + itest().given(a, dt.timedelta(days=1)).check_eq(b, dt.timedelta(days=366)) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + def test_inline_detects_from_imports(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + from enum import Enum + + class Choice(Enum): + YES = 0 + NO = 1 + + def m(a): + b = a + itest().given(a, Choice.YES).check_eq(b, Choice.YES) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret == 0 + def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """ From 1a9845003aa23cb5bcaa265cfcbc56d750d7843f Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:33:12 -0500 Subject: [PATCH 2/4] Added Failed Import Test --- tests/test_plugin.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 37e8c99..953b61f 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -74,7 +74,7 @@ def test_inline_detects_from_imports(self, pytester: Pytester): class Choice(Enum): YES = 0 NO = 1 - + def m(a): b = a itest().given(a, Choice.YES).check_eq(b, Choice.YES) @@ -86,6 +86,21 @@ def m(a): res = pytester.runpytest() assert res.ret == 0 + def test_fail_on_importing_missing_module(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + from scipy import owijef as st + + def m(n, p): + b = st.binom(n, p) + itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 0 + def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """ From 19665d9618b8cb98bb874ca0d6b4d98b2518cf39 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:02:24 -0500 Subject: [PATCH 3/4] Added Import Functionality --- tests/test_plugin.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 953b61f..37e8c99 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -74,7 +74,7 @@ def test_inline_detects_from_imports(self, pytester: Pytester): class Choice(Enum): YES = 0 NO = 1 - + def m(a): b = a itest().given(a, Choice.YES).check_eq(b, Choice.YES) @@ -86,21 +86,6 @@ def m(a): res = pytester.runpytest() assert res.ret == 0 - def test_fail_on_importing_missing_module(self, pytester: Pytester): - checkfile = pytester.makepyfile( - """ - from inline import itest - from scipy import owijef as st - - def m(n, p): - b = st.binom(n, p) - itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p) - """ - ) - for x in (pytester.path, checkfile): - items, reprec = pytester.inline_genitems(x) - assert len(items) == 0 - def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """ From 53e16f728df45ad5a33448d68a35ac8aaea8c0b8 Mon Sep 17 00:00:00 2001 From: hanse141 Date: Sun, 23 Nov 2025 13:33:12 -0500 Subject: [PATCH 4/4] Added Failed Import Test --- tests/test_plugin.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 37e8c99..953b61f 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -74,7 +74,7 @@ def test_inline_detects_from_imports(self, pytester: Pytester): class Choice(Enum): YES = 0 NO = 1 - + def m(a): b = a itest().given(a, Choice.YES).check_eq(b, Choice.YES) @@ -86,6 +86,21 @@ def m(a): res = pytester.runpytest() assert res.ret == 0 + def test_fail_on_importing_missing_module(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + from scipy import owijef as st + + def m(n, p): + b = st.binom(n, p) + itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 0 + def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """