Skip to content

Commit

Permalink
Merge pull request #83 from qua-platform/extra_params
Browse files Browse the repository at this point in the history
Forbid extra args to runnable parameters
  • Loading branch information
nulinspiratie authored Jan 30, 2025
2 parents 3441ea4 + 111d5e9 commit b077bed
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
4 changes: 3 additions & 1 deletion qualibrate/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
cast,
)

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator

from qualibrate.utils.logger_m import logger
from qualibrate.utils.naming import get_full_class_path
Expand All @@ -31,6 +31,8 @@


class RunnableParameters(BaseModel):
model_config = ConfigDict(extra="forbid")

@classmethod
def serialize(cls, **kwargs: Any) -> Mapping[str, Any]:
schema = cls.model_json_schema()
Expand Down
2 changes: 1 addition & 1 deletion qualibrate/qualibration_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _run(self, **passed_parameters: Any) -> None:
orchestrator = self._orchestrator_or_error()
self.cleanup()
nodes = self._get_all_nodes_parameters(
passed_parameters.get("nodes", {})
passed_parameters.pop("nodes", {})
)
self._parameters = self.parameters.model_validate(passed_parameters)
self.full_parameters = self.full_parameters_class.model_validate(
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_parameters/test_node_and_graph_parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import pytest
from pydantic import ValidationError

from qualibrate.parameters import GraphParameters, NodeParameters

Expand All @@ -14,6 +15,16 @@ class SampleGraphParameters(GraphParameters):
qubits: Optional[list[str]] = None
other_param: str = "test"

@pytest.mark.parametrize(
"parameters_class", [SampleNodeParameters, SampleGraphParameters]
)
def test_forbid_extra_parameters(self, parameters_class):
with pytest.raises(ValidationError) as ex:
parameters_class.model_validate({"invalid_key": None})
errors = ex.value.errors()
assert errors[0]["type"] == "extra_forbidden"
assert errors[0]["loc"] == ("invalid_key",)

def test_node_targets_name(self):
assert NodeParameters.targets_name == "qubits"

Expand Down

0 comments on commit b077bed

Please sign in to comment.