diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 37e20b1..2d3599f 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -29,6 +29,20 @@ class StaticExtractionTests(unittest.TestCase): }, ) + def _extract_two_sources(self, source_a, source_b, pack_id): + with tempfile.TemporaryDirectory() as tmp: + Path(tmp, "a.py").write_text(textwrap.dedent(source_a), encoding="utf-8") + Path(tmp, "b.py").write_text(textwrap.dedent(source_b), encoding="utf-8") + return extract_repo_signatures( + Path(tmp), + { + "id": pack_id, + "title": "Sample Pack", + "repository": f"https://github.com/example/{pack_id}", + "rank": 1, + }, + ) + def _skip_if_syntax_unsupported(self, source): try: compile(textwrap.dedent(source), "", "exec") @@ -400,6 +414,64 @@ NODE_CLASS_MAPPINGS.update({ self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def _assert_duplicate_node_id_from_alias_mutation_skips_static_node(self, mutation, pack_id): + source_a = ''' +class StaticDupNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DupNode": StaticDupNode, +} +''' + source_b = f''' +class DynamicDupNode: + RETURN_TYPES = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return {{ + "required": {{ + "mask": ("MASK",), + }}, + }} + + +NODE_CLASS_MAPPINGS = {{}} +alias = NODE_CLASS_MAPPINGS +{mutation} +''' + result = self._extract_two_sources(source_a, source_b, pack_id) + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_duplicate_node_id_from_mapping_alias_subscript_skips_static_node(self): + self._assert_duplicate_node_id_from_alias_mutation_skips_static_node( + 'alias["DupNode"] = DynamicDupNode', + "alias-subscript-duplicate-node-pack", + ) + + def test_duplicate_node_id_from_mapping_alias_update_skips_static_node(self): + self._assert_duplicate_node_id_from_alias_mutation_skips_static_node( + 'alias.update({"DupNode": DynamicDupNode})', + "alias-update-duplicate-node-pack", + ) + + def test_duplicate_node_id_from_mapping_alias_setdefault_skips_static_node(self): + self._assert_duplicate_node_id_from_alias_mutation_skips_static_node( + 'alias.setdefault("DupNode", DynamicDupNode)', + "alias-setdefault-duplicate-node-pack", + ) + def test_unsupported_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index f012137..bad507d 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -2058,10 +2058,15 @@ def _literal_module_dict_string_keys(node, env): return keys -def _mapping_subscript_target_key(target, mapping_name, env): +def _mapping_subscript_target_key(target, mapping_name, env, aliases=None, namespace_aliases=None): if not isinstance(target, ast.Subscript): return None - if _root_name(target.value) != mapping_name: + if not _module_dict_alias_sources( + target.value, + mapping_name, + aliases or {}, + namespace_aliases or set(), + ): return None try: key_value = _literal(target.slice, env) @@ -2070,8 +2075,10 @@ def _mapping_subscript_target_key(target, mapping_name, env): return key_value if isinstance(key_value, str) and key_value else None -def _node_class_mapping_mutation_string_keys(stmt, env): +def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_aliases=None): keys = set() + aliases = aliases or {} + namespace_aliases = namespace_aliases or set() class MappingMutationKeyVisitor(ast.NodeVisitor): def _visit_function_definition_expressions(self, node): @@ -2103,26 +2110,52 @@ def _node_class_mapping_mutation_string_keys(stmt, env): def visit_Assign(self, node): for target in node.targets: - key = _mapping_subscript_target_key(target, "NODE_CLASS_MAPPINGS", env) + key = _mapping_subscript_target_key( + target, + "NODE_CLASS_MAPPINGS", + env, + aliases, + namespace_aliases, + ) if key is not None: keys.add(key) self.visit(node.value) def visit_AnnAssign(self, node): - key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env) + key = _mapping_subscript_target_key( + node.target, + "NODE_CLASS_MAPPINGS", + env, + aliases, + namespace_aliases, + ) if key is not None: keys.add(key) if node.value is not None: self.visit(node.value) def visit_AugAssign(self, node): - key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env) + key = _mapping_subscript_target_key( + node.target, + "NODE_CLASS_MAPPINGS", + env, + aliases, + namespace_aliases, + ) if key is not None: keys.add(key) self.visit(node.value) def visit_Call(self, node): - if isinstance(node.func, ast.Attribute) and _root_name(node.func.value) == "NODE_CLASS_MAPPINGS": + if ( + isinstance(node.func, ast.Attribute) + and _module_dict_alias_sources( + node.func.value, + "NODE_CLASS_MAPPINGS", + aliases, + namespace_aliases, + ) + ): if node.func.attr == "update": for arg in node.args: keys.update(_literal_module_dict_string_keys(arg, env)) @@ -2148,6 +2181,8 @@ def _node_class_mapping_keys(tree): keys = set() env = {} class_bindings = {} + module_dict_aliases = {} + namespace_aliases = set() for stmt in tree.body: if isinstance(stmt, ast.Assign) and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS"): keys.update(_literal_module_dict_string_keys(stmt.value, env)) @@ -2157,8 +2192,22 @@ def _node_class_mapping_keys(tree): and stmt.value is not None ): keys.update(_literal_module_dict_string_keys(stmt.value, env)) - keys.update(_node_class_mapping_mutation_string_keys(stmt, env)) + keys.update( + _node_class_mapping_mutation_string_keys( + stmt, + env, + module_dict_aliases, + namespace_aliases, + ) + ) _apply_module_stmt_to_env(stmt, env, class_bindings) + _update_module_dict_aliases( + stmt, + "NODE_CLASS_MAPPINGS", + module_dict_aliases, + namespace_aliases, + ) + _update_namespace_aliases(stmt, namespace_aliases) return keys