Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/repro2 #31

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@ And to launch distributed multinode or multiprocess experiment go to :ref:`distr
tutorials/local_communicator_tutorial
tutorials/distr_communicator_tutorial
tutorials/inference_tutorial
tutorials/batching_tutorial



62 changes: 62 additions & 0 deletions docs/tutorials/batching_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
.. _batching_tutorial:

*how-to:* Make batchers
======================================
We need to remind that the Stalactite framework utilizes train/infer loops inside each agent: PartyMembers and PartyMaster.
And each of these loops contains its batchers, which have to be synchronized.

To make batch initialization you need to override the ``make_batcher`` method in your PartyMaster class.
For example, you can initialize batcher like this in ``stalactite/ml/honest/linear_regression/party_master.py`` or make your implementation.

.. code-block:: python

def make_batcher(
self,
uids: Optional[List[str]] = None,
party_members: Optional[List[str]] = None,
is_infer: bool = False,
) -> Batcher:

logger.info("Master %s: making a make_batcher for uids %s" % (self.id, len(uids)))
self.check_if_ready()
batch_size = self._eval_batch_size if is_infer else self._batch_size
epochs = 1 if is_infer else self.epochs
if uids is None:
raise RuntimeError('Master must initialize batcher with collected uids.')
assert party_members is not None, "Master is trying to initialize make_batcher without members list"
return ListBatcher(epochs=epochs, members=party_members, uids=uids, batch_size=batch_size)


Party members use their batchers, which need to be initialized in the PartyMember class.
You can find an example of such initialization in ``stalactite/ml/honest/base.py``.
The point is that your batchers have to have the same ``uids, batch_size, and epochs`` to make training/infer properly.

.. code-block:: python

def make_batcher(
self,
uids: Optional[List[str]] = None,
party_members: Optional[List[str]] = None,
is_infer: bool = False,
) -> Batcher:
epochs = 1 if is_infer else self.epochs
batch_size = self._eval_batch_size if is_infer else self._batch_size
uids_to_use = self._uids_to_use_test if is_infer else self._uids_to_use
if uids_to_use is None:
raise RuntimeError("Cannot create make_batcher, you must `register_records_uids` first.")
return self._create_batcher(epochs=epochs, uids=uids_to_use, batch_size=batch_size)

def _create_batcher(self, epochs: int, uids: List[str], batch_size: int) -> Batcher:
"""Create a make_batcher for training.

:param epochs: Number of training epochs.
:param uids: List of unique identifiers for dataset rows.
:param batch_size: Size of the training batch.
"""
logger.info("Member %s: making a make_batcher for uids" % self.id)
if not self.is_consequently:
return ListBatcher(epochs=epochs, members=None, uids=uids, batch_size=batch_size)
else:
return ConsecutiveListBatcher(
epochs=epochs, members=self.members, uids=uids, batch_size=batch_size
)
1 change: 1 addition & 0 deletions examples/configs/efficientnet-splitNN-mnist-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ common:
world_size: 2
experiment_label: experiment-efficientnet-mnist-local
reports_export_folder: "../../reports"
seed: 22

vfl_model:
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/linreg-mnist-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ common:
world_size: 2
experiment_label: test-experiment-mnist-local
reports_export_folder: ../../reports
seed: 22

vfl_model:
epochs: 2
Expand All @@ -19,7 +20,6 @@ vfl_model:
vfl_model_path: ../../saved_models/linreg_model

data:
random_seed: 0
dataset_size: 5000
dataset: 'mnist'
host_path_data_dir: ../../data/sber_ds_vfl/mnist_vfl_parts2
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/linreg-mnist-single.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
common:
report_train_metrics_iteration: 1
report_test_metrics_iteration: 1
world_size: 2
world_size: 1
experiment_label: experiment-mnist-centralized
reports_export_folder: "../../reports"

Expand Down
1 change: 1 addition & 0 deletions examples/configs/logreg-sbol-smm-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ vfl_model:
is_consequently: False
use_class_weights: True
learning_rate: 0.05
weight_decay: 0.02
do_train: True
do_predict: True
do_save_model: True
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/mlp-splitNN-sbol-smm-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ common:
world_size: 2
experiment_label: experiment-mlp-sbol-smm-local
reports_export_folder: "../../reports"
seed: 22

vfl_model:
epochs: 1
batch_size: 250
eval_batch_size: 200
vfl_model_name: mlp
is_consequently: False
use_class_weights: False
use_class_weights: True
learning_rate: 0.01
do_train: True
do_predict: False
Expand Down
3 changes: 2 additions & 1 deletion examples/utils/local_arbitered_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from examples.utils.prepare_mnist import load_data as load_mnist
from examples.utils.prepare_sbol_smm import load_data as load_sbol_smm
from stalactite.helpers import reporting, run_local_agents

from stalactite.utils import seed_all

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
Expand Down Expand Up @@ -90,6 +90,7 @@ def run(config_path: Optional[str] = None):
)

config = VFLConfig.load_and_validate(config_path)
seed_all(config.common.seed)
master_processor, processors = load_processors(config)

with reporting(config):
Expand Down
16 changes: 12 additions & 4 deletions examples/utils/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from examples.utils.prepare_mnist import load_data as load_mnist
from examples.utils.prepare_sbol_smm import load_data as load_sbol_smm
from stalactite.helpers import reporting, run_local_agents
from stalactite.utils import seed_all

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
Expand All @@ -43,10 +44,12 @@ def load_processors(config: VFLConfig):
"""
if config.data.dataset.lower() == "mnist":

binary = False if config.vfl_model.vfl_model_name == "efficientnet" else True
binary = False if config.vfl_model.vfl_model_name in ["efficientnet", "logreg"] else True

if len(os.listdir(config.data.host_path_data_dir)) == 0:
load_mnist(config.data.host_path_data_dir, config.common.world_size, binary=binary)
load_mnist(
save_path=Path(config.data.host_path_data_dir), parts_num=config.common.world_size, binary=binary
)

dataset = {}
for m in range(config.common.world_size):
Expand Down Expand Up @@ -96,7 +99,10 @@ def run(config_path: Optional[str] = None):
)

config = VFLConfig.load_and_validate(config_path)
seed_all(config.common.seed)
master_processor, processors = load_processors(config)
if config.data.dataset_size == -1:
config.data.dataset_size = len(master_processor.dataset[config.data.train_split][config.data.uids_key])

with reporting(config):
shared_party_info = dict()
Expand Down Expand Up @@ -134,7 +140,8 @@ def run(config_path: Optional[str] = None):
do_predict=config.vfl_model.do_predict,
model_name=config.vfl_model.vfl_model_name if
config.vfl_model.vfl_model_name in ["resnet", "mlp", "efficientnet"] else None,
model_params=config.master.master_model_params
model_params=config.master.master_model_params,
seed=config.common.seed
)

member_ids = [f"member-{member_rank}" for member_rank in range(config.common.world_size)]
Expand All @@ -158,7 +165,8 @@ def run(config_path: Optional[str] = None):
do_save_model=config.vfl_model.do_save_model,
model_path=config.vfl_model.vfl_model_path,
model_params=config.member.member_model_params,
use_inner_join=True if member_rank == 0 else False
use_inner_join=True if member_rank == 0 else False,
seed=config.common.seed

)
for member_rank, member_uid in enumerate(member_ids)
Expand Down
6 changes: 4 additions & 2 deletions examples/utils/local_experiment_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from examples.utils.prepare_mnist import load_data as load_mnist
from examples.utils.prepare_sbol_smm import load_data as load_sbol_smm
from stalactite.helpers import reporting
from stalactite.utils import seed_all

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
Expand All @@ -23,7 +24,7 @@ def load_processors(config_path: str):
if config.data.dataset.lower() == "mnist":

if len(os.listdir(config.data.host_path_data_dir)) == 0:
load_mnist(config.data.host_path_data_dir, parts_num=1, is_single=True, binary=True)
load_mnist(Path(config.data.host_path_data_dir), parts_num=1, binary=True)

dataset = {0: datasets.load_from_disk(
os.path.join(f"{config.data.host_path_data_dir}/part_{0}")
Expand All @@ -36,7 +37,7 @@ def load_processors(config_path: str):
elif config.data.dataset.lower() == "sbol_smm":

if len(os.listdir(config.data.host_path_data_dir)) == 0:
load_sbol_smm(os.path.dirname(config.data.host_path_data_dir), parts_num=1, is_single=True, sbol_only=False)
load_sbol_smm(os.path.dirname(config.data.host_path_data_dir), parts_num=1, sbol_only=False)

dataset = {0: datasets.load_from_disk(
os.path.join(f"{config.data.host_path_data_dir}/part_{0}")
Expand All @@ -59,6 +60,7 @@ def run(config_path: Optional[str] = None):
os.path.join(Path(__file__).parent.parent.parent, 'configs/linreg-mnist-local.yml')
)
config = VFLConfig.load_and_validate(config_path)
seed_all(config.common.seed)
model_name = config.vfl_model.vfl_model_name

with reporting(config):
Expand Down
5 changes: 2 additions & 3 deletions examples/utils/prepare_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def save_splitted_dataset(ds_list, path, part_dir_name='part_', clean_dir=False)
ds.save_to_disk(part_path)


def load_data(save_path, parts_num, binary: bool = True, is_single: bool = False):
def load_data(save_path: Path, parts_num: int, binary: bool = True):
"""

The input is the original MNIST dataset.
1. Labels filtered and replaced so that the task is binary.

"""

make_validation = True
test_size = 0.15
stratify_by_column = 'label'
Expand All @@ -170,7 +169,7 @@ def load_data(save_path, parts_num, binary: bool = True, is_single: bool = False
if make_validation:
train_train, train_val = make_train_val_split(mnist['train'], test_size=test_size,
stratify_by_column=stratify_by_column, shuffle=shuffle, seed=seed)
if not is_single:
if parts_num != 1:

train_train_labels = train_train.select_columns(["image_idx", "label"])
train_val_labels = train_val.select_columns(["image_idx", "label"])
Expand Down
8 changes: 4 additions & 4 deletions examples/utils/prepare_sbol_smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def split_save_datasets(df, train_users, test_users, columns, postfix_sample, di
)


def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, sbol_only: bool = False):
def load_data(data_dir_path: str, parts_num: int = 2, sbol_only: bool = False):
sbol_path = os.path.join(data_dir_path, "sbol")
smm_path = os.path.join(data_dir_path, "smm")

Expand All @@ -44,7 +44,7 @@ def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, s
sbol_labels[["user_id"]], shuffle=True, random_state=seed, test_size=0.15
)

if not is_single:
if parts_num != 1:
logger.info("Save vfl dataset labels part...")
split_save_datasets(df=sbol_labels, train_users=users_train, test_users=users_test,
columns=["user_id", "labels"], postfix_sample=sample, part_postfix="master_part",
Expand All @@ -61,7 +61,7 @@ def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, s
lambda x: list(x), axis=1)
sbol_user_features = sbol_user_features[["user_id", "features_part_0", "labels"]]

if not is_single:
if parts_num != 1:
logger.info("Save vfl dataset part 0...")
split_save_datasets(df=sbol_user_features, train_users=users_train, test_users=users_test,
columns=["user_id", "features_part_0"], postfix_sample=sample, part_postfix="part_0",
Expand All @@ -85,7 +85,7 @@ def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, s
smm_user_factors = sbol_labels[["user_id"]].merge(smm_user_factors, on="user_id", how="inner")
smm_user_factors.rename(columns={"user_factors": "features_part_1"}, inplace=True)

if not is_single:
if parts_num != 1:
logger.info("Save vfl dataset part 1...")
split_save_datasets(df=smm_user_factors, train_users=users_train, test_users=users_test,
columns=["user_id", "features_part_1"], postfix_sample=sample, part_postfix="part_1",
Expand Down
4 changes: 3 additions & 1 deletion stalactite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ def make_batcher(
party_members: Optional[List[str]] = None,
is_infer: bool = False
) -> Batcher:
""" Make a make_batcher for training.
""" Make a make_batcher for training or inference.

:param uids: List of unique identifiers of dataset records.
:param party_members: List of party members` identifiers.
:param is_infer: Flag indicating whether to use inference mode.
Inference mode means that the batcher will have only one epoch and use eval_batch_size.

:return: Batcher instance.
"""
Expand Down
16 changes: 11 additions & 5 deletions stalactite/communications/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,24 @@ def broadcast(
return tasks

def gather(self, tasks: List[Task], recv_results: bool = False) -> List[Task]:
_recv_futures = [
RecvFuture(method_name=task.method_name, receive_from_id=task.from_id if not recv_results else task.to_id)
for task in tasks
]
receive_from_ids, _recv_futures = [], []
for task in tasks:
rfi = task.from_id if not recv_results else task.to_id
receive_from_ids.append(rfi)
_recv_futures.append(RecvFuture(method_name=task.method_name, receive_from_id=rfi))

for recv_f in _recv_futures:
threading.Thread(target=self._get_from_recv, args=(recv_f,), daemon=True).start()
done_tasks, pending_tasks = concurrent.futures.wait(_recv_futures, timeout=self.recv_timeout)

if number_to := len(pending_tasks):
raise TimeoutError(f"{self.participant.id} could not gather tasks from {number_to} members.")
return [task.result() for task in done_tasks]

results = {}
for task in done_tasks:
results[task.receive_from_id] = task.result()

return [results[idx] for idx in receive_from_ids]

def raise_if_not_ready(self):
"""Raise an exception if the communicator was not initialized properly."""
Expand Down
8 changes: 3 additions & 5 deletions stalactite/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CommonConfig(BaseModel):
default=Path(__file__).parent, description="Folder for exporting tests` and experiments` reports"
)
rendezvous_timeout: float = Field(default=3600, description="Initial agents rendezvous timeout in sec")
seed: int = Field(default=42, description="Initial random seed")
logging_level: Literal["debug", "info", "warning"] = Field(default="info", description="Logging level")

@model_validator(mode="after")
Expand All @@ -81,6 +82,7 @@ class VFLModelConfig(BaseModel):
learning_rate: float = Field(default=0.01, description='Learning rate')
l2_alpha: Optional[float] = Field(default=None, description='Alpha used for L2 regularization')
momentum: Optional[float] = Field(default=0, description='Optimizer momentum')
weight_decay: Optional[float] = Field(default=0.01, description='Optimizer weight decay')
do_train: bool = Field(default=True, description='Whether to run a training loop.')
do_predict: bool = Field(default=True, description='Whether to run an inference loop.')
do_save_model: bool = Field(default=True, description='Whether to save the model after training.')
Expand All @@ -93,11 +95,7 @@ class VFLModelConfig(BaseModel):
class DataConfig(BaseModel):
"""Experimental data parameters config."""

random_seed: int = Field(
default=0,
description="Experiment data random seed (including random, numpy, torch)"
)
dataset_size: int = Field(default=1000, description="Number of dataset rows to use")
dataset_size: int = Field(default=100, description="Number of dataset rows to use")
host_path_data_dir: str = Field(default='.', description="Path to datasets` directory")
dataset: Literal[
'mnist', 'sbol', 'sbol_smm', 'home_credit', 'home_credit_bureau_pos', 'avito', 'avito_texts_images'] = Field(
Expand Down
2 changes: 1 addition & 1 deletion stalactite/data_preprocessors/full_data_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def transform(self, inp_data):

def _transform(self, inp_data):

num_rows = inp_data.num_rows #todo: refactor
num_rows = inp_data.num_rows

tnsr = torch.as_tensor(inp_data[self.input_feature_name][0:num_rows])

Expand Down
1 change: 0 additions & 1 deletion stalactite/data_preprocessors/tabular_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fit_transform(self):
split_dict[feature_name] = standard_scaler.fit_transform(split_dict[feature_name])

split_dict[uids_name] = split_data[uids_name]
# split_dict[label_name] = [x[6] for x in split_dict[label_name]] # todo: remove (for debugging only)

if self.is_master and isinstance(train_split_data[label_name][0], list):
self.multilabel = True
Expand Down
2 changes: 1 addition & 1 deletion stalactite/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def reporting(config: VFLConfig):
"is_consequently": config.vfl_model.is_consequently,
"model_name": config.vfl_model.vfl_model_name,
"learning_rate": config.vfl_model.learning_rate,
"weight_decay": config.vfl_model.weight_decay,
"dataset": config.data.dataset,

}
Expand Down Expand Up @@ -108,4 +109,3 @@ def run_local_agents(

for thread in threads:
thread.join()

Loading
Loading