diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 8f6f7cc..4be5680 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -359,6 +359,61 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_alias_mutation_invalidates_static_source_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +ALIAS: dict = INPUTS +ALIAS.clear() + + +class AnnotatedAliasMutatedInputNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "AnnotatedAliasMutatedInputNode": AnnotatedAliasMutatedInputNode, +} +''' + result = self._extract_source(source, "annotated-alias-mutated-input-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_wildcard_import_invalidates_static_env_values(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +from something import * + + +class WildcardImportInputNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "WildcardImportInputNode": WildcardImportInputNode, +} +''' + result = self._extract_source(source, "wildcard-import-input-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): @@ -659,6 +714,55 @@ NODE_CLASS_MAPPINGS.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_alias_mutation_invalidates_static_node_mapping(self): + source = ''' +class AnnotatedAliasMutatedMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "AnnotatedAliasMutatedMappingNode": AnnotatedAliasMutatedMappingNode, +} +ALIAS: dict = NODE_CLASS_MAPPINGS +ALIAS.clear() +''' + result = self._extract_source(source, "annotated-alias-mutated-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_wildcard_import_invalidates_static_node_mapping(self): + source = ''' +class WildcardImportMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "WildcardImportMappingNode": WildcardImportMappingNode, +} +from something import * +''' + result = self._extract_source(source, "wildcard-import-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_display_mapping_reassignment_falls_back_to_node_type(self): source = ''' def build_displays(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 417ff07..cf623ea 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -148,6 +148,10 @@ def _bound_names(stmt): return names +def _has_wildcard_import(stmt): + return isinstance(stmt, ast.ImportFrom) and any(alias.name == "*" for alias in stmt.names) + + def _assignment_target_names(stmt): if isinstance(stmt, ast.Assign): names = set() @@ -276,6 +280,14 @@ def _collect_module_env(tree): continue if isinstance(stmt.target, ast.Name): name = stmt.target.id + if ( + isinstance(stmt.value, ast.Name) + and stmt.value.id in env + and _is_mutable_static_value(env[stmt.value.id]) + ): + env.pop(stmt.value.id, None) + env.pop(name, None) + continue try: env[name] = _literal(stmt.value, env) except UnsupportedStaticExpression: @@ -302,6 +314,9 @@ def _collect_module_env(tree): for name in _assigned_names_in_control_flow(stmt): env.pop(name, None) continue + if _has_wildcard_import(stmt): + env.clear() + continue for name in _bound_names(stmt): env.pop(name, None) return env @@ -425,6 +440,8 @@ def _final_module_dict(tree, env, name, value_converter): continue if isinstance(stmt, ast.AnnAssign): if not _name_is_assigned(stmt, name): + if isinstance(stmt.value, ast.Name) and stmt.value.id == name: + value = _INVALID continue if isinstance(stmt.target, ast.Name) and stmt.value is not None: try: @@ -452,6 +469,9 @@ def _final_module_dict(tree, env, name, value_converter): if name in _assigned_names_in_control_flow(stmt): value = _INVALID continue + if _has_wildcard_import(stmt): + value = _INVALID + continue if name in _bound_names(stmt): value = _INVALID if value in (_MISSING, _INVALID):