Harden generated signature matching
This commit is contained in:
+82
-11
@@ -362,6 +362,72 @@ def build_index(ctx):
|
||||
return {"sources": sources, "candidates": candidates, "stats": stats}
|
||||
|
||||
|
||||
def _signature_from_item(it):
|
||||
inputs_raw = it.get("inputs") or {}
|
||||
if not isinstance(inputs_raw, dict):
|
||||
inputs_raw = {}
|
||||
outputs_raw = it.get("outputs") or []
|
||||
if not isinstance(outputs_raw, list):
|
||||
outputs_raw = []
|
||||
output_names_raw = it.get("output_names") or []
|
||||
if not isinstance(output_names_raw, list):
|
||||
output_names_raw = []
|
||||
|
||||
inputs = {str(k): str(v) for k, v in inputs_raw.items() if k is not None}
|
||||
return {
|
||||
"inputs": inputs,
|
||||
"required": set(inputs),
|
||||
"outputs": [str(x) for x in outputs_raw],
|
||||
"output_names": [str(x) for x in output_names_raw],
|
||||
}
|
||||
|
||||
|
||||
def _generated_signature_usable(sig):
|
||||
return isinstance(sig, dict) and isinstance(sig.get("inputs"), dict) and isinstance(sig.get("outputs"), list)
|
||||
|
||||
|
||||
def _normalised_generated_signature(sig):
|
||||
if not _generated_signature_usable(sig):
|
||||
return None
|
||||
|
||||
try:
|
||||
inputs = {str(k): str(v) for k, v in sig["inputs"].items() if k is not None}
|
||||
outputs = [str(x) for x in sig["outputs"]]
|
||||
required_raw = sig.get("required") or []
|
||||
if not isinstance(required_raw, (list, set, tuple)):
|
||||
required_raw = []
|
||||
output_names_raw = sig.get("output_names") or []
|
||||
if not isinstance(output_names_raw, list):
|
||||
output_names_raw = []
|
||||
return {
|
||||
"inputs": inputs,
|
||||
"required": {str(v) for v in required_raw if str(v) in inputs},
|
||||
"outputs": outputs,
|
||||
"output_names": [str(x) for x in output_names_raw],
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _generated_signature_conflicts(serialized_sig, generated_sig):
|
||||
if not serialized_sig["inputs"] and not serialized_sig["outputs"]:
|
||||
return False
|
||||
|
||||
generated_inputs = generated_sig["inputs"]
|
||||
generated_input_types = set(generated_inputs.values())
|
||||
for name, typ in serialized_sig["inputs"].items():
|
||||
if name in generated_inputs:
|
||||
if generated_inputs[name] != typ:
|
||||
return True
|
||||
elif typ not in generated_input_types:
|
||||
return True
|
||||
|
||||
if Counter(serialized_sig["outputs"]) - Counter(generated_sig["outputs"]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def match(ctx, items):
|
||||
"""
|
||||
Match a batch of nodes given only their (possibly serialized) signature —
|
||||
@@ -375,30 +441,35 @@ def match(ctx, items):
|
||||
Returns a mapping from source node type to candidate list.
|
||||
"""
|
||||
out = {}
|
||||
generated = ctx.get("generated") or _empty_generated_signatures()
|
||||
generated = ctx.get("generated") or {}
|
||||
if not isinstance(generated, dict):
|
||||
generated = {}
|
||||
generated_sigs = generated.get("sigs") or {}
|
||||
if not isinstance(generated_sigs, dict):
|
||||
generated_sigs = {}
|
||||
generated_meta = generated.get("meta") or {}
|
||||
if not isinstance(generated_meta, dict):
|
||||
generated_meta = {}
|
||||
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
t = it.get("type")
|
||||
if not t or t in out:
|
||||
continue
|
||||
|
||||
gen_sig = generated_sigs.get(t)
|
||||
if gen_sig is not None:
|
||||
gen_pack = (generated_meta.get(t) or {}).get("pack")
|
||||
sig = _signature_from_item(it)
|
||||
gen_sig = _normalised_generated_signature(generated_sigs.get(t))
|
||||
if gen_sig is not None and not _generated_signature_conflicts(sig, gen_sig):
|
||||
gen_meta = generated_meta.get(t) or {}
|
||||
if not isinstance(gen_meta, dict):
|
||||
gen_meta = {}
|
||||
gen_pack = gen_meta.get("pack")
|
||||
found = _candidates_for(t, gen_sig, gen_pack, ctx)
|
||||
if found:
|
||||
out[t] = found
|
||||
continue
|
||||
|
||||
inputs = {k: str(v) for k, v in (it.get("inputs") or {}).items()}
|
||||
sig = {
|
||||
"inputs": inputs,
|
||||
"required": set(inputs),
|
||||
"outputs": [str(x) for x in (it.get("outputs") or [])],
|
||||
"output_names": list(it.get("output_names") or []),
|
||||
}
|
||||
found = _candidates_for(t, sig, None, ctx)
|
||||
if found:
|
||||
out[t] = found
|
||||
|
||||
Reference in New Issue
Block a user