Fail closed on ambiguous mapping duplicate keys

This commit is contained in:
2026-07-02 20:21:04 +02:00
parent ecd8f7c082
commit 75224982ba
2 changed files with 123 additions and 21 deletions
@@ -470,6 +470,59 @@ alias = NODE_CLASS_MAPPINGS
"alias-setdefault-duplicate-node-pack", "alias-setdefault-duplicate-node-pack",
) )
def test_duplicate_node_id_from_multi_target_mapping_alias_skips_static_node(self):
self._assert_duplicate_node_id_from_alias_mutation_skips_static_node(
'alias = other = NODE_CLASS_MAPPINGS\nalias["DupNode"] = DynamicDupNode',
"multi-target-alias-duplicate-node-pack",
)
def test_ambiguous_mapping_mutation_key_suppresses_static_nodes(self):
source_a = '''
class StaticDupNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DupNode": StaticDupNode,
}
'''
source_b = '''
def get_id():
return "DupNode"
class DynamicDupNode:
RETURN_TYPES = ("MASK",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
},
}
NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS[get_id()] = DynamicDupNode
'''
result = self._extract_two_sources(
source_a,
source_b,
"ambiguous-mapping-key-duplicate-node-pack",
)
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_unsupported_reassignment_invalidates_static_env_value(self): def test_unsupported_reassignment_invalidates_static_env_value(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
+66 -17
View File
@@ -1927,6 +1927,13 @@ def _update_module_dict_aliases(stmt, name, aliases, namespace_aliases):
sources = _module_dict_alias_sources(stmt.value, name, aliases, namespace_aliases) sources = _module_dict_alias_sources(stmt.value, name, aliases, namespace_aliases)
if sources: if sources:
aliases[stmt.targets[0].id] = sources aliases[stmt.targets[0].id] = sources
elif isinstance(stmt, ast.Assign) and len(stmt.targets) > 1:
sources = _module_dict_alias_sources(stmt.value, name, aliases, namespace_aliases)
if sources:
for target in stmt.targets:
target_name = _alias_target_name(target)
if target_name is not None:
aliases[target_name] = sources
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
_update_module_dict_alias_from_unpack(stmt.targets[0], stmt.value, name, aliases, namespace_aliases) _update_module_dict_alias_from_unpack(stmt.targets[0], stmt.value, name, aliases, namespace_aliases)
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
@@ -2099,40 +2106,49 @@ def _node_class_mappings(tree):
def _literal_module_dict_string_keys(node, env): def _literal_module_dict_string_keys(node, env):
keys, _ambiguous = _literal_module_dict_string_keys_state(node, env)
return keys
def _literal_module_dict_string_keys_state(node, env):
if not isinstance(node, ast.Dict): if not isinstance(node, ast.Dict):
return set() return set(), False
keys = set() keys = set()
ambiguous = False
for key in node.keys: for key in node.keys:
if key is None: if key is None:
ambiguous = True
continue continue
try: try:
key_value = _literal(key, env) key_value = _literal(key, env)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
ambiguous = True
continue continue
if isinstance(key_value, str) and key_value: if isinstance(key_value, str) and key_value:
keys.add(key_value) keys.add(key_value)
return keys return keys, ambiguous
def _mapping_subscript_target_key(target, mapping_name, env, aliases=None, namespace_aliases=None): def _mapping_subscript_target_key_state(target, mapping_name, env, aliases=None, namespace_aliases=None):
if not isinstance(target, ast.Subscript): if not isinstance(target, ast.Subscript):
return None return None, False
if not _module_dict_alias_sources( if not _module_dict_alias_sources(
target.value, target.value,
mapping_name, mapping_name,
aliases or {}, aliases or {},
namespace_aliases or set(), namespace_aliases or set(),
): ):
return None return None, False
try: try:
key_value = _literal(target.slice, env) key_value = _literal(target.slice, env)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
return None return None, True
return key_value if isinstance(key_value, str) and key_value else None return (key_value, False) if isinstance(key_value, str) and key_value else (None, False)
def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_aliases=None): def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_aliases=None):
keys = set() keys = set()
ambiguous = False
aliases = aliases or {} aliases = aliases or {}
namespace_aliases = namespace_aliases or set() namespace_aliases = namespace_aliases or set()
@@ -2165,44 +2181,51 @@ def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_
self.visit(child) self.visit(child)
def visit_Assign(self, node): def visit_Assign(self, node):
nonlocal ambiguous
for target in node.targets: for target in node.targets:
key = _mapping_subscript_target_key( key, key_ambiguous = _mapping_subscript_target_key_state(
target, target,
"NODE_CLASS_MAPPINGS", "NODE_CLASS_MAPPINGS",
env, env,
aliases, aliases,
namespace_aliases, namespace_aliases,
) )
ambiguous = ambiguous or key_ambiguous
if key is not None: if key is not None:
keys.add(key) keys.add(key)
self.visit(node.value) self.visit(node.value)
def visit_AnnAssign(self, node): def visit_AnnAssign(self, node):
key = _mapping_subscript_target_key( nonlocal ambiguous
key, key_ambiguous = _mapping_subscript_target_key_state(
node.target, node.target,
"NODE_CLASS_MAPPINGS", "NODE_CLASS_MAPPINGS",
env, env,
aliases, aliases,
namespace_aliases, namespace_aliases,
) )
ambiguous = ambiguous or key_ambiguous
if key is not None: if key is not None:
keys.add(key) keys.add(key)
if node.value is not None: if node.value is not None:
self.visit(node.value) self.visit(node.value)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
key = _mapping_subscript_target_key( nonlocal ambiguous
key, key_ambiguous = _mapping_subscript_target_key_state(
node.target, node.target,
"NODE_CLASS_MAPPINGS", "NODE_CLASS_MAPPINGS",
env, env,
aliases, aliases,
namespace_aliases, namespace_aliases,
) )
ambiguous = ambiguous or key_ambiguous
if key is not None: if key is not None:
keys.add(key) keys.add(key)
self.visit(node.value) self.visit(node.value)
def visit_Call(self, node): def visit_Call(self, node):
nonlocal ambiguous
if ( if (
isinstance(node.func, ast.Attribute) isinstance(node.func, ast.Attribute)
and _module_dict_alias_sources( and _module_dict_alias_sources(
@@ -2214,21 +2237,37 @@ def _node_class_mapping_mutation_string_keys(stmt, env, aliases=None, namespace_
): ):
if node.func.attr == "update": if node.func.attr == "update":
for arg in node.args: for arg in node.args:
keys.update(_literal_module_dict_string_keys(arg, env)) if isinstance(arg, ast.Dict):
arg_keys, arg_ambiguous = _literal_module_dict_string_keys_state(arg, env)
keys.update(arg_keys)
ambiguous = ambiguous or arg_ambiguous
else:
ambiguous = True
for keyword in node.keywords: for keyword in node.keywords:
if keyword.arg: if keyword.arg:
keys.add(keyword.arg) keys.add(keyword.arg)
else:
ambiguous = True
elif node.func.attr == "setdefault" and node.args: elif node.func.attr == "setdefault" and node.args:
try: try:
key_value = _literal(node.args[0], env) key_value = _literal(node.args[0], env)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
key_value = None key_value = None
ambiguous = True
if isinstance(key_value, str) and key_value:
keys.add(key_value)
elif node.func.attr == "__setitem__" and node.args:
try:
key_value = _literal(node.args[0], env)
except UnsupportedStaticExpression:
key_value = None
ambiguous = True
if isinstance(key_value, str) and key_value: if isinstance(key_value, str) and key_value:
keys.add(key_value) keys.add(key_value)
self.generic_visit(node) self.generic_visit(node)
MappingMutationKeyVisitor().visit(stmt) MappingMutationKeyVisitor().visit(stmt)
return keys return _INVALID if ambiguous else keys
def _node_class_mapping_keys(tree): def _node_class_mapping_keys(tree):
@@ -2241,21 +2280,28 @@ def _node_class_mapping_keys(tree):
namespace_aliases = set() namespace_aliases = set()
for stmt in tree.body: for stmt in tree.body:
if isinstance(stmt, ast.Assign) and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS"): if isinstance(stmt, ast.Assign) and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS"):
keys.update(_literal_module_dict_string_keys(stmt.value, env)) literal_keys, literal_ambiguous = _literal_module_dict_string_keys_state(stmt.value, env)
keys.update(literal_keys)
if literal_ambiguous:
return _INVALID
elif ( elif (
isinstance(stmt, ast.AnnAssign) isinstance(stmt, ast.AnnAssign)
and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS") and _name_is_assigned(stmt, "NODE_CLASS_MAPPINGS")
and stmt.value is not None and stmt.value is not None
): ):
keys.update(_literal_module_dict_string_keys(stmt.value, env)) literal_keys, literal_ambiguous = _literal_module_dict_string_keys_state(stmt.value, env)
keys.update( keys.update(literal_keys)
_node_class_mapping_mutation_string_keys( if literal_ambiguous:
return _INVALID
mutation_keys = _node_class_mapping_mutation_string_keys(
stmt, stmt,
env, env,
module_dict_aliases, module_dict_aliases,
namespace_aliases, namespace_aliases,
) )
) if mutation_keys is _INVALID:
return _INVALID
keys.update(mutation_keys)
_apply_module_stmt_to_env(stmt, env, class_bindings) _apply_module_stmt_to_env(stmt, env, class_bindings)
_update_module_dict_aliases( _update_module_dict_aliases(
stmt, stmt,
@@ -2368,6 +2414,9 @@ def extract_repo_signatures(repo_dir, pack_meta):
env = _collect_module_env(tree) env = _collect_module_env(tree)
mappings = _node_class_mappings(tree) mappings = _node_class_mappings(tree)
mapping_node_types = _node_class_mapping_keys(tree) mapping_node_types = _node_class_mapping_keys(tree)
if mapping_node_types is _INVALID:
nodes = {}
break
displays = _display_mappings(tree) displays = _display_mappings(tree)
for node_type in sorted(mapping_node_types): for node_type in sorted(mapping_node_types):
prior_path = node_sources.get(node_type) prior_path = node_sources.get(node_type)