diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 4be5680..d8730e8 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -414,6 +414,34 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_wildcard_import_invalidates_static_env_values(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +if True: + from something import * + + +class NestedWildcardImportInputNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "NestedWildcardImportInputNode": NestedWildcardImportInputNode, +} +''' + result = self._extract_source(source, "nested-wildcard-import-input-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): @@ -763,6 +791,31 @@ from something import * self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_wildcard_import_invalidates_static_node_mapping(self): + source = ''' +class NestedWildcardImportMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "NestedWildcardImportMappingNode": NestedWildcardImportMappingNode, +} +if True: + from something import * +''' + result = self._extract_source(source, "nested-wildcard-import-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_display_mapping_reassignment_falls_back_to_node_type(self): source = ''' def build_displays(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index cf623ea..11119d7 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -251,6 +251,28 @@ def _assigned_names_in_control_flow(stmt): return names +def _has_wildcard_import_in_control_flow(stmt): + found = False + + class WildcardImportVisitor(ast.NodeVisitor): + def visit_FunctionDef(self, node): + return None + + def visit_AsyncFunctionDef(self, node): + return None + + def visit_ClassDef(self, node): + return None + + def visit_ImportFrom(self, node): + nonlocal found + if _has_wildcard_import(node): + found = True + + WildcardImportVisitor().visit(stmt) + return found + + def _collect_module_env(tree): env = {} for stmt in tree.body: @@ -311,6 +333,9 @@ def _collect_module_env(tree): env.pop(name, None) continue if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if _has_wildcard_import_in_control_flow(stmt): + env.clear() + continue for name in _assigned_names_in_control_flow(stmt): env.pop(name, None) continue @@ -374,6 +399,8 @@ def _class_attr(cls, name, env): if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): if name in _assigned_names_in_control_flow(stmt): value = _INVALID + if _has_wildcard_import_in_control_flow(stmt): + value = _INVALID continue if name in _bound_names(stmt): value = _INVALID @@ -468,6 +495,8 @@ def _final_module_dict(tree, env, name, value_converter): if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): if name in _assigned_names_in_control_flow(stmt): value = _INVALID + if _has_wildcard_import_in_control_flow(stmt): + value = _INVALID continue if _has_wildcard_import(stmt): value = _INVALID