Skip nontrivial class creation signatures
This commit is contained in:
@@ -1889,6 +1889,63 @@ NODE_CLASS_MAPPINGS = {
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_class_with_base_mapping_skips_node(self):
|
||||
source = '''
|
||||
class Base:
|
||||
def __init_subclass__(cls):
|
||||
cls.RETURN_TYPES = ("MASK",)
|
||||
|
||||
|
||||
class HookedNode(Base):
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HookedNode": HookedNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "hooked-base-class-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_class_with_metaclass_mapping_skips_node(self):
|
||||
source = '''
|
||||
class Meta(type):
|
||||
def __new__(mcls, name, bases, attrs):
|
||||
attrs["RETURN_TYPES"] = ("MASK",)
|
||||
return super().__new__(mcls, name, bases, attrs)
|
||||
|
||||
|
||||
class MetaNode(metaclass=Meta):
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MetaNode": MetaNode,
|
||||
}
|
||||
'''
|
||||
result = self._extract_source(source, "metaclass-pack")
|
||||
|
||||
self.assertEqual({}, result["nodes"])
|
||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||
|
||||
def test_node_mapping_key_uses_assignment_time_env(self):
|
||||
source = '''
|
||||
KEY = "Original"
|
||||
|
||||
@@ -605,6 +605,16 @@ def _invalidate_class_bindings(class_bindings, names):
|
||||
class_bindings.pop(name, None)
|
||||
|
||||
|
||||
def _is_trivially_safe_class_def(stmt):
|
||||
return (
|
||||
isinstance(stmt, ast.ClassDef)
|
||||
and not stmt.decorator_list
|
||||
and not stmt.bases
|
||||
and not stmt.keywords
|
||||
and not getattr(stmt, "type_params", ())
|
||||
)
|
||||
|
||||
|
||||
def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
|
||||
names = _mutating_call_target_names(stmt)
|
||||
if _DYNAMIC_NAMESPACE_MUTATION in names:
|
||||
@@ -616,10 +626,10 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
|
||||
_invalidate_env_names(env, names)
|
||||
if isinstance(stmt, ast.ClassDef):
|
||||
if class_bindings is not None:
|
||||
if stmt.decorator_list:
|
||||
class_bindings.pop(stmt.name, None)
|
||||
else:
|
||||
if _is_trivially_safe_class_def(stmt):
|
||||
class_bindings[stmt.name] = (stmt, dict(env))
|
||||
else:
|
||||
class_bindings.pop(stmt.name, None)
|
||||
_invalidate_env_name(env, stmt.name)
|
||||
return
|
||||
if isinstance(stmt, ast.Assign):
|
||||
@@ -999,7 +1009,7 @@ def _update_class_aliases(stmt, class_aliases, class_bindings):
|
||||
class_aliases.pop(name, None)
|
||||
|
||||
if isinstance(stmt, ast.ClassDef):
|
||||
if stmt.name in class_bindings and not stmt.decorator_list:
|
||||
if stmt.name in class_bindings and _is_trivially_safe_class_def(stmt):
|
||||
class_aliases[stmt.name] = {stmt.name}
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user