Fail closed on walrus bindings and invalid input sections
This commit is contained in:
@@ -934,6 +934,30 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_unhashable_literal_input_key_skips_repo_without_raising(self):
|
||||
source = '''
|
||||
INPUTS = {
|
||||
["bad"]: ("IMAGE",),
|
||||
}
|
||||
|
||||
|
||||
class UnhashableLiteralInputKeyNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return INPUTS
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UnhashableLiteralInputKeyNode": UnhashableLiteralInputKeyNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "unhashable-literal-input-key-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_post_class_input_reassignment_skips_static_node(self):
|
||||
source = '''
|
||||
def build_inputs():
|
||||
@@ -1057,6 +1081,28 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_input_types_with_present_non_dict_sections_skips_node(self):
|
||||
source = '''
|
||||
class InvalidInputSectionsNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": [],
|
||||
"optional": None,
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"InvalidInputSectionsNode": InvalidInputSectionsNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "invalid-input-sections-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_dynamic_return_types_reassignment_skips_node(self):
|
||||
source = '''
|
||||
def build_outputs():
|
||||
@@ -1157,6 +1203,32 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_function_default_walrus_to_return_types_skips_node(self):
|
||||
source = '''
|
||||
class DefaultWalrusReturnTypesNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
def helper(self, x=(RETURN_TYPES := ("MASK",))):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DefaultWalrusReturnTypesNode": DefaultWalrusReturnTypesNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "default-walrus-return-types-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_function_default_mutation_to_return_types_skips_node(self):
|
||||
source = '''
|
||||
class DefaultMutatedReturnTypesNode:
|
||||
@@ -1844,6 +1916,31 @@ Alias.RETURN_TYPES = ("MASK",)
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_module_class_tuple_alias_patch_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class TupleAliasPatchedNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TupleAliasPatchedNode": TupleAliasPatchedNode,
|
||||
}
|
||||
Alias, = (TupleAliasPatchedNode,)
|
||||
Alias.RETURN_TYPES = ("MASK",)
|
||||
'''
|
||||
result = self._extract_source(source, "tuple-alias-patched-node-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_definition_time_class_attribute_mutation_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class DefinitionTimeMutatedMappedNode:
|
||||
|
||||
@@ -50,7 +50,11 @@ def _literal(node, env, allow_mutable_env=True):
|
||||
for key, value in zip(node.keys, node.values):
|
||||
if key is None:
|
||||
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
||||
result[_literal(key, env, allow_mutable_env=False)] = _literal(value, env, allow_mutable_env=False)
|
||||
key_value = _literal(key, env, allow_mutable_env=False)
|
||||
try:
|
||||
result[key_value] = _literal(value, env, allow_mutable_env=False)
|
||||
except TypeError as exc:
|
||||
raise UnsupportedStaticExpression("unhashable dict key") from exc
|
||||
return result
|
||||
if isinstance(node, ast.Name) and node.id in env:
|
||||
value = env[node.id]
|
||||
@@ -197,17 +201,35 @@ def _named_expr_target_names(node):
|
||||
names = set()
|
||||
|
||||
class NamedExprVisitor(ast.NodeVisitor):
|
||||
def _visit_function_definition_expressions(self, child):
|
||||
for decorator in child.decorator_list:
|
||||
self.visit(decorator)
|
||||
self.visit(child.args)
|
||||
if child.returns is not None:
|
||||
self.visit(child.returns)
|
||||
for type_param in getattr(child, "type_params", ()):
|
||||
self.visit(type_param)
|
||||
|
||||
def visit_FunctionDef(self, child):
|
||||
return None
|
||||
self._visit_function_definition_expressions(child)
|
||||
|
||||
def visit_AsyncFunctionDef(self, child):
|
||||
return None
|
||||
self._visit_function_definition_expressions(child)
|
||||
|
||||
def visit_ClassDef(self, child):
|
||||
return None
|
||||
for decorator in child.decorator_list:
|
||||
self.visit(decorator)
|
||||
for base in child.bases:
|
||||
self.visit(base)
|
||||
for keyword in child.keywords:
|
||||
self.visit(keyword.value)
|
||||
for type_param in getattr(child, "type_params", ()):
|
||||
self.visit(type_param)
|
||||
for stmt in child.body:
|
||||
self.visit(stmt)
|
||||
|
||||
def visit_Lambda(self, child):
|
||||
return None
|
||||
self.visit(child.args)
|
||||
|
||||
def visit_NamedExpr(self, child):
|
||||
names.update(_target_names(child.target))
|
||||
@@ -766,6 +788,19 @@ def _class_alias_sources(value, class_aliases, class_bindings):
|
||||
return set()
|
||||
|
||||
|
||||
def _update_class_alias_from_unpack(target, value, class_aliases, class_bindings):
|
||||
if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)):
|
||||
return
|
||||
if len(target.elts) != len(value.elts):
|
||||
return
|
||||
for target_item, value_item in zip(target.elts, value.elts):
|
||||
if not isinstance(target_item, ast.Name):
|
||||
continue
|
||||
sources = _class_alias_sources(value_item, class_aliases, class_bindings)
|
||||
if sources:
|
||||
class_aliases[target_item.id] = sources
|
||||
|
||||
|
||||
def _update_class_aliases(stmt, class_aliases, class_bindings):
|
||||
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
|
||||
for name in rebound_names:
|
||||
@@ -780,6 +815,8 @@ def _update_class_aliases(stmt, class_aliases, class_bindings):
|
||||
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
|
||||
if sources:
|
||||
class_aliases[stmt.targets[0].id] = sources
|
||||
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
|
||||
_update_class_alias_from_unpack(stmt.targets[0], stmt.value, class_aliases, class_bindings)
|
||||
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
|
||||
if sources:
|
||||
@@ -930,9 +967,12 @@ def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_e
|
||||
inputs = {}
|
||||
required = []
|
||||
for section in ("required", "optional"):
|
||||
values = input_types.get(section) or {}
|
||||
if not isinstance(values, dict):
|
||||
return None
|
||||
if section in input_types:
|
||||
values = input_types[section]
|
||||
if not isinstance(values, dict):
|
||||
return None
|
||||
else:
|
||||
values = {}
|
||||
for name, spec in values.items():
|
||||
inputs[str(name)] = normalise_input_spec(spec)
|
||||
if section == "required":
|
||||
|
||||
Reference in New Issue
Block a user