Invalidate mappings on duplicate keys and class aliases
This commit is contained in:
@@ -1730,6 +1730,110 @@ PatchedInputTypesNode.INPUT_TYPES = build_inputs
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_duplicate_node_mapping_key_with_dynamic_value_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
def build_node():
|
||||||
|
return object()
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateMappingKeyNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DuplicateMappingKeyNode": DuplicateMappingKeyNode,
|
||||||
|
"DuplicateMappingKeyNode": build_node(),
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "duplicate-mapping-key-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_module_class_alias_patch_after_mapping_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class AliasPatchedNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"AliasPatchedNode": AliasPatchedNode,
|
||||||
|
}
|
||||||
|
Alias = AliasPatchedNode
|
||||||
|
Alias.RETURN_TYPES = ("MASK",)
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "alias-patched-node-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_definition_time_class_attribute_mutation_after_mapping_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class DefinitionTimeMutatedMappedNode:
|
||||||
|
RETURN_TYPES = ["IMAGE"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DefinitionTimeMutatedMappedNode": DefinitionTimeMutatedMappedNode,
|
||||||
|
}
|
||||||
|
def helper(x=DefinitionTimeMutatedMappedNode.RETURN_TYPES.clear()):
|
||||||
|
pass
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "definition-time-mutated-mapped-node-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_unhashable_node_mapping_key_skips_repo_without_raising(self):
|
||||||
|
source = '''
|
||||||
|
KEY = ["UnhashableMappingKeyNode"]
|
||||||
|
|
||||||
|
|
||||||
|
class UnhashableMappingKeyNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
KEY: UnhashableMappingKeyNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "unhashable-mapping-key-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_mutated_node_class_mapping_skips_node(self):
|
def test_mutated_node_class_mapping_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class MutatedMappingNode:
|
class MutatedMappingNode:
|
||||||
|
|||||||
@@ -107,17 +107,35 @@ def _class_attribute_mutation_target_names(stmt):
|
|||||||
names = set()
|
names = set()
|
||||||
|
|
||||||
class AttributeMutationVisitor(ast.NodeVisitor):
|
class AttributeMutationVisitor(ast.NodeVisitor):
|
||||||
|
def _visit_function_definition_expressions(self, node):
|
||||||
|
for decorator in node.decorator_list:
|
||||||
|
self.visit(decorator)
|
||||||
|
self.visit(node.args)
|
||||||
|
if node.returns is not None:
|
||||||
|
self.visit(node.returns)
|
||||||
|
for type_param in getattr(node, "type_params", ()):
|
||||||
|
self.visit(type_param)
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
return None
|
self._visit_function_definition_expressions(node)
|
||||||
|
|
||||||
def visit_AsyncFunctionDef(self, node):
|
def visit_AsyncFunctionDef(self, node):
|
||||||
return None
|
self._visit_function_definition_expressions(node)
|
||||||
|
|
||||||
def visit_ClassDef(self, node):
|
def visit_ClassDef(self, node):
|
||||||
return None
|
for decorator in node.decorator_list:
|
||||||
|
self.visit(decorator)
|
||||||
|
for base in node.bases:
|
||||||
|
self.visit(base)
|
||||||
|
for keyword in node.keywords:
|
||||||
|
self.visit(keyword.value)
|
||||||
|
for type_param in getattr(node, "type_params", ()):
|
||||||
|
self.visit(type_param)
|
||||||
|
for child in node.body:
|
||||||
|
self.visit(child)
|
||||||
|
|
||||||
def visit_Lambda(self, node):
|
def visit_Lambda(self, node):
|
||||||
return None
|
self.visit(node.args)
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
@@ -698,20 +716,73 @@ def _module_dict_entries(node, env, class_bindings, value_converter):
|
|||||||
for key, value in zip(node.keys, node.values):
|
for key, value in zip(node.keys, node.values):
|
||||||
if key is None:
|
if key is None:
|
||||||
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
||||||
|
key_value = _literal(key, env)
|
||||||
|
try:
|
||||||
|
hash(key_value)
|
||||||
|
except TypeError as exc:
|
||||||
|
raise UnsupportedStaticExpression("unhashable dict key") from exc
|
||||||
|
if key_value in result:
|
||||||
|
raise UnsupportedStaticExpression("duplicate dict key")
|
||||||
converted_value = value_converter(value, env, class_bindings)
|
converted_value = value_converter(value, env, class_bindings)
|
||||||
if converted_value is None:
|
if converted_value is None:
|
||||||
continue
|
raise UnsupportedStaticExpression("unsupported dict value")
|
||||||
result[_literal(key, env)] = converted_value
|
result[key_value] = converted_value
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _class_alias_sources(value, class_aliases, class_bindings):
|
||||||
|
if not isinstance(value, ast.Name):
|
||||||
|
return set()
|
||||||
|
if value.id in class_aliases:
|
||||||
|
return set(class_aliases[value.id])
|
||||||
|
if value.id in class_bindings:
|
||||||
|
return {value.id}
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def _update_class_aliases(stmt, class_aliases, class_bindings):
|
||||||
|
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
|
||||||
|
for name in rebound_names:
|
||||||
|
class_aliases.pop(name, None)
|
||||||
|
|
||||||
|
if isinstance(stmt, ast.ClassDef):
|
||||||
|
if stmt.name in class_bindings and not stmt.decorator_list:
|
||||||
|
class_aliases[stmt.name] = {stmt.name}
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
|
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
|
||||||
|
if sources:
|
||||||
|
class_aliases[stmt.targets[0].id] = sources
|
||||||
|
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||||
|
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
|
||||||
|
if sources:
|
||||||
|
class_aliases[stmt.target.id] = sources
|
||||||
|
|
||||||
|
|
||||||
|
def _expanded_class_attribute_names(names, class_aliases):
|
||||||
|
expanded = set(names)
|
||||||
|
for name in names:
|
||||||
|
expanded.update(class_aliases.get(name, ()))
|
||||||
|
return expanded
|
||||||
|
|
||||||
|
|
||||||
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
|
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
|
||||||
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
|
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
|
||||||
value = _MISSING
|
value = _MISSING
|
||||||
env = {}
|
env = {}
|
||||||
class_bindings = {}
|
class_bindings = {}
|
||||||
|
class_aliases = {}
|
||||||
|
|
||||||
|
def advance_module_state(stmt):
|
||||||
|
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||||
|
_update_class_aliases(stmt, class_aliases, class_bindings)
|
||||||
|
|
||||||
for stmt in tree.body:
|
for stmt in tree.body:
|
||||||
class_attr_names = _class_attribute_mutation_target_names(stmt)
|
class_attr_names = _expanded_class_attribute_names(
|
||||||
|
_class_attribute_mutation_target_names(stmt),
|
||||||
|
class_aliases,
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
value not in (_MISSING, _INVALID)
|
value not in (_MISSING, _INVALID)
|
||||||
and class_attr_names
|
and class_attr_names
|
||||||
@@ -724,7 +795,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
if not _name_is_assigned(stmt, name):
|
if not _name_is_assigned(stmt, name):
|
||||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
try:
|
try:
|
||||||
@@ -733,13 +804,13 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
value = _INVALID
|
value = _INVALID
|
||||||
else:
|
else:
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, ast.AnnAssign):
|
if isinstance(stmt, ast.AnnAssign):
|
||||||
if not _name_is_assigned(stmt, name):
|
if not _name_is_assigned(stmt, name):
|
||||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
if isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||||
try:
|
try:
|
||||||
@@ -748,39 +819,39 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
value = _INVALID
|
value = _INVALID
|
||||||
else:
|
else:
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, ast.AugAssign):
|
if isinstance(stmt, ast.AugAssign):
|
||||||
if _name_is_assigned(stmt, name):
|
if _name_is_assigned(stmt, name):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, ast.Delete):
|
if isinstance(stmt, ast.Delete):
|
||||||
if name in _delete_target_names(stmt):
|
if name in _delete_target_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, ast.Expr):
|
if isinstance(stmt, ast.Expr):
|
||||||
if name in _mutating_call_target_names(stmt):
|
if name in _mutating_call_target_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
if name in _bound_names(stmt):
|
if name in _bound_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if isinstance(stmt, _CONTROL_FLOW_TYPES):
|
if isinstance(stmt, _CONTROL_FLOW_TYPES):
|
||||||
if name in _assigned_names_in_control_flow(stmt):
|
if name in _assigned_names_in_control_flow(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
if _has_wildcard_import_in_control_flow(stmt):
|
if _has_wildcard_import_in_control_flow(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if _has_wildcard_import(stmt):
|
if _has_wildcard_import(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
continue
|
continue
|
||||||
if name in _bound_names(stmt):
|
if name in _bound_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
advance_module_state(stmt)
|
||||||
if value in (_MISSING, _INVALID):
|
if value in (_MISSING, _INVALID):
|
||||||
return {}
|
return {}
|
||||||
return value
|
return value
|
||||||
|
|||||||
Reference in New Issue
Block a user