Extract builder nodes
This commit is contained in:
+39
-10
@@ -48,10 +48,12 @@ def _assignment_dict(path: Path, name: str) -> dict[str, Any]:
|
||||
|
||||
def _class_return_names(path: Path) -> dict[str, tuple[str, ...]]:
|
||||
tree = ast.parse(path.read_text(encoding="utf-8"))
|
||||
result: dict[str, tuple[str, ...]] = {}
|
||||
classes: dict[str, tuple[list[str], tuple[str, ...]]] = {}
|
||||
for node in tree.body:
|
||||
if not isinstance(node, ast.ClassDef) or not node.name.startswith("SxCP"):
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
|
||||
return_names: tuple[str, ...] = ()
|
||||
for item in node.body:
|
||||
if not isinstance(item, ast.Assign):
|
||||
continue
|
||||
@@ -59,7 +61,29 @@ def _class_return_names(path: Path) -> dict[str, tuple[str, ...]]:
|
||||
continue
|
||||
value = _literal_or_none(item.value)
|
||||
if isinstance(value, tuple) and all(isinstance(part, str) for part in value):
|
||||
result[node.name] = value
|
||||
return_names = value
|
||||
classes[node.name] = (bases, return_names)
|
||||
|
||||
def resolve(class_name: str, seen: set[str] | None = None) -> tuple[str, ...]:
|
||||
seen = seen or set()
|
||||
if class_name in seen:
|
||||
return ()
|
||||
seen.add(class_name)
|
||||
bases, return_names = classes.get(class_name, ([], ()))
|
||||
if return_names:
|
||||
return return_names
|
||||
for base_name in bases:
|
||||
inherited = resolve(base_name, seen)
|
||||
if inherited:
|
||||
return inherited
|
||||
return ()
|
||||
|
||||
result: dict[str, tuple[str, ...]] = {}
|
||||
for class_name in classes:
|
||||
if class_name.startswith("SxCP"):
|
||||
return_names = resolve(class_name)
|
||||
if return_names:
|
||||
result[class_name] = return_names
|
||||
return result
|
||||
|
||||
|
||||
@@ -97,6 +121,12 @@ def _category_json_paths() -> list[Path]:
|
||||
return sorted((ROOT / "categories").glob("*.json"))
|
||||
|
||||
|
||||
def _node_python_paths() -> list[Path]:
|
||||
paths = [ROOT / "__init__.py", ROOT / "loop_nodes.py"]
|
||||
paths.extend(sorted(ROOT.glob("node_*.py")))
|
||||
return [path for path in paths if path.exists()]
|
||||
|
||||
|
||||
def _load_category_json(path: Path) -> dict[str, Any]:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else {}
|
||||
@@ -246,14 +276,13 @@ def print_table(headers: tuple[str, ...], rows: list[tuple[Any, ...]]) -> None:
|
||||
|
||||
|
||||
def main() -> int:
|
||||
init_path = ROOT / "__init__.py"
|
||||
loop_path = ROOT / "loop_nodes.py"
|
||||
category_paths = _category_json_paths()
|
||||
display = _assignment_dict(init_path, "NODE_DISPLAY_NAME_MAPPINGS")
|
||||
loop_display = _assignment_dict(loop_path, "LOOP_NODE_DISPLAY_NAME_MAPPINGS")
|
||||
display.update(loop_display)
|
||||
returns = _class_return_names(init_path)
|
||||
returns.update(_class_return_names(loop_path))
|
||||
display: dict[str, Any] = {}
|
||||
returns: dict[str, tuple[str, ...]] = {}
|
||||
for path in _node_python_paths():
|
||||
display.update(_assignment_dict(path, "NODE_DISPLAY_NAME_MAPPINGS"))
|
||||
display.update(_assignment_dict(path, "LOOP_NODE_DISPLAY_NAME_MAPPINGS"))
|
||||
returns.update(_class_return_names(path))
|
||||
|
||||
print("# Node Display Map")
|
||||
node_rows = []
|
||||
|
||||
@@ -2169,6 +2169,63 @@ def smoke_node_insta_registration() -> None:
|
||||
_expect(pair.get("options", {}).get("hardcore_cast") == "couple", "Insta/OF Prompt Pair lost options metadata")
|
||||
|
||||
|
||||
def smoke_node_builder_registration() -> None:
|
||||
required_nodes = [
|
||||
"SxCPPromptBuilder",
|
||||
"SxCPPromptBuilderFromConfigs",
|
||||
]
|
||||
for node_name in required_nodes:
|
||||
_expect(node_name in sxcp_nodes.NODE_CLASS_MAPPINGS, f"{node_name} missing from node registry")
|
||||
_expect(node_name in sxcp_nodes.NODE_DISPLAY_NAME_MAPPINGS, f"{node_name} missing from display registry")
|
||||
|
||||
builder_node = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPPromptBuilder"]
|
||||
builder_inputs = builder_node.INPUT_TYPES().get("required") or {}
|
||||
_expect("category" in builder_inputs, "Prompt Builder lost category input")
|
||||
_expect("tooltip" in builder_inputs["category"][1], "Prompt Builder tooltip injection missing")
|
||||
direct_output = builder_node().build(
|
||||
"woman",
|
||||
"random",
|
||||
1,
|
||||
41,
|
||||
123,
|
||||
"full",
|
||||
"any",
|
||||
"standard",
|
||||
True,
|
||||
0.5,
|
||||
0.0,
|
||||
"random",
|
||||
1,
|
||||
0,
|
||||
-1,
|
||||
-1,
|
||||
Trigger,
|
||||
True,
|
||||
)
|
||||
direct_row = json.loads(direct_output[3])
|
||||
_expect_row_base(direct_row, "node_builder.direct_row")
|
||||
_expect(direct_output[0] == direct_row.get("prompt"), "Prompt Builder prompt output drifted from metadata")
|
||||
_expect(direct_output[4] == direct_row.get("main_category"), "Prompt Builder category output drifted from metadata")
|
||||
_expect_trigger_once("node_builder.direct_prompt", direct_output[0], Trigger)
|
||||
|
||||
config_node = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPPromptBuilderFromConfigs"]
|
||||
config_inputs = config_node.INPUT_TYPES()
|
||||
_expect("category_config" in (config_inputs.get("optional") or {}), "Prompt Builder From Configs lost category_config input")
|
||||
config_output = config_node().build(
|
||||
1,
|
||||
41,
|
||||
123,
|
||||
category_config=pb.build_category_config_json("woman", "random"),
|
||||
cast_config=pb.build_cast_config_json("solo_woman", 1, 0),
|
||||
generation_profile=pb.build_generation_profile_json(profile="balanced"),
|
||||
)
|
||||
config_row = json.loads(config_output[3])
|
||||
_expect_row_base(config_row, "node_builder.config_row")
|
||||
_expect(config_output[0] == config_row.get("prompt"), "Prompt Builder From Configs prompt output drifted from metadata")
|
||||
_expect(config_output[4] == config_row.get("main_category"), "Prompt Builder From Configs category output drifted from metadata")
|
||||
_expect_text("node_builder.config_caption", config_output[2], 20)
|
||||
|
||||
|
||||
def smoke_node_profile_filter_registration() -> None:
|
||||
required_nodes = [
|
||||
"SxCPGenerationProfile",
|
||||
@@ -2280,6 +2337,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("node_hardcore_position_registration", smoke_node_hardcore_position_registration),
|
||||
("node_formatter_registration", smoke_node_formatter_registration),
|
||||
("node_insta_registration", smoke_node_insta_registration),
|
||||
("node_builder_registration", smoke_node_builder_registration),
|
||||
("node_profile_filter_registration", smoke_node_profile_filter_registration),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user