diff --git a/gates/gate.py b/gates/gate.py index f32a034..699d5aa 100644 --- a/gates/gate.py +++ b/gates/gate.py @@ -47,13 +47,14 @@ class ImageGate: def run(self, image, routes, unique_id): from comfy_execution.graph_utils import ExecutionBlocker from . import gate_server + import comfy.model_management as mm gate_bus.GateBus.arm(unique_id) gate_server.send_preview(unique_id, image, routes) try: - chosen_1 = gate_bus.GateBus.wait(unique_id) + chosen_1 = gate_bus.GateBus.wait( + unique_id, should_cancel=mm.processing_interrupted) except gate_bus.GateCancelled: - import comfy.model_management as mm raise mm.InterruptProcessingException() mask = mask_from_stash(gate_bus.GateBus.pop_mask(unique_id), image) diff --git a/gates/gate_bus.py b/gates/gate_bus.py index fd735e2..b2499f8 100644 --- a/gates/gate_bus.py +++ b/gates/gate_bus.py @@ -27,10 +27,10 @@ class GateBus: cls.messages[str(node_id)] = int(message) @classmethod - def wait(cls, node_id, period=0.1): + def wait(cls, node_id, period=0.1, should_cancel=None): sid = str(node_id) while sid not in cls.messages: - if cls.cancelled: + if cls.cancelled or (should_cancel is not None and should_cancel()): cls.cancelled = False raise GateCancelled() time.sleep(period) diff --git a/tests/test_gate_bus.py b/tests/test_gate_bus.py index 46376a6..a7fe16f 100644 --- a/tests/test_gate_bus.py +++ b/tests/test_gate_bus.py @@ -65,3 +65,10 @@ def test_wait_payload_should_cancel_raises(): gb.GateBus.arm("p") with pytest.raises(gb.GateCancelled): gb.GateBus.wait_payload("p", should_cancel=lambda: True) + +def test_wait_should_cancel_raises(): + # image gate: ComfyUI Interrupt (should_cancel) must abort the wait too + gb.GateBus.arm("7") + with pytest.raises(gb.GateCancelled): + gb.GateBus.wait("7", should_cancel=lambda: True) + assert gb.GateBus.cancelled is False