Invalidate mappings on duplicate keys and class aliases
This commit is contained in:
@@ -1730,6 +1730,110 @@ PatchedInputTypesNode.INPUT_TYPES = build_inputs
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_duplicate_node_mapping_key_with_dynamic_value_skips_node(self):
|
||||
source = '''
|
||||
def build_node():
|
||||
return object()
|
||||
|
||||
|
||||
class DuplicateMappingKeyNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DuplicateMappingKeyNode": DuplicateMappingKeyNode,
|
||||
"DuplicateMappingKeyNode": build_node(),
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "duplicate-mapping-key-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_module_class_alias_patch_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class AliasPatchedNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"AliasPatchedNode": AliasPatchedNode,
|
||||
}
|
||||
Alias = AliasPatchedNode
|
||||
Alias.RETURN_TYPES = ("MASK",)
|
||||
'''
|
||||
result = self._extract_source(source, "alias-patched-node-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_definition_time_class_attribute_mutation_after_mapping_skips_node(self):
|
||||
source = '''
|
||||
class DefinitionTimeMutatedMappedNode:
|
||||
RETURN_TYPES = ["IMAGE"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DefinitionTimeMutatedMappedNode": DefinitionTimeMutatedMappedNode,
|
||||
}
|
||||
def helper(x=DefinitionTimeMutatedMappedNode.RETURN_TYPES.clear()):
|
||||
pass
|
||||
'''
|
||||
result = self._extract_source(source, "definition-time-mutated-mapped-node-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_unhashable_node_mapping_key_skips_repo_without_raising(self):
|
||||
source = '''
|
||||
KEY = ["UnhashableMappingKeyNode"]
|
||||
|
||||
|
||||
class UnhashableMappingKeyNode:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
KEY: UnhashableMappingKeyNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "unhashable-mapping-key-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_mutated_node_class_mapping_skips_node(self):
|
||||
source = '''
|
||||
class MutatedMappingNode:
|
||||
|
||||
@@ -107,17 +107,35 @@ def _class_attribute_mutation_target_names(stmt):
|
||||
names = set()
|
||||
|
||||
class AttributeMutationVisitor(ast.NodeVisitor):
|
||||
def _visit_function_definition_expressions(self, node):
|
||||
for decorator in node.decorator_list:
|
||||
self.visit(decorator)
|
||||
self.visit(node.args)
|
||||
if node.returns is not None:
|
||||
self.visit(node.returns)
|
||||
for type_param in getattr(node, "type_params", ()):
|
||||
self.visit(type_param)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
return None
|
||||
self._visit_function_definition_expressions(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
return None
|
||||
self._visit_function_definition_expressions(node)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
return None
|
||||
for decorator in node.decorator_list:
|
||||
self.visit(decorator)
|
||||
for base in node.bases:
|
||||
self.visit(base)
|
||||
for keyword in node.keywords:
|
||||
self.visit(keyword.value)
|
||||
for type_param in getattr(node, "type_params", ()):
|
||||
self.visit(type_param)
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
return None
|
||||
self.visit(node.args)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
for target in node.targets:
|
||||
@@ -698,20 +716,73 @@ def _module_dict_entries(node, env, class_bindings, value_converter):
|
||||
for key, value in zip(node.keys, node.values):
|
||||
if key is None:
|
||||
raise UnsupportedStaticExpression("dict unpacking is not supported")
|
||||
key_value = _literal(key, env)
|
||||
try:
|
||||
hash(key_value)
|
||||
except TypeError as exc:
|
||||
raise UnsupportedStaticExpression("unhashable dict key") from exc
|
||||
if key_value in result:
|
||||
raise UnsupportedStaticExpression("duplicate dict key")
|
||||
converted_value = value_converter(value, env, class_bindings)
|
||||
if converted_value is None:
|
||||
continue
|
||||
result[_literal(key, env)] = converted_value
|
||||
raise UnsupportedStaticExpression("unsupported dict value")
|
||||
result[key_value] = converted_value
|
||||
return result
|
||||
|
||||
|
||||
def _class_alias_sources(value, class_aliases, class_bindings):
|
||||
if not isinstance(value, ast.Name):
|
||||
return set()
|
||||
if value.id in class_aliases:
|
||||
return set(class_aliases[value.id])
|
||||
if value.id in class_bindings:
|
||||
return {value.id}
|
||||
return set()
|
||||
|
||||
|
||||
def _update_class_aliases(stmt, class_aliases, class_bindings):
|
||||
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
|
||||
for name in rebound_names:
|
||||
class_aliases.pop(name, None)
|
||||
|
||||
if isinstance(stmt, ast.ClassDef):
|
||||
if stmt.name in class_bindings and not stmt.decorator_list:
|
||||
class_aliases[stmt.name] = {stmt.name}
|
||||
return
|
||||
|
||||
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
|
||||
if sources:
|
||||
class_aliases[stmt.targets[0].id] = sources
|
||||
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||
sources = _class_alias_sources(stmt.value, class_aliases, class_bindings)
|
||||
if sources:
|
||||
class_aliases[stmt.target.id] = sources
|
||||
|
||||
|
||||
def _expanded_class_attribute_names(names, class_aliases):
|
||||
expanded = set(names)
|
||||
for name in names:
|
||||
expanded.update(class_aliases.get(name, ()))
|
||||
return expanded
|
||||
|
||||
|
||||
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None):
|
||||
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
|
||||
value = _MISSING
|
||||
env = {}
|
||||
class_bindings = {}
|
||||
class_aliases = {}
|
||||
|
||||
def advance_module_state(stmt):
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
_update_class_aliases(stmt, class_aliases, class_bindings)
|
||||
|
||||
for stmt in tree.body:
|
||||
class_attr_names = _class_attribute_mutation_target_names(stmt)
|
||||
class_attr_names = _expanded_class_attribute_names(
|
||||
_class_attribute_mutation_target_names(stmt),
|
||||
class_aliases,
|
||||
)
|
||||
if (
|
||||
value not in (_MISSING, _INVALID)
|
||||
and class_attr_names
|
||||
@@ -724,7 +795,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
||||
if not _name_is_assigned(stmt, name):
|
||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||
try:
|
||||
@@ -733,13 +804,13 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
||||
value = _INVALID
|
||||
else:
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
if not _name_is_assigned(stmt, name):
|
||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||
try:
|
||||
@@ -748,39 +819,39 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
||||
value = _INVALID
|
||||
else:
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if isinstance(stmt, ast.AugAssign):
|
||||
if _name_is_assigned(stmt, name):
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if isinstance(stmt, ast.Delete):
|
||||
if name in _delete_target_names(stmt):
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if isinstance(stmt, ast.Expr):
|
||||
if name in _mutating_call_target_names(stmt):
|
||||
value = _INVALID
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if isinstance(stmt, _CONTROL_FLOW_TYPES):
|
||||
if name in _assigned_names_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
if _has_wildcard_import_in_control_flow(stmt):
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if _has_wildcard_import(stmt):
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
continue
|
||||
if name in _bound_names(stmt):
|
||||
value = _INVALID
|
||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||
advance_module_state(stmt)
|
||||
if value in (_MISSING, _INVALID):
|
||||
return {}
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user