diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 0745ab9..082a17b 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -470,6 +470,59 @@ alias = NODE_CLASS_MAPPINGS "alias-setdefault-duplicate-node-pack", ) + def test_duplicate_node_id_from_multi_target_mapping_alias_skips_static_node(self): + self._assert_duplicate_node_id_from_alias_mutation_skips_static_node( + 'alias = other = NODE_CLASS_MAPPINGS\nalias["DupNode"] = DynamicDupNode', + "multi-target-alias-duplicate-node-pack", + ) + + def test_ambiguous_mapping_mutation_key_suppresses_static_nodes(self): + source_a = ''' +class StaticDupNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DupNode": StaticDupNode, +} +''' + source_b = ''' +def get_id(): + return "DupNode" + + +class DynamicDupNode: + RETURN_TYPES = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + }, + } + + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS[get_id()] = DynamicDupNode +''' + result = self._extract_two_sources( + source_a, + source_b, + "ambiguous-mapping-key-duplicate-node-pack", + ) + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_unsupported_reassignment_invalidates_static_env_value(self): source = ''' def build_inputs(): diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 9ce6980..5a52d1b 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -1927,6 +1927,13 @@ def _update_module_dict_aliases(stmt, name, aliases, namespace_aliases): sources = _module_dict_alias_sources(stmt.value, name, aliases, namespace_aliases) if sources: aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1: + sources = _module_dict_alias_sources(stmt.value, name, aliases, namespace_aliases) + if sources: + for target in stmt.targets: + target_name = _alias_target_name(target) + if target_name is not None: + aliases[target_name] = sources elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: _update_module_dict_alias_from_unpack(stmt.targets[0], stmt.value, name, aliases, namespace_aliases) elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: @@ -2099,40 +2106,49 @@ def _node_class_mappings(tree): def _literal_module_dict_string_keys(node, env): + keys, _ambiguous = _literal_module_dict_string_keys_state(node, env) + return keys + + +def _literal_module_dict_string_keys_state(node, env): if not isinstance(node, ast.Dict): - return set() + return set(), False keys = set() + ambiguous = False for key in node.keys: if key is None: + ambiguous = True continue try: key_value = _literal(key, env) except UnsupportedStaticExpression: + ambiguous = True continue if isinstance(key_value, str) and key_value: keys.add(key_value) - return keys + return keys, ambiguous -def _mapping_subscript_target_key(target, mapping_name, env, aliases=None, namespace_aliases=None): +def _mapping_subscript_target_key_state(target, mapping_name, env, aliases=None, namespace_aliases=None): if not isinstance(target, ast.Subscript): - return None + return None, False if not _module_dict_alias_sources( target.value, mapping_name, aliases or {}, namespace_aliases or set(), ): - return None + return None, False try: key_value = _literal(target.slice, env) except UnsupportedStaticExpression: - return None - return key_value if isinstance(key_value, str) and key_value else None + return None, True + return (key_value, False) if isinstance(key_value, str) and key_value else (None, False) def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_aliases=None): keys = set() + ambiguous = False aliases = aliases or {} namespace_aliases = namespace_aliases or set() @@ -2165,44 +2181,51 @@ def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_ self.visit(child) def visit_Assign(self, node): + nonlocal ambiguous for target in node.targets: - key = _mapping_subscript_target_key( + key, key_ambiguous = _mapping_subscript_target_key_state( target, "NODE_CLASS_MAPPINGS", env, aliases, namespace_aliases, ) + ambiguous = ambiguous or key_ambiguous if key is not None: keys.add(key) self.visit(node.value) def visit_AnnAssign(self, node): - key = _mapping_subscript_target_key( + nonlocal ambiguous + key, key_ambiguous = _mapping_subscript_target_key_state( node.target, "NODE_CLASS_MAPPINGS", env, aliases, namespace_aliases, ) + ambiguous = ambiguous or key_ambiguous if key is not None: keys.add(key) if node.value is not None: self.visit(node.value) def visit_AugAssign(self, node): - key = _mapping_subscript_target_key( + nonlocal ambiguous + key, key_ambiguous = _mapping_subscript_target_key_state( node.target, "NODE_CLASS_MAPPINGS", env, aliases, namespace_aliases, ) + ambiguous = ambiguous or key_ambiguous if key is not None: keys.add(key) self.visit(node.value) def visit_Call(self, node): + nonlocal ambiguous if ( isinstance(node.func, ast.Attribute) and _module_dict_alias_sources( @@ -2214,21 +2237,37 @@ def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_ ): if node.func.attr == "update": for arg in node.args: - keys.update(_literal_module_dict_string_keys(arg, env)) + if isinstance(arg, ast.Dict): + arg_keys, arg_ambiguous = _literal_module_dict_string_keys_state(arg, env) + keys.update(arg_keys) + ambiguous = ambiguous or arg_ambiguous + else: + ambiguous = True for keyword in node.keywords: if keyword.arg: keys.add(keyword.arg) + else: + ambiguous = True elif node.func.attr == "setdefault" and node.args: try: key_value = _literal(node.args[0], env) except UnsupportedStaticExpression: key_value = None + ambiguous = True + if isinstance(key_value, str) and key_value: + keys.add(key_value) + elif node.func.attr == "__setitem__" and node.args: + try: + key_value = _literal(node.args[0], env) + except UnsupportedStaticExpression: + key_value = None + ambiguous = True if isinstance(key_value, str) and key_value: keys.add(key_value) self.generic_visit(node) MappingMutationKeyVisitor().visit(stmt) - return keys + return _INVALID if ambiguous else keys def _node_class_mapping_keys(tree): @@ -2241,21 +2280,28 @@ def _node_class_mapping_keys(tree): namespace_aliases = set() for stmt in tree.body: if isinstance(stmt, ast.Assign) and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS"): - keys.update(_literal_module_dict_string_keys(stmt.value, env)) + literal_keys, literal_ambiguous = _literal_module_dict_string_keys_state(stmt.value, env) + keys.update(literal_keys) + if literal_ambiguous: + return _INVALID elif ( isinstance(stmt, ast.AnnAssign) and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS") and stmt.value is not None ): - keys.update(_literal_module_dict_string_keys(stmt.value, env)) - keys.update( - _node_class_mapping_mutation_string_keys( - stmt, - env, - module_dict_aliases, - namespace_aliases, - ) + literal_keys, literal_ambiguous = _literal_module_dict_string_keys_state(stmt.value, env) + keys.update(literal_keys) + if literal_ambiguous: + return _INVALID + mutation_keys = _node_class_mapping_mutation_string_keys( + stmt, + env, + module_dict_aliases, + namespace_aliases, ) + if mutation_keys is _INVALID: + return _INVALID + keys.update(mutation_keys) _apply_module_stmt_to_env(stmt, env, class_bindings) _update_module_dict_aliases( stmt, @@ -2368,6 +2414,9 @@ def extract_repo_signatures(repo_dir, pack_meta): env = _collect_module_env(tree) mappings = _node_class_mappings(tree) mapping_node_types = _node_class_mapping_keys(tree) + if mapping_node_types is _INVALID: + nodes = {} + break displays = _display_mappings(tree) for node_type in sorted(mapping_node_types): prior_path = node_sources.get(node_type)