131 lines
5.4 KiB
JavaScript
131 lines
5.4 KiB
JavaScript
import { app } from "../../scripts/app.js";
|
|
|
|
const EXTENSION = "ethanfel.prompt_builder.loop_slots";
|
|
const LOOP_NODES = new Set(["SxCPForLoopStart", "SxCPForLoopEnd", "SxCPWhileLoopStart", "SxCPWhileLoopEnd"]);
|
|
const MAX_CARRY = 18;
|
|
|
|
function isCarryInput(input) {
|
|
return /^initial_value\d+$/.test(input?.name || "");
|
|
}
|
|
|
|
function isCarryOutput(output) {
|
|
return /^value\d+$/.test(output?.name || "");
|
|
}
|
|
|
|
function carryNumber(slot) {
|
|
const match = String(slot?.name || "").match(/\d+$/);
|
|
return match ? Number(match[0]) : -1;
|
|
}
|
|
|
|
function resizeNode(node) {
|
|
const size = node.computeSize?.();
|
|
if (size) node.setSize?.(size);
|
|
app.graph?.setDirtyCanvas(true, true);
|
|
}
|
|
|
|
function getCarryLimit(nodeName) {
|
|
return nodeName === "SxCPWhileLoopStart" || nodeName === "SxCPWhileLoopEnd" ? 19 : MAX_CARRY;
|
|
}
|
|
|
|
function getFirstCarry(nodeName) {
|
|
return nodeName === "SxCPWhileLoopStart" || nodeName === "SxCPWhileLoopEnd" ? 0 : 1;
|
|
}
|
|
|
|
function addCarryPair(node, nodeName, number) {
|
|
if (number > getCarryLimit(nodeName)) return;
|
|
const inputName = `initial_value${number}`;
|
|
const outputName = `value${number}`;
|
|
if (!node.inputs?.some((input) => input.name === inputName)) node.addInput(inputName, "*");
|
|
if (!node.outputs?.some((output) => output.name === outputName)) node.addOutput(outputName, "*");
|
|
}
|
|
|
|
function removeCarryPair(node, number) {
|
|
const inputIndex = node.inputs?.findIndex((input) => input.name === `initial_value${number}`) ?? -1;
|
|
if (inputIndex >= 0 && !node.inputs[inputIndex]?.link) node.removeInput(inputIndex);
|
|
const outputIndex = node.outputs?.findIndex((output) => output.name === `value${number}`) ?? -1;
|
|
if (outputIndex >= 0 && !(node.outputs[outputIndex]?.links?.length)) node.removeOutput(outputIndex);
|
|
}
|
|
|
|
function trimCarryTail(node, nodeName) {
|
|
const first = getFirstCarry(nodeName);
|
|
for (let number = getCarryLimit(nodeName); number > first; number--) {
|
|
const input = node.inputs?.find((slot) => slot.name === `initial_value${number}`);
|
|
const output = node.outputs?.find((slot) => slot.name === `value${number}`);
|
|
const previousInput = node.inputs?.find((slot) => slot.name === `initial_value${number - 1}`);
|
|
const previousOutput = node.outputs?.find((slot) => slot.name === `value${number - 1}`);
|
|
const currentUsed = Boolean(input?.link || output?.links?.length);
|
|
const previousUsed = Boolean(previousInput?.link || previousOutput?.links?.length);
|
|
if (!currentUsed && !previousUsed) removeCarryPair(node, number);
|
|
}
|
|
}
|
|
|
|
function setupNodeSlots(node, nodeName) {
|
|
const first = getFirstCarry(nodeName);
|
|
const limit = getCarryLimit(nodeName);
|
|
addCarryPair(node, nodeName, first);
|
|
for (let number = first + 1; number <= limit; number++) {
|
|
const input = node.inputs?.find((slot) => slot.name === `initial_value${number}`);
|
|
const output = node.outputs?.find((slot) => slot.name === `value${number}`);
|
|
if (!input?.link && !output?.links?.length) removeCarryPair(node, number);
|
|
}
|
|
for (const output of node.outputs || []) {
|
|
if (output.name === "flow") output.shape = 5;
|
|
}
|
|
for (const input of node.inputs || []) {
|
|
if (input.name === "flow") input.shape = 5;
|
|
}
|
|
trimCarryTail(node, nodeName);
|
|
resizeNode(node);
|
|
}
|
|
|
|
function maybeGrow(node, nodeName) {
|
|
const carryInputs = (node.inputs || []).filter(isCarryInput);
|
|
const carryOutputs = (node.outputs || []).filter(isCarryOutput);
|
|
const lastInput = carryInputs.reduce((max, input) => Math.max(max, carryNumber(input)), -1);
|
|
const lastOutput = carryOutputs.reduce((max, output) => Math.max(max, carryNumber(output)), -1);
|
|
const last = Math.max(lastInput, lastOutput, getFirstCarry(nodeName));
|
|
const input = node.inputs?.find((slot) => slot.name === `initial_value${last}`);
|
|
const output = node.outputs?.find((slot) => slot.name === `value${last}`);
|
|
if ((input?.link || output?.links?.length) && last < getCarryLimit(nodeName)) {
|
|
addCarryPair(node, nodeName, last + 1);
|
|
resizeNode(node);
|
|
}
|
|
}
|
|
|
|
app.registerExtension({
|
|
name: EXTENSION,
|
|
|
|
async beforeRegisterNodeDef(nodeType, nodeData) {
|
|
if (!LOOP_NODES.has(nodeData.name)) return;
|
|
|
|
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
|
nodeType.prototype.onNodeCreated = function () {
|
|
const result = onNodeCreated?.apply(this, arguments);
|
|
queueMicrotask(() => setupNodeSlots(this, nodeData.name));
|
|
return result;
|
|
};
|
|
|
|
const onConfigure = nodeType.prototype.onConfigure;
|
|
nodeType.prototype.onConfigure = function () {
|
|
const result = onConfigure?.apply(this, arguments);
|
|
queueMicrotask(() => setupNodeSlots(this, nodeData.name));
|
|
return result;
|
|
};
|
|
|
|
const onConnectionsChange = nodeType.prototype.onConnectionsChange;
|
|
nodeType.prototype.onConnectionsChange = function (type, index, connected, linkInfo) {
|
|
const result = onConnectionsChange?.apply(this, arguments);
|
|
if (!linkInfo) return result;
|
|
const slot = type === LiteGraph.INPUT ? this.inputs?.[index] : this.outputs?.[index];
|
|
if (isCarryInput(slot) || isCarryOutput(slot)) {
|
|
if (connected) maybeGrow(this, nodeData.name);
|
|
else {
|
|
trimCarryTail(this, nodeData.name);
|
|
resizeNode(this);
|
|
}
|
|
}
|
|
return result;
|
|
};
|
|
},
|
|
});
|