diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 674a1b6..b0f045c 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -2978,6 +2978,56 @@ globals().update(NODE_CLASS_MAPPINGS={}) self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_globals_alias_subscript_assignment_invalidates_static_node_mapping(self): + source = ''' +class GlobalAliasSubscriptAssignmentNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GlobalAliasSubscriptAssignmentNode": GlobalAliasSubscriptAssignmentNode, +} +G = globals() +G["NODE_CLASS_MAPPINGS"] = {} +''' + result = self._extract_source(source, "global-alias-subscript-assignment-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_globals_alias_update_invalidates_static_node_mapping(self): + source = ''' +class GlobalAliasUpdateNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "GlobalAliasUpdateNode": GlobalAliasUpdateNode, +} +G = globals() +G.update(NODE_CLASS_MAPPINGS={}) +''' + result = self._extract_source(source, "global-alias-update-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_arbitrary_call_invalidates_static_node_mapping(self): source = ''' class ArbitraryCallMappingNode: @@ -3493,6 +3543,52 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_default_observed_by_arbitrary_call_skips_node(self): + source = ''' +class DefaultObservedInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls, value=observe(INPUT_TYPES)): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DefaultObservedInputTypesNode": DefaultObservedInputTypesNode, +} +''' + result = self._extract_source(source, "default-observed-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_input_types_return_annotation_observed_by_arbitrary_call_skips_node(self): + source = ''' +class AnnotationObservedInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls) -> observe(INPUT_TYPES): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "AnnotationObservedInputTypesNode": AnnotationObservedInputTypesNode, +} +''' + result = self._extract_source(source, "annotation-observed-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_write_artifact_is_deterministic(self): with tempfile.TemporaryDirectory() as tmp: out_one = Path(tmp, "one.json") diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index cb75fbf..78c112a 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -142,6 +142,19 @@ def _target_names(target): return set() +def _direct_target_names(target): + if isinstance(target, ast.Name): + return {target.id} + if isinstance(target, (ast.List, ast.Tuple)): + names = set() + for item in target.elts: + names.update(_direct_target_names(item)) + return names + if isinstance(target, ast.Starred): + return _direct_target_names(target.value) + return set() + + def _root_name(node): while True: name = _namespace_lookup_name(node) @@ -1099,11 +1112,17 @@ def _input_types(cls, env, decorator_env): for stmt in cls.body: mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt) - if "INPUT_TYPES" in mutating_targets or aliases.intersection(mutating_targets): - value = _INVALID - if "INPUT_TYPES" in observed_targets or aliases.intersection(observed_targets): + input_types_invalidated = ( + "INPUT_TYPES" in mutating_targets + or bool(aliases.intersection(mutating_targets)) + or "INPUT_TYPES" in observed_targets + or bool(aliases.intersection(observed_targets)) + ) + if input_types_invalidated: value = _INVALID if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": + if input_types_invalidated: + continue if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): value = _INVALID continue @@ -1435,6 +1454,139 @@ def _module_dict_alias_invalidated(stmt, aliases): return any(name in aliases for name in names) +def _namespace_alias_sources(value, aliases): + if _namespace_call_function_name(value) is not None: + return True + if isinstance(value, ast.Name): + return value.id in aliases + if isinstance(value, (ast.Tuple, ast.List)): + return any(_namespace_alias_sources(item, aliases) for item in value.elts) + return False + + +def _namespace_alias_subscript_name(node, aliases): + if not isinstance(node, ast.Subscript): + return None + if not isinstance(node.value, ast.Name) or node.value.id not in aliases: + return None + if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str): + return node.slice.value + return None + + +def _namespace_alias_target_names(target, aliases): + name = _namespace_alias_subscript_name(target, aliases) + if name is not None: + return {name} + if isinstance(target, (ast.Tuple, ast.List)): + names = set() + for item in target.elts: + names.update(_namespace_alias_target_names(item, aliases)) + return names + if isinstance(target, ast.Starred): + return _namespace_alias_target_names(target.value, aliases) + if isinstance(target, (ast.Attribute, ast.Subscript)): + return _namespace_alias_target_names(target.value, aliases) + return set() + + +def _namespace_alias_mutation_target_names(stmt, aliases): + names = set() + + class NamespaceAliasMutationVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + self.visit(node.args) + if node.returns is not None: + self.visit(node.returns) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + + def visit_FunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function_definition_expressions(node) + + def visit_ClassDef(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + for child in node.body: + self.visit(child) + + def visit_Lambda(self, node): + self.visit(node.args) + + def visit_Assign(self, node): + for target in node.targets: + names.update(_namespace_alias_target_names(target, aliases)) + self.visit(node.value) + + def visit_AnnAssign(self, node): + names.update(_namespace_alias_target_names(node.target, aliases)) + if node.value is not None: + self.visit(node.value) + + def visit_AugAssign(self, node): + names.update(_namespace_alias_target_names(node.target, aliases)) + self.visit(node.value) + + def visit_Delete(self, node): + for target in node.targets: + names.update(_namespace_alias_target_names(target, aliases)) + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute): + if isinstance(node.func.value, ast.Name) and node.func.value.id in aliases: + if node.func.attr == "update": + for keyword in node.keywords: + names.add(_DYNAMIC_NAMESPACE_MUTATION if keyword.arg is None else keyword.arg) + if node.args or not node.keywords: + names.add(_DYNAMIC_NAMESPACE_MUTATION) + elif node.func.attr in _MUTATING_METHODS: + names.add(_DYNAMIC_NAMESPACE_MUTATION) + namespace_name = _namespace_alias_subscript_name(node.func.value, aliases) + if namespace_name is not None and node.func.attr in _MUTATING_METHODS: + names.add(namespace_name) + self.generic_visit(node) + + NamespaceAliasMutationVisitor().visit(stmt) + return names + + +def _update_namespace_aliases(stmt, aliases): + direct_names = set() + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + direct_names.update(_direct_target_names(target)) + elif isinstance(stmt, (ast.AnnAssign, ast.AugAssign)): + direct_names.update(_direct_target_names(stmt.target)) + elif isinstance(stmt, ast.Delete): + for target in stmt.targets: + direct_names.update(_direct_target_names(target)) + direct_names.update(_bound_names(stmt)) + aliases.difference_update(direct_names) + + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + if _namespace_alias_sources(stmt.value, aliases): + aliases.add(stmt.targets[0].id) + elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: + for target_item, value_item in _unpack_target_value_pairs(stmt.targets[0], stmt.value): + target_name = _alias_target_name(target_item) + if target_name is not None and _namespace_alias_sources(value_item, aliases): + aliases.add(target_name) + elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: + if _namespace_alias_sources(stmt.value, aliases): + aliases.add(stmt.target.id) + + def _update_module_dict_aliases(stmt, name, aliases): rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) for rebound_name in rebound_names: @@ -1460,6 +1612,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N class_aliases = {} class_attribute_aliases = {} module_dict_aliases = {} + namespace_aliases = set() def advance_module_state(stmt): _invalidate_class_bindings( @@ -1470,6 +1623,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N _update_class_aliases(stmt, class_aliases, class_bindings) _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) _update_module_dict_aliases(stmt, name, module_dict_aliases) + _update_namespace_aliases(stmt, namespace_aliases) for stmt in tree.body: class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases) @@ -1483,6 +1637,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N value = _INVALID if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)): value = _INVALID + if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)): + value = _INVALID if _module_dict_alias_invalidated(stmt, module_dict_aliases): value = _INVALID if isinstance(stmt, ast.Assign):