diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index f25107d..12023d1 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -2222,6 +2222,30 @@ globals()["GlobalsClassReturnTypesNode"].RETURN_TYPES.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_globals_get_class_return_types_mutation_after_mapping_skips_node(self): + source = ''' +class GlobalsGetReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GlobalsGetReturnTypesNode": GlobalsGetReturnTypesNode, +} +globals().get("GlobalsGetReturnTypesNode").RETURN_TYPES.clear() +''' + result = self._extract_source(source, "globals-get-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_module_class_tuple_alias_patch_after_mapping_skips_node(self): source = ''' class TupleAliasPatchedNode: @@ -2346,6 +2370,30 @@ globals()["NODE_CLASS_MAPPINGS"].clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_globals_update_invalidates_static_node_mapping(self): + source = ''' +class GlobalUpdateNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GlobalUpdateNode": GlobalUpdateNode, +} +globals().update(NODE_CLASS_MAPPINGS={}) +''' + result = self._extract_source(source, "global-update-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unpacked_alias_mutation_invalidates_static_node_mapping(self): source = ''' class UnpackedAliasMutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index ff85e1d..5f9c018 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -37,6 +37,8 @@ _CONTROL_FLOW_TYPES = (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.Wi if hasattr(ast, "TryStar"): _CONTROL_FLOW_TYPES += (ast.TryStar,) _CLASS_SIGNATURE_ATTRS = {"INPUT_TYPES", "RETURN_NAMES", "RETURN_TYPES"} +_DYNAMIC_NAMESPACE_MUTATION = object() +_NAMESPACE_FUNCTIONS = {"globals", "locals", "vars"} def _literal(node, env, allow_mutable_env=True): @@ -83,23 +85,46 @@ def _is_mutable_static_value(value): return isinstance(value, (dict, list, set)) +def _namespace_call_function_name(node): + if not isinstance(node, ast.Call): + return None + if not isinstance(node.func, ast.Name) or node.func.id not in _NAMESPACE_FUNCTIONS: + return None + if node.args or node.keywords: + return None + return node.func.id + + def _namespace_subscript_name(node): if not isinstance(node, ast.Subscript): return None - if not isinstance(node.value, ast.Call) or not isinstance(node.value.func, ast.Name): - return None - if node.value.func.id not in {"globals", "locals", "vars"}: - return None - if node.value.args or node.value.keywords: + if _namespace_call_function_name(node.value) is None: return None if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str): return node.slice.value return None +def _namespace_lookup_name(node): + if not isinstance(node, ast.Call): + return None + if not isinstance(node.func, ast.Attribute) or node.func.attr != "get": + return None + if _namespace_call_function_name(node.func.value) is None: + return None + if not node.args: + return None + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str): + return node.args[0].value + return None + + def _target_names(target): if isinstance(target, ast.Name): return {target.id} + if isinstance(target, ast.Call): + name = _namespace_lookup_name(target) + return {name} if name is not None else set() if isinstance(target, (ast.List, ast.Tuple)): names = set() for item in target.elts: @@ -119,6 +144,9 @@ def _target_names(target): def _root_name(node): while True: + name = _namespace_lookup_name(node) + if name is not None: + return name name = _namespace_subscript_name(node) if name is not None: return name @@ -167,6 +195,33 @@ def _getattr_mutating_method_target_names(node): return _target_names(getattr_call.args[0]) +def _namespace_mutating_call_target_names(node): + if not isinstance(node, ast.Call): + return set() + if not isinstance(node.func, ast.Attribute): + return set() + if _namespace_call_function_name(node.func.value) is None: + return set() + if node.func.attr not in _MUTATING_METHODS: + return set() + if node.func.attr != "update": + return {_DYNAMIC_NAMESPACE_MUTATION} + + names = set() + for keyword in node.keywords: + if keyword.arg is None: + names.add(_DYNAMIC_NAMESPACE_MUTATION) + else: + names.add(keyword.arg) + if node.args or not names: + names.add(_DYNAMIC_NAMESPACE_MUTATION) + return names + + +def _name_invalidated_by(name, names): + return name in names or _DYNAMIC_NAMESPACE_MUTATION in names + + def _attribute_target_base_names(target): if isinstance(target, ast.Attribute): name = _root_name(target.value) @@ -259,6 +314,7 @@ def _class_attribute_mutation_target_names(stmt): def visit_Call(self, node): names.update(_setattr_delattr_target_names(node)) names.update(_getattr_mutating_method_target_names(node)) + names.update(_namespace_mutating_call_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: names.update(_attribute_target_base_names(node.func.value)) self.generic_visit(node) @@ -428,6 +484,7 @@ def _mutating_call_target_names(stmt): def visit_Call(self, node): names.update(_setattr_delattr_target_names(node)) names.update(_getattr_mutating_method_target_names(node)) + names.update(_namespace_mutating_call_target_names(node)) if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: names.update(_target_names(node.func.value)) self.generic_visit(node) @@ -550,8 +607,13 @@ def _invalidate_class_bindings(class_bindings, names): def _apply_module_stmt_to_env(stmt, env, class_bindings=None): names = _mutating_call_target_names(stmt) - _invalidate_class_bindings(class_bindings, names) - _invalidate_env_names(env, names) + if _DYNAMIC_NAMESPACE_MUTATION in names: + env.clear() + if class_bindings is not None: + class_bindings.clear() + else: + _invalidate_class_bindings(class_bindings, names) + _invalidate_env_names(env, names) if isinstance(stmt, ast.ClassDef): if class_bindings is not None: if stmt.decorator_list: @@ -1127,7 +1189,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N and value_invalidated_by_names(value, class_attr_names) ): value = _INVALID - if name in _mutating_call_target_names(stmt): + if _name_invalidated_by(name, _mutating_call_target_names(stmt)): value = _INVALID if _module_dict_alias_invalidated(stmt, module_dict_aliases): value = _INVALID