diff --git a/qualibrate/parameters.py b/qualibrate/parameters.py index e165cda..b3e3d49 100644 --- a/qualibrate/parameters.py +++ b/qualibrate/parameters.py @@ -7,7 +7,7 @@ cast, ) -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from qualibrate.utils.logger_m import logger from qualibrate.utils.naming import get_full_class_path @@ -31,6 +31,8 @@ class RunnableParameters(BaseModel): + model_config = ConfigDict(extra="forbid") + @classmethod def serialize(cls, **kwargs: Any) -> Mapping[str, Any]: schema = cls.model_json_schema() diff --git a/qualibrate/qualibration_graph.py b/qualibrate/qualibration_graph.py index ded0051..ee92374 100644 --- a/qualibrate/qualibration_graph.py +++ b/qualibrate/qualibration_graph.py @@ -369,7 +369,7 @@ def _run(self, **passed_parameters: Any) -> None: orchestrator = self._orchestrator_or_error() self.cleanup() nodes = self._get_all_nodes_parameters( - passed_parameters.get("nodes", {}) + passed_parameters.pop("nodes", {}) ) self._parameters = self.parameters.model_validate(passed_parameters) self.full_parameters = self.full_parameters_class.model_validate( diff --git a/tests/unit/test_parameters/test_node_and_graph_parameters.py b/tests/unit/test_parameters/test_node_and_graph_parameters.py index 9e21899..d75dc6a 100644 --- a/tests/unit/test_parameters/test_node_and_graph_parameters.py +++ b/tests/unit/test_parameters/test_node_and_graph_parameters.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +from pydantic import ValidationError from qualibrate.parameters import GraphParameters, NodeParameters @@ -14,6 +15,16 @@ class SampleGraphParameters(GraphParameters): qubits: Optional[list[str]] = None other_param: str = "test" + @pytest.mark.parametrize( + "parameters_class", [SampleNodeParameters, SampleGraphParameters] + ) + def test_forbid_extra_parameters(self, parameters_class): + with pytest.raises(ValidationError) as ex: + parameters_class.model_validate({"invalid_key": None}) + errors = ex.value.errors() + assert errors[0]["type"] == "extra_forbidden" + assert errors[0]["loc"] == ("invalid_key",) + def test_node_targets_name(self): assert NodeParameters.targets_name == "qubits"