Skip to content

Commit

Permalink
minor refactor to allow modular functions (#224)
Browse files Browse the repository at this point in the history
* minor refactor to allow modular functions

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* minor fix in import

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* minor fix to imports

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix linting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix formatting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
  • Loading branch information
Ssukriti authored Jul 1, 2024
1 parent 0be40e0 commit b655e1a
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 99 deletions.
28 changes: 22 additions & 6 deletions tests/utils/test_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

# Local
from tuning.config import configs
from tuning.utils.preprocessing_utils import (
combine_sequence,
get_data_trainer_kwargs,
Expand Down Expand Up @@ -180,14 +181,29 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data):
assert trainer_kwargs["formatting_func"] is not None


# Tests for fetching train args
# Tests for validating data args
# Invalid args return ValueError
@pytest.mark.parametrize(
"dataset_text_field, response_template",
"data_args, packing",
[
("input", None),
(None, "output"),
# dataset_text_field with no response_template
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA,
dataset_text_field="output",
),
False,
),
# response template with no dataset_text_field or formatter
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA,
response_template="\n### Label:",
),
False,
),
],
)
def test_validate_args(dataset_text_field, response_template):
def test_validate_args(data_args, packing):
with pytest.raises(ValueError):
validate_data_args(dataset_text_field, response_template)
validate_data_args(data_args, packing)
37 changes: 7 additions & 30 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TrainerCallback,
)
from transformers.utils import is_accelerate_available, logging
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
from trl import SFTConfig, SFTTrainer
import datasets
import fire
import transformers
Expand Down Expand Up @@ -62,6 +62,7 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.preprocessing_utils import get_data_collator, validate_data_args


def train(
Expand Down Expand Up @@ -195,14 +196,6 @@ def train(
}
)

# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:]

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
if train_args.max_seq_length > tokenizer.model_max_length:
Expand Down Expand Up @@ -244,31 +237,14 @@ def train(
packing = True
else:
logger.info("Packing is set to False")
if data_args.response_template is None:
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization
# We should do this validation up front, then do the encoding, then handle the collator
raise ValueError("Response template is None, needs to be set for training")
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
packing = False

# Currently we support formatted datasets with single sequence instances.
if not (data_args.dataset_text_field or data_args.data_formatter_template):
raise ValueError(
"dataset_text_field and data_formatter_template are None. \
One of them needs to be set for training"
)
# Only one of dataset_text_field or data_formatter_template should be set.
if data_args.dataset_text_field and data_args.data_formatter_template:
raise ValueError(
"dataset_text_field and data_formatter_template are both set,\
but are mutually exclusive options"
)
# Validate if data args are set properly
validate_data_args(data_args, packing)
data_collator = get_data_collator(packing, data_args.response_template, tokenizer)

# load the data by parsing JSON
### TODO: all the jSON file formatting will be moved to a separate function
data_files = {"train": data_args.training_data_path}
if data_args.validation_data_path:
data_files["validation"] = data_args.validation_data_path
Expand Down Expand Up @@ -310,6 +286,7 @@ def train(
logger.info(
"Validation dataset length is %s", len(formatted_validation_dataset)
)
### JSON file formatting ends here

if framework is not None and framework.requires_agumentation:
model, (peft_config,) = framework.augmentation(
Expand Down
196 changes: 133 additions & 63 deletions tuning/utils/preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,141 @@
from tuning.config import configs


def validate_data_args(
def validate_data_args(data_args: configs.DataArguments, packing: bool):

assert isinstance(
data_args.training_data_path, str
), "Training data path has to be set and str"

# Dataset containing single sequence needs a response template for masking
if data_args.response_template is None and data_args.dataset_text_field is not None:
if packing is False:
raise ValueError(
"Since dataset_text_field is provided and packing is disabled, \
needs a corresponding response template for masking"
)

# Currently if packing is false, we require a response_template. This may change in future.
if packing is False:
if data_args.response_template is None:
raise ValueError(
"Response template is None, needs to be set for training \
with packing disabled."
)

if data_args.response_template:
# To use Response template, pass datasets with single sequence instances \
# or a formatter template to create single sequence on the fly.
if not (data_args.dataset_text_field or data_args.data_formatter_template):
raise ValueError(
"dataset_text_field and data_formatter_template are None. \
One of them needs to be set to use response_template"
)
# Only one of dataset_text_field or data_formatter_template should be set.
if data_args.dataset_text_field and data_args.data_formatter_template:
raise ValueError(
"dataset_text_field and data_formatter_template are both set,\
but are mutually exclusive options"
)
# TODO(s) In future seupport two more formats:
# 1. Allow no response template, and JSON with input/output fields and mask input

# 2. Allow pretokenized Dataset besides JSON.


def get_data_collator(
packing: bool,
response_template: Optional[str],
tokenizer: AutoTokenizer,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
Args:
packing: bool
Whether or not we should apply packing or not.
response_template: Optional[str]
Response template to be used for formatting by TRL.
tokenizer: AutoTokenizer
Loaded tokenizer object to be used by the collator.
Returns:
Callable
Callable collator to be leveraged by the trainer.
"""
if not packing:
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
if response_template:
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:]
return DataCollatorForCompletionOnlyLM(
response_template=response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
# TO DO with future changes,
# 1. Support no packing and seq2seq colator without response template
# # if dataset_text_field is None and response_template is None:
# # Use the seq2seq data collator;
# # Note that this automatically pads labels with -100
# return DataCollatorForSeq2Seq(
# tokenizer=tokenizer, padding=True, max_length=max_sequence_length
# )
# 2. add anything needed for preprocessed input


###################################################################################
### The functions below are not yet used. Iterative development towards new features


def get_data_collator_temp(
packing: bool,
dataset_text_field: Optional[str],
response_template: Optional[str],
):
# Dataset containing single sequence needs a single sequence and a response template
if dataset_text_field is None and response_template is not None:
raise ValueError(
"Needs a corresponding dataset_text_feld \
in which to look for response_template"
)
if response_template is None and dataset_text_field is not None:
raise ValueError(
"Since dataset_text_field is provided, \
needs a corresponding response template for masking"
)
# Dataset containing JSON with fields and a formatter template
# TO DO load JSON and check input/output field is present
max_sequence_length: int,
tokenizer: AutoTokenizer,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
# in future : pretokenized Dataset may be added.
Args:
packing: bool
Whether or not we should apply packing or not.
dataset_text_field: Optional[str]
Dataset text field fto be used for formatting by TRL.
response_template: Optional[str]
Response template to be used for formatting by TRL.
max_sequence_length: int
Max sequence length to be used for sequence tokenization.
tokenizer: AutoTokenizer
Loaded tokenizer object to be used by the collator.
Returns:
Callable
Callable collator to be leveraged by the trainer.
"""
if not packing:
if dataset_text_field is None and response_template is None:
# Use the seq2seq data collator; note that this automatically pads labels with -100
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=True, max_length=max_sequence_length
)
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:]
return DataCollatorForCompletionOnlyLM(
response_template=response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)


def get_data_trainer_kwargs(
Expand Down Expand Up @@ -82,7 +198,7 @@ def get_data_trainer_kwargs(
Dict[str, Any]
Data related kwargs to be used by the SFT Trainer.
"""
data_collator = get_data_collator(
data_collator = get_data_collator_temp(
packing, dataset_text_field, response_template, max_sequence_length, tokenizer
)
eval_dataset = None
Expand Down Expand Up @@ -122,52 +238,6 @@ def get_data_trainer_kwargs(
return data_kwargs


def get_data_collator(
packing: bool,
dataset_text_field: Optional[str],
response_template: Optional[str],
max_sequence_length: int,
tokenizer: AutoTokenizer,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
Args:
packing: bool
Whether or not we should apply packing or not.
dataset_text_field: Optional[str]
Dataset text field fto be used for formatting by TRL.
response_template: Optional[str]
Response template to be used for formatting by TRL.
max_sequence_length: int
Max sequence length to be used for sequence tokenization.
tokenizer: AutoTokenizer
Loaded tokenizer object to be used by the collator.
Returns:
Callable
Callable collator to be leveraged by the trainer.
"""
if not packing:
if dataset_text_field is None and response_template is None:
# Use the seq2seq data collator; note that this automatically pads labels with -100
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=True, max_length=max_sequence_length
)
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:]
return DataCollatorForCompletionOnlyLM(
response_template=response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)


def get_formatted_dataset(
data_path: str, dataset_text_field: str, tokenizer: AutoTokenizer
) -> Dataset:
Expand Down

0 comments on commit b655e1a

Please sign in to comment.