Track mapping and class attribute aliases
This commit is contained in:
@@ -2047,6 +2047,59 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_module_class_attribute_tuple_alias_mutation_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class TupleAttributeAliasNode:
|
||||||
|
RETURN_TYPES = ["IMAGE"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
RET, = (TupleAttributeAliasNode.RETURN_TYPES,)
|
||||||
|
RET.clear()
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TupleAttributeAliasNode": TupleAttributeAliasNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "tuple-attribute-alias-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_module_class_attribute_transitive_alias_mutation_skips_node(self):
|
||||||
|
source = '''
|
||||||
|
class TransitiveAttributeAliasNode:
|
||||||
|
RETURN_TYPES = ["IMAGE"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
RET = TransitiveAttributeAliasNode.RETURN_TYPES
|
||||||
|
ALIAS = RET
|
||||||
|
ALIAS.clear()
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TransitiveAttributeAliasNode": TransitiveAttributeAliasNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "transitive-attribute-alias-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_module_class_attribute_alias_mutation_after_mapping_skips_node(self):
|
def test_module_class_attribute_alias_mutation_after_mapping_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class PostMappingAttributeAliasNode:
|
class PostMappingAttributeAliasNode:
|
||||||
@@ -2172,6 +2225,31 @@ NODE_CLASS_MAPPINGS.clear()
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_unpacked_alias_mutation_invalidates_static_node_mapping(self):
|
||||||
|
source = '''
|
||||||
|
class UnpackedAliasMutatedMappingNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"UnpackedAliasMutatedMappingNode": UnpackedAliasMutatedMappingNode,
|
||||||
|
}
|
||||||
|
ALIAS, = (NODE_CLASS_MAPPINGS,)
|
||||||
|
ALIAS.clear()
|
||||||
|
'''
|
||||||
|
result = self._extract_source(source, "unpacked-alias-mutated-mapping-pack")
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_rhs_mutating_call_to_node_mapping_skips_node(self):
|
def test_rhs_mutating_call_to_node_mapping_skips_node(self):
|
||||||
source = '''
|
source = '''
|
||||||
class RhsMutatedMappingNode:
|
class RhsMutatedMappingNode:
|
||||||
|
|||||||
@@ -894,7 +894,9 @@ def _expanded_class_attribute_names(names, class_aliases):
|
|||||||
return expanded
|
return expanded
|
||||||
|
|
||||||
|
|
||||||
def _class_attribute_alias_sources(value, class_aliases, class_bindings):
|
def _class_attribute_alias_sources(value, class_attribute_aliases, class_aliases, class_bindings):
|
||||||
|
if isinstance(value, ast.Name):
|
||||||
|
return set(class_attribute_aliases.get(value.id, ()))
|
||||||
if not isinstance(value, ast.Attribute) or value.attr not in _CLASS_SIGNATURE_ATTRS:
|
if not isinstance(value, ast.Attribute) or value.attr not in _CLASS_SIGNATURE_ATTRS:
|
||||||
return set()
|
return set()
|
||||||
name = _root_name(value.value)
|
name = _root_name(value.value)
|
||||||
@@ -905,6 +907,30 @@ def _class_attribute_alias_sources(value, class_aliases, class_bindings):
|
|||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def _update_class_attribute_alias_from_unpack(
|
||||||
|
target,
|
||||||
|
value,
|
||||||
|
class_attribute_aliases,
|
||||||
|
class_aliases,
|
||||||
|
class_bindings,
|
||||||
|
):
|
||||||
|
if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)):
|
||||||
|
return
|
||||||
|
if len(target.elts) != len(value.elts):
|
||||||
|
return
|
||||||
|
for target_item, value_item in zip(target.elts, value.elts):
|
||||||
|
if not isinstance(target_item, ast.Name):
|
||||||
|
continue
|
||||||
|
sources = _class_attribute_alias_sources(
|
||||||
|
value_item,
|
||||||
|
class_attribute_aliases,
|
||||||
|
class_aliases,
|
||||||
|
class_bindings,
|
||||||
|
)
|
||||||
|
if sources:
|
||||||
|
class_attribute_aliases[target_item.id] = sources
|
||||||
|
|
||||||
|
|
||||||
def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
|
def _class_attribute_alias_invalidated_names(stmt, class_attribute_aliases):
|
||||||
names = (
|
names = (
|
||||||
_mutating_call_target_names(stmt)
|
_mutating_call_target_names(stmt)
|
||||||
@@ -924,11 +950,29 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
|
|||||||
class_attribute_aliases.pop(name, None)
|
class_attribute_aliases.pop(name, None)
|
||||||
|
|
||||||
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings)
|
sources = _class_attribute_alias_sources(
|
||||||
|
stmt.value,
|
||||||
|
class_attribute_aliases,
|
||||||
|
class_aliases,
|
||||||
|
class_bindings,
|
||||||
|
)
|
||||||
if sources:
|
if sources:
|
||||||
class_attribute_aliases[stmt.targets[0].id] = sources
|
class_attribute_aliases[stmt.targets[0].id] = sources
|
||||||
|
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
|
||||||
|
_update_class_attribute_alias_from_unpack(
|
||||||
|
stmt.targets[0],
|
||||||
|
stmt.value,
|
||||||
|
class_attribute_aliases,
|
||||||
|
class_aliases,
|
||||||
|
class_bindings,
|
||||||
|
)
|
||||||
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||||
sources = _class_attribute_alias_sources(stmt.value, class_aliases, class_bindings)
|
sources = _class_attribute_alias_sources(
|
||||||
|
stmt.value,
|
||||||
|
class_attribute_aliases,
|
||||||
|
class_aliases,
|
||||||
|
class_bindings,
|
||||||
|
)
|
||||||
if sources:
|
if sources:
|
||||||
class_attribute_aliases[stmt.target.id] = sources
|
class_attribute_aliases[stmt.target.id] = sources
|
||||||
|
|
||||||
@@ -942,6 +986,54 @@ def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribu
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _module_dict_alias_sources(value, name, aliases):
|
||||||
|
if not isinstance(value, ast.Name):
|
||||||
|
return set()
|
||||||
|
if value.id == name:
|
||||||
|
return {name}
|
||||||
|
return set(aliases.get(value.id, ()))
|
||||||
|
|
||||||
|
|
||||||
|
def _update_module_dict_alias_from_unpack(target, value, name, aliases):
|
||||||
|
if not isinstance(target, (ast.Tuple, ast.List)) or not isinstance(value, (ast.Tuple, ast.List)):
|
||||||
|
return
|
||||||
|
if len(target.elts) != len(value.elts):
|
||||||
|
return
|
||||||
|
for target_item, value_item in zip(target.elts, value.elts):
|
||||||
|
if not isinstance(target_item, ast.Name):
|
||||||
|
continue
|
||||||
|
sources = _module_dict_alias_sources(value_item, name, aliases)
|
||||||
|
if sources:
|
||||||
|
aliases[target_item.id] = sources
|
||||||
|
|
||||||
|
|
||||||
|
def _module_dict_alias_invalidated(stmt, aliases):
|
||||||
|
names = (
|
||||||
|
_mutating_call_target_names(stmt)
|
||||||
|
| _assignment_target_names(stmt)
|
||||||
|
| _delete_target_names(stmt)
|
||||||
|
| _bound_names(stmt)
|
||||||
|
)
|
||||||
|
return any(name in aliases for name in names)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_module_dict_aliases(stmt, name, aliases):
|
||||||
|
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
|
||||||
|
for rebound_name in rebound_names:
|
||||||
|
aliases.pop(rebound_name, None)
|
||||||
|
|
||||||
|
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
|
||||||
|
sources = _module_dict_alias_sources(stmt.value, name, aliases)
|
||||||
|
if sources:
|
||||||
|
aliases[stmt.targets[0].id] = sources
|
||||||
|
elif isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
|
||||||
|
_update_module_dict_alias_from_unpack(stmt.targets[0], stmt.value, name, aliases)
|
||||||
|
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None:
|
||||||
|
sources = _module_dict_alias_sources(stmt.value, name, aliases)
|
||||||
|
if sources:
|
||||||
|
aliases[stmt.target.id] = sources
|
||||||
|
|
||||||
|
|
||||||
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False):
|
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False):
|
||||||
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
|
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
|
||||||
value = _MISSING
|
value = _MISSING
|
||||||
@@ -949,6 +1041,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
class_bindings = {}
|
class_bindings = {}
|
||||||
class_aliases = {}
|
class_aliases = {}
|
||||||
class_attribute_aliases = {}
|
class_attribute_aliases = {}
|
||||||
|
module_dict_aliases = {}
|
||||||
|
|
||||||
def advance_module_state(stmt):
|
def advance_module_state(stmt):
|
||||||
_invalidate_class_bindings(
|
_invalidate_class_bindings(
|
||||||
@@ -958,6 +1051,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
_apply_module_stmt_to_env(stmt, env, class_bindings)
|
||||||
_update_class_aliases(stmt, class_aliases, class_bindings)
|
_update_class_aliases(stmt, class_aliases, class_bindings)
|
||||||
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
|
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
|
||||||
|
_update_module_dict_aliases(stmt, name, module_dict_aliases)
|
||||||
|
|
||||||
for stmt in tree.body:
|
for stmt in tree.body:
|
||||||
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
|
class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
|
||||||
@@ -969,6 +1063,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
|
|||||||
value = _INVALID
|
value = _INVALID
|
||||||
if name in _mutating_call_target_names(stmt):
|
if name in _mutating_call_target_names(stmt):
|
||||||
value = _INVALID
|
value = _INVALID
|
||||||
|
if _module_dict_alias_invalidated(stmt, module_dict_aliases):
|
||||||
|
value = _INVALID
|
||||||
if isinstance(stmt, ast.Assign):
|
if isinstance(stmt, ast.Assign):
|
||||||
if not _name_is_assigned(stmt, name):
|
if not _name_is_assigned(stmt, name):
|
||||||
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
|
||||||
|
|||||||
Reference in New Issue
Block a user