Harden generated signature matching

This commit is contained in:
2026-07-02 11:47:07 +02:00
parent 38c142d42e
commit 9c17083298
2 changed files with 149 additions and 16 deletions
+82 -11
View File
@@ -362,6 +362,72 @@ def build_index(ctx):
return {"sources": sources, "candidates": candidates, "stats": stats}
def _signature_from_item(it):
inputs_raw = it.get("inputs") or {}
if not isinstance(inputs_raw, dict):
inputs_raw = {}
outputs_raw = it.get("outputs") or []
if not isinstance(outputs_raw, list):
outputs_raw = []
output_names_raw = it.get("output_names") or []
if not isinstance(output_names_raw, list):
output_names_raw = []
inputs = {str(k): str(v) for k, v in inputs_raw.items() if k is not None}
return {
"inputs": inputs,
"required": set(inputs),
"outputs": [str(x) for x in outputs_raw],
"output_names": [str(x) for x in output_names_raw],
}
def _generated_signature_usable(sig):
return isinstance(sig, dict) and isinstance(sig.get("inputs"), dict) and isinstance(sig.get("outputs"), list)
def _normalised_generated_signature(sig):
if not _generated_signature_usable(sig):
return None
try:
inputs = {str(k): str(v) for k, v in sig["inputs"].items() if k is not None}
outputs = [str(x) for x in sig["outputs"]]
required_raw = sig.get("required") or []
if not isinstance(required_raw, (list, set, tuple)):
required_raw = []
output_names_raw = sig.get("output_names") or []
if not isinstance(output_names_raw, list):
output_names_raw = []
return {
"inputs": inputs,
"required": {str(v) for v in required_raw if str(v) in inputs},
"outputs": outputs,
"output_names": [str(x) for x in output_names_raw],
}
except Exception:
return None
def _generated_signature_conflicts(serialized_sig, generated_sig):
if not serialized_sig["inputs"] and not serialized_sig["outputs"]:
return False
generated_inputs = generated_sig["inputs"]
generated_input_types = set(generated_inputs.values())
for name, typ in serialized_sig["inputs"].items():
if name in generated_inputs:
if generated_inputs[name] != typ:
return True
elif typ not in generated_input_types:
return True
if Counter(serialized_sig["outputs"]) - Counter(generated_sig["outputs"]):
return True
return False
def match(ctx, items):
"""
Match a batch of nodes given only their (possibly serialized) signature —
@@ -375,30 +441,35 @@ def match(ctx, items):
Returns a mapping from source node type to candidate list.
"""
out = {}
generated = ctx.get("generated") or _empty_generated_signatures()
generated = ctx.get("generated") or {}
if not isinstance(generated, dict):
generated = {}
generated_sigs = generated.get("sigs") or {}
if not isinstance(generated_sigs, dict):
generated_sigs = {}
generated_meta = generated.get("meta") or {}
if not isinstance(generated_meta, dict):
generated_meta = {}
for it in items:
if not isinstance(it, dict):
continue
t = it.get("type")
if not t or t in out:
continue
gen_sig = generated_sigs.get(t)
if gen_sig is not None:
gen_pack = (generated_meta.get(t) or {}).get("pack")
sig = _signature_from_item(it)
gen_sig = _normalised_generated_signature(generated_sigs.get(t))
if gen_sig is not None and not _generated_signature_conflicts(sig, gen_sig):
gen_meta = generated_meta.get(t) or {}
if not isinstance(gen_meta, dict):
gen_meta = {}
gen_pack = gen_meta.get("pack")
found = _candidates_for(t, gen_sig, gen_pack, ctx)
if found:
out[t] = found
continue
inputs = {k: str(v) for k, v in (it.get("inputs") or {}).items()}
sig = {
"inputs": inputs,
"required": set(inputs),
"outputs": [str(x) for x in (it.get("outputs") or [])],
"output_names": list(it.get("output_names") or []),
}
found = _candidates_for(t, sig, None, ctx)
if found:
out[t] = found