Fail closed on mapping mutation keys and bare input specs
This commit is contained in:
@@ -39,6 +39,7 @@ class StaticExtractionTests(unittest.TestCase):
|
|||||||
self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],)))
|
self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],)))
|
||||||
self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",)))
|
self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",)))
|
||||||
self.assertEqual("FLOAT", normalise_input_spec(("FLOAT", {"default": 1.0})))
|
self.assertEqual("FLOAT", normalise_input_spec(("FLOAT", {"default": 1.0})))
|
||||||
|
self.assertIsNone(normalise_input_spec("IMAGE"))
|
||||||
|
|
||||||
def test_extracts_static_node_mapping_and_signatures(self):
|
def test_extracts_static_node_mapping_and_signatures(self):
|
||||||
source = '''
|
source = '''
|
||||||
@@ -347,6 +348,58 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_duplicate_node_id_from_mapping_update_skips_static_node(self):
|
||||||
|
source_a = '''
|
||||||
|
class StaticDupNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DupNode": StaticDupNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
source_b = '''
|
||||||
|
class DynamicDupNode:
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"mask": ("MASK",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
NODE_CLASS_MAPPINGS.update({
|
||||||
|
"DupNode": DynamicDupNode,
|
||||||
|
})
|
||||||
|
'''
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
Path(tmp, "a.py").write_text(textwrap.dedent(source_a), encoding="utf-8")
|
||||||
|
Path(tmp, "b.py").write_text(textwrap.dedent(source_b), encoding="utf-8")
|
||||||
|
result = extract_repo_signatures(
|
||||||
|
Path(tmp),
|
||||||
|
{
|
||||||
|
"id": "update-duplicate-node-pack",
|
||||||
|
"title": "Update Duplicate Node Pack",
|
||||||
|
"repository": "https://github.com/example/update-duplicate-node-pack",
|
||||||
|
"rank": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_unsupported_reassignment_invalidates_static_env_value(self):
|
def test_unsupported_reassignment_invalidates_static_env_value(self):
|
||||||
source = '''
|
source = '''
|
||||||
def build_inputs():
|
def build_inputs():
|
||||||
@@ -1458,6 +1511,29 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_input_types_with_bare_string_input_spec_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class BareStringInputSpecNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": "IMAGE",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"BareStringInputSpecNode": BareStringInputSpecNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "bare-string-input-spec-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_dynamic_return_types_reassignment_skips_node(self):
|
def test_dynamic_return_types_reassignment_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
def build_outputs():
|
def build_outputs():
|
||||||
|
|||||||
@@ -1060,7 +1060,9 @@ def _collect_module_env(tree, class_bindings=None):
|
|||||||
|
|
||||||
|
|
||||||
def normalise_input_spec(spec):
|
def normalise_input_spec(spec):
|
||||||
first = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
|
if not isinstance(spec, (list, tuple)) or not spec:
|
||||||
|
return None
|
||||||
|
first = spec[0]
|
||||||
if isinstance(first, list):
|
if isinstance(first, list):
|
||||||
return "COMBO" if all(isinstance(value, str) for value in first) else None
|
return "COMBO" if all(isinstance(value, str) for value in first) else None
|
||||||
return first if isinstance(first, str) else None
|
return first if isinstance(first, str) else None
|
||||||
@@ -2056,6 +2058,90 @@ def _literal_module_dict_string_keys(node, env):
|
|||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
def _mapping_subscript_target_key(target, mapping_name, env):
|
||||||
|
if not isinstance(target, ast.Subscript):
|
||||||
|
return None
|
||||||
|
if _root_name(target.value) != mapping_name:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
key_value = _literal(target.slice, env)
|
||||||
|
except UnsupportedStaticExpression:
|
||||||
|
return None
|
||||||
|
return key_value if isinstance(key_value, str) and key_value else None
|
||||||
|
|
||||||
|
|
||||||
|
def _node_class_mapping_mutation_string_keys(stmt, env):
|
||||||
|
keys = set()
|
||||||
|
|
||||||
|
class MappingMutationKeyVisitor(ast.NodeVisitor):
|
||||||
|
def _visit_function_definition_expressions(self, node):
|
||||||
|
for decorator in node.decorator_list:
|
||||||
|
self.visit(decorator)
|
||||||
|
self.visit(node.args)
|
||||||
|
if node.returns is not None:
|
||||||
|
self.visit(node.returns)
|
||||||
|
for type_param in getattr(node, "type_params", ()):
|
||||||
|
self.visit(type_param)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
self._visit_function_definition_expressions(node)
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, node):
|
||||||
|
self._visit_function_definition_expressions(node)
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node):
|
||||||
|
for decorator in node.decorator_list:
|
||||||
|
self.visit(decorator)
|
||||||
|
for base in node.bases:
|
||||||
|
self.visit(base)
|
||||||
|
for keyword in node.keywords:
|
||||||
|
self.visit(keyword.value)
|
||||||
|
for type_param in getattr(node, "type_params", ()):
|
||||||
|
self.visit(type_param)
|
||||||
|
for child in node.body:
|
||||||
|
self.visit(child)
|
||||||
|
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
for target in node.targets:
|
||||||
|
key = _mapping_subscript_target_key(target, "NODE_CLASS_MAPPINGS", env)
|
||||||
|
if key is not None:
|
||||||
|
keys.add(key)
|
||||||
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_AnnAssign(self, node):
|
||||||
|
key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env)
|
||||||
|
if key is not None:
|
||||||
|
keys.add(key)
|
||||||
|
if node.value is not None:
|
||||||
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_AugAssign(self, node):
|
||||||
|
key = _mapping_subscript_target_key(node.target, "NODE_CLASS_MAPPINGS", env)
|
||||||
|
if key is not None:
|
||||||
|
keys.add(key)
|
||||||
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
if isinstance(node.func, ast.Attribute) and _root_name(node.func.value) == "NODE_CLASS_MAPPINGS":
|
||||||
|
if node.func.attr == "update":
|
||||||
|
for arg in node.args:
|
||||||
|
keys.update(_literal_module_dict_string_keys(arg, env))
|
||||||
|
for keyword in node.keywords:
|
||||||
|
if keyword.arg:
|
||||||
|
keys.add(keyword.arg)
|
||||||
|
elif node.func.attr == "setdefault" and node.args:
|
||||||
|
try:
|
||||||
|
key_value = _literal(node.args[0], env)
|
||||||
|
except UnsupportedStaticExpression:
|
||||||
|
key_value = None
|
||||||
|
if isinstance(key_value, str) and key_value:
|
||||||
|
keys.add(key_value)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
MappingMutationKeyVisitor().visit(stmt)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
def _node_class_mapping_keys(tree):
|
def _node_class_mapping_keys(tree):
|
||||||
if _has_module_wildcard_import(tree):
|
if _has_module_wildcard_import(tree):
|
||||||
return set()
|
return set()
|
||||||
@@ -2071,6 +2157,7 @@ def _node_class_mapping_keys(tree):
|
|||||||
and stmt.value is not None
|
and stmt.value is not None
|
||||||
):
|
):
|
||||||
keys.update(_literal_module_dict_string_keys(stmt.value, env))
|
keys.update(_literal_module_dict_string_keys(stmt.value, env))
|
||||||
|
keys.update(_node_class_mapping_mutation_string_keys(stmt, env))
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user