Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load and dump parameters #72

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
381 changes: 376 additions & 5 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ qualang_tools = { version = ">0.17.4", python = ">=3.9,<3.12" }
networkx = "~3.2.0" # max available version with python 3.9 support
jsonpointer = "^3.0.0"
types-networkx = "~3.2.1.20240918"
datamodel-code-generator = "^0.26.3"


[tool.poetry.group.test.dependencies]
Expand Down
43 changes: 35 additions & 8 deletions qualibrate/qualibration_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import traceback
from collections.abc import Generator, Sequence
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from copy import copy
Expand Down Expand Up @@ -36,12 +36,14 @@
from qualibrate.utils.exceptions import StopInspection
from qualibrate.utils.logger_m import logger
from qualibrate.utils.node.comined_method import InstanceOrClassMethod
from qualibrate.utils.node.content import read_node_content, read_node_data
from qualibrate.utils.node.content import (
parse_node_content,
read_node_content,
read_node_data,
)
from qualibrate.utils.node.loaders.base_loader import BaseLoader
from qualibrate.utils.node.loaders.quam_loader import QuamLoader
from qualibrate.utils.node.path_solver import (
get_node_dir_path,
get_node_quam_filepath,
)
from qualibrate.utils.node.record_state_update import (
record_state_update_getattr,
Expand Down Expand Up @@ -325,6 +327,7 @@ def _load_from_id(
node_id: int,
base_path: Optional[Path] = None,
custom_loaders: Optional[Sequence[type[BaseLoader]]] = None,
build_params_class: bool = False,
) -> Optional["QualibrationNode[ParametersType]"]:
if base_path is None:
try:
Expand All @@ -345,12 +348,25 @@ def _load_from_id(
return None
node_content = read_node_content(node_dir, node_id, base_path)
if node_content is not None:
quam_filepath = get_node_quam_filepath(
node_content["data"], node_dir
quam_machine, parameters = parse_node_content(
node_content,
node_id,
node_dir,
build_params_class,
)
if quam_filepath is not None:
quam_machine = QuamLoader().load(quam_filepath)
if quam_machine is not None:
self.machine = quam_machine
if parameters is not None:
if build_params_class:
self.parameters_class = cast(
ParametersType, parameters
).__class__
self._parameters = cast(ParametersType, parameters)
else:
self._parameters = self.parameters.model_construct(
**cast(Mapping[str, Any], parameters)
)

data = read_node_data(node_dir, node_id, base_path, custom_loaders)
if data is not None:
self.results = data
Expand All @@ -375,6 +391,7 @@ def load_from_id(
node_id=node_id,
base_path=base_path,
custom_loaders=custom_loaders,
build_params_class=isinstance(caller, type),
)

def _post_run(
Expand Down Expand Up @@ -742,3 +759,13 @@ def add_node(
)

nodes[node.name] = node


if __name__ == "__main__":
from pathlib import Path

# path = Path("/home/maxim_v4s/Downloads/")
path = Path.home().joinpath(".qualibrate/user_storage/init_project")
node = QualibrationNode.load_from_id(12)
print(node.results)
print(node.parameters)
8 changes: 7 additions & 1 deletion qualibrate/storage/local_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ def save(self, node: NodeTypeVar) -> None:

# Save results
self.data_handler.name = node.name
DataHandler.node_data = {"quam": "./quam_state.json"}
DataHandler.node_data = {
"quam": "./quam_state.json",
"parameters": {
"model": node.parameters.model_dump(mode="json"),
"schema": node.parameters.__class__.model_json_schema(),
},
}
node_contents = (
self.data_handler.generate_node_contents()
) # TODO directly access idx
Expand Down
78 changes: 76 additions & 2 deletions qualibrate/utils/node/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,36 @@
from datetime import datetime
from itertools import chain
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast

import datamodel_code_generator as dmcg
from datamodel_code_generator.format import DatetimeClassType
from datamodel_code_generator.model import get_data_model_types
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
from pydantic import Field # noqa: F401

from qualibrate import NodeParameters
from qualibrate.utils.logger_m import logger
from qualibrate.utils.node.loaders import DEFAULT_LOADERS, SUPPORTED_EXTENSIONS
from qualibrate.utils.node.loaders import (
DEFAULT_LOADERS,
SUPPORTED_EXTENSIONS,
QuamLoader,
)
from qualibrate.utils.node.loaders.base_loader import BaseLoader
from qualibrate.utils.node.path_solver import (
get_data_filepath,
get_node_dir_path,
get_node_filepath,
get_node_quam_filepath,
resolve_and_check_relative,
)

DATA_MODEL_TYPES = get_data_model_types(
dmcg.DataModelType.PydanticV2BaseModel,
target_python_version=dmcg.PythonVersion.PY_311,
target_datetime_class=DatetimeClassType.Datetime,
)


def read_raw_node_file(
node_filepath: Path,
Expand Down Expand Up @@ -152,6 +170,26 @@ def read_node_content(
return node_content


def parse_node_content(
node_content: Mapping[str, Any],
node_id: int,
node_dir: Path,
build_params_class: bool,
) -> tuple[Optional[Any], Optional[Union[NodeParameters, Mapping[str, Any]]]]:
quam_machine = None
if "data" in node_content:
quam_filepath = get_node_quam_filepath(node_content["data"], node_dir)
if quam_filepath is not None:
quam_machine = QuamLoader().load(quam_filepath)
parameters_data = node_content.get("data", {}).get("parameters")
if parameters_data is None:
return quam_machine, None
if not isinstance(parameters_data, dict):
return quam_machine, None
parameters = load_parameters(parameters_data, node_id, build_params_class)
return quam_machine, parameters


def _get_filename_and_subreference(
filepath: Union[str, Path],
) -> tuple[Path, Optional[str]]:
Expand Down Expand Up @@ -249,3 +287,39 @@ def read_node_data(
results, loaders_instances, supported_extensions, node_dir
)
return results


def load_parameters(
parameters: Mapping[str, Any],
node_id: int,
build_params_class: bool,
) -> Optional[Union[NodeParameters, Mapping[str, Any]]]:
model = parameters.get("model")
if not build_params_class:
return model
schema = parameters.get("schema")
if schema is None or model is None:
return None
class_name = f"LoadedNode{node_id}Parameters"
parser = JsonSchemaParser(
json.dumps(schema),
data_model_type=DATA_MODEL_TYPES.data_model,
data_model_root_type=DATA_MODEL_TYPES.root_model,
data_model_field_type=DATA_MODEL_TYPES.field_model,
data_type_manager_type=DATA_MODEL_TYPES.data_type_manager,
use_standard_collections=True,
use_generic_container_types=True,
class_name=class_name,
base_class="qualibrate.parameters.NodeParameters",
additional_imports=["datetime.datetime"],
dump_resolve_reference_action=(
DATA_MODEL_TYPES.dump_resolve_reference_action
),
)
model_class_str = str(parser.parse())
exec(model_class_str)
params_class = locals().get(class_name)

if params_class is None:
logger.error("Can't build parameters class correctly")
return cast(NodeParameters, params_class).model_validate(model)
2 changes: 2 additions & 0 deletions qualibrate/utils/node/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"NumpyArrayLoader",
"ImageLoader",
"XarrayLoader",
"DEFAULT_LOADERS",
"SUPPORTED_EXTENSIONS",
]


Expand Down