Fail closed on mapping mutation keys and bare input specs

This commit is contained in:
2026-07-02 19:46:55 +02:00
parent 7e7479fb6a
commit 9792989216
2 changed files with 164 additions and 1 deletions
@@ -39,6 +39,7 @@ class StaticExtractionTests(unittest.TestCase):
self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],))) self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],)))
self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",))) self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",)))
self.assertEqual("FLOAT", normalise_input_spec(("FLOAT", {"default": 1.0}))) self.assertEqual("FLOAT", normalise_input_spec(("FLOAT", {"default": 1.0})))
self.assertIsNone(normalise_input_spec("IMAGE"))
def test_extracts_static_node_mapping_and_signatures(self): def test_extracts_static_node_mapping_and_signatures(self):
source = ''' source = '''
@@ -347,6 +348,58 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_duplicate_node_id_from_mapping_update_skips_static_node(self):
source_a = '''
class StaticDupNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DupNode": StaticDupNode,
}
'''
source_b = '''
class DynamicDupNode:
RETURN_TYPES = ("MASK",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
},
}
NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS.update({
"DupNode": DynamicDupNode,
})
'''
with tempfile.TemporaryDirectory() as tmp:
Path(tmp, "a.py").write_text(textwrap.dedent(source_a), encoding="utf-8")
Path(tmp, "b.py").write_text(textwrap.dedent(source_b), encoding="utf-8")
result = extract_repo_signatures(
Path(tmp),
{
"id": "update-duplicate-node-pack",
"title": "Update Duplicate Node Pack",
"repository": "https://github.com/example/update-duplicate-node-pack",
"rank": 1,
},
)
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_unsupported_reassignment_invalidates_static_env_value(self): def test_unsupported_reassignment_invalidates_static_env_value(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
@@ -1458,6 +1511,29 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_with_bare_string_input_spec_skips_node(self):
source = '''
class BareStringInputSpecNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": "IMAGE",
},
}
NODE_CLASS_MAPPINGS = {
"BareStringInputSpecNode": BareStringInputSpecNode,
}
'''
result = self._extract_source(source, "bare-string-input-spec-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_dynamic_return_types_reassignment_skips_node(self): def test_dynamic_return_types_reassignment_skips_node(self):
source = ''' source = '''
def build_outputs(): def build_outputs():
+88 -1
View File
@@ -1060,7 +1060,9 @@ def _collect_module_env(tree, class_bindings=None):
def normalise_input_spec(spec): def normalise_input_spec(spec):
first = spec[0] if isinstance(spec, (list, tuple)) and spec else spec if not isinstance(spec, (list, tuple)) or not spec:
return None
first = spec[0]
if isinstance(first, list): if isinstance(first, list):
return "COMBO" if all(isinstance(value, str) for value in first) else None return "COMBO" if all(isinstance(value, str) for value in first) else None
return first if isinstance(first, str) else None return first if isinstance(first, str) else None
@@ -2056,6 +2058,90 @@ def _literal_module_dict_string_keys(node, env):
return keys return keys
def _mapping_subscript_target_key(target, mapping_name, env):
if not isinstance(target, ast.Subscript):
return None
if _root_name(target.value) != mapping_name:
return None
try:
key_value = _literal(target.slice, env)
except UnsupportedStaticExpression:
return None
return key_value if isinstance(key_value, str) and key_value else None
def _node_class_mapping_mutation_string_keys(stmt, env):
keys = set()
class MappingMutationKeyVisitor(ast.NodeVisitor):
def _visit_function_definition_expressions(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns is not None:
self.visit(node.returns)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
def visit_FunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function_definition_expressions(node)
def visit_ClassDef(self, node):
for decorator in node.decorator_list:
self.visit(decorator)
for base in node.bases:
self.visit(base)
for keyword in node.keywords:
self.visit(keyword.value)
for type_param in getattr(node, "type_params", ()):
self.visit(type_param)
for child in node.body:
self.visit(child)
def visit_Assign(self, node):
for target in node.targets:
key = _mapping_subscript_target_key(target, "NODE_CLASS_MAPPINGS", env)
if key is not None:
keys.add(key)
self.visit(node.value)
def visit_AnnAssign(self, node):
key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env)
if key is not None:
keys.add(key)
if node.value is not None:
self.visit(node.value)
def visit_AugAssign(self, node):
key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env)
if key is not None:
keys.add(key)
self.visit(node.value)
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute) and _root_name(node.func.value) == "NODE_CLASS_MAPPINGS":
if node.func.attr == "update":
for arg in node.args:
keys.update(_literal_module_dict_string_keys(arg, env))
for keyword in node.keywords:
if keyword.arg:
keys.add(keyword.arg)
elif node.func.attr == "setdefault" and node.args:
try:
key_value = _literal(node.args[0], env)
except UnsupportedStaticExpression:
key_value = None
if isinstance(key_value, str) and key_value:
keys.add(key_value)
self.generic_visit(node)
MappingMutationKeyVisitor().visit(stmt)
return keys
def _node_class_mapping_keys(tree): def _node_class_mapping_keys(tree):
if _has_module_wildcard_import(tree): if _has_module_wildcard_import(tree):
return set() return set()
@@ -2071,6 +2157,7 @@ def _node_class_mapping_keys(tree):
and stmt.value is not None and stmt.value is not None
): ):
keys.update(_literal_module_dict_string_keys(stmt.value, env)) keys.update(_literal_module_dict_string_keys(stmt.value, env))
keys.update(_node_class_mapping_mutation_string_keys(stmt, env))
_apply_module_stmt_to_env(stmt, env, class_bindings) _apply_module_stmt_to_env(stmt, env, class_bindings)
return keys return keys