diff --git a/docs/examples/use_cases/pytorch/resnet50/main.py b/docs/examples/use_cases/pytorch/resnet50/main.py index e69b6ae5b3d..4064d3e2b3f 100644 --- a/docs/examples/use_cases/pytorch/resnet50/main.py +++ b/docs/examples/use_cases/pytorch/resnet50/main.py @@ -20,6 +20,7 @@ try: from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy from nvidia.dali.pipeline import pipeline_def + from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy import nvidia.dali.types as types import nvidia.dali.fn as fn except ImportError: @@ -27,6 +28,7 @@ import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models +from contextlib import ExitStack def fast_collate(batch, memory_format): """Based on fast_collate from the APEX example @@ -87,6 +89,8 @@ def parse(): help='Runs CPU based version of DALI pipeline.') parser.add_argument('--disable_dali', default=False, action='store_true', help='Disable DALI data loader and use native PyTorch one instead.') + parser.add_argument('--dali_proxy', default=False, action='store_true', + help='Enable DALI proxy: uses native PyTorch data loader and DALI for preprocessing.') parser.add_argument('--prof', default=-1, type=int, help='Only run 10 iterations for profiling.') parser.add_argument('--deterministic', action='store_true') @@ -107,17 +111,7 @@ def to_python_float(t): else: return t[0] - -@pipeline_def(exec_dynamic=True) -def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True): - images, labels = fn.readers.file(file_root=data_dir, - shard_id=shard_id, - num_shards=num_shards, - random_shuffle=is_training, - pad_last_batch=True, - name="Reader") - dali_device = 'cpu' if dali_cpu else 'gpu' - decoder_device = 'cpu' if dali_cpu else 'mixed' +def image_processing_func(images, crop, size, is_training=True, decoder_device='mixed'): # ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0 preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0 @@ -130,7 +124,6 @@ def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=Fa random_area=[0.1, 1.0], num_attempts=100) images = fn.resize(images, - device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR) @@ -140,7 +133,6 @@ def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=Fa device=decoder_device, output_type=types.RGB) images = fn.resize(images, - device=dali_device, size=size, mode="not_smaller", interp_type=types.INTERP_TRIANGULAR) @@ -153,10 +145,29 @@ def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=Fa mean=[0.485 * 255,0.456 * 255,0.406 * 255], std=[0.229 * 255,0.224 * 255,0.225 * 255], mirror=mirror) - labels = labels.gpu() - return images, labels + return images +@pipeline_def(exec_dynamic=True) +def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True): + images, labels = fn.readers.file(file_root=data_dir, + shard_id=shard_id, + num_shards=num_shards, + random_shuffle=is_training, + pad_last_batch=True, + name="Reader") + decoder_device = 'cpu' if dali_cpu else 'mixed' + images = image_processing_func(images, crop, size, is_training, decoder_device) + return images, labels.gpu() + +@pipeline_def(exec_dynamic=True) +def create_dali_proxy_pipeline(crop, size, dali_cpu=False, is_training=True): + filepaths = fn.external_source(name="images", no_copy=True) + images = fn.io.file.read(filepaths) + decoder_device = 'cpu' if dali_cpu else 'mixed' + images = image_processing_func(images, crop, size, is_training, decoder_device) + return images + def main(): global best_prec1, args best_prec1 = 0 @@ -271,7 +282,7 @@ def resume(): train_loader = None val_loader = None - if not args.disable_dali: + if not args.disable_dali and not args.dali_proxy: train_pipe = create_dali_pipeline(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, @@ -303,6 +314,72 @@ def resume(): val_loader = DALIClassificationIterator(val_pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True) + elif args.dali_proxy: + + def read_filepath(path): + return np.frombuffer(path.encode(), dtype=np.int8) + + train_pipe = create_dali_proxy_pipeline( + batch_size=args.batch_size, + num_threads=args.workers, + device_id=args.local_rank, + seed=12 + args.local_rank, + crop=crop_size, + size=val_size, + dali_cpu=args.dali_cpu, + is_training=True) + + dali_server_train = dali_proxy.DALIServer(train_pipe) + train_dataset = datasets.ImageFolder( + traindir, + transform=dali_server_train.proxy, + loader=read_filepath, + ) + + val_pipe = create_dali_proxy_pipeline( + batch_size=args.batch_size, + num_threads=args.workers, + device_id=args.local_rank, + seed=12 + args.local_rank, + crop=crop_size, + size=val_size, + dali_cpu=args.dali_cpu, + is_training=False) + + dali_server_val = dali_proxy.DALIServer(val_pipe) + val_dataset = datasets.ImageFolder( + valdir, + transform=dali_server_val.proxy, + loader=read_filepath + ) + + train_sampler = None + val_sampler = None + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + + train_loader = dali_proxy.DataLoader( + dali_server_train, + train_dataset, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + pin_memory=True, + sampler=train_sampler, + collate_fn=None + ) + + val_loader = dali_proxy.DataLoader( + dali_server_val, + val_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + sampler=val_sampler, + collate_fn=None + ) else: train_dataset = datasets.ImageFolder(traindir, transforms.Compose([transforms.RandomResizedCrop(crop_size), @@ -344,35 +421,43 @@ def resume(): backoff_factor=0.5, growth_interval=100, enabled=args.fp16_mode) + total_time = AverageMeter() - for epoch in range(args.start_epoch, args.epochs): - # train for one epoch - avg_train_time = train(train_loader, model, criterion, scaler, optimizer, epoch) - total_time.update(avg_train_time) - if args.test: - break - - # evaluate on validation set - [prec1, prec5] = validate(val_loader, model, criterion) - - # remember best prec@1 and save checkpoint - if args.local_rank == 0: - is_best = prec1 > best_prec1 - best_prec1 = max(prec1, best_prec1) - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.arch, - 'state_dict': model.state_dict(), - 'best_prec1': best_prec1, - 'optimizer' : optimizer.state_dict(), - }, is_best) - if epoch == args.epochs - 1: - print('##Top-1 {0}\n' - '##Top-5 {1}\n' - '##Perf {2}'.format( - prec1, - prec5, - args.total_batch_size / total_time.avg)) + + with ExitStack() as stack: + if dali_server_train: + stack.enter_context(dali_server_train) + if dali_server_val: + stack.enter_context(dali_server_val) + + for epoch in range(args.start_epoch, args.epochs): + # train for one epoch + avg_train_time = train(train_loader, model, criterion, scaler, optimizer, epoch) + total_time.update(avg_train_time) + if args.test: + break + + # evaluate on validation set + [prec1, prec5] = validate(val_loader, model, criterion) + + # remember best prec@1 and save checkpoint + if args.local_rank == 0: + is_best = prec1 > best_prec1 + best_prec1 = max(prec1, best_prec1) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + 'optimizer' : optimizer.state_dict(), + }, is_best) + if epoch == args.epochs - 1: + print('##Top-1 {0}\n' + '##Top-5 {1}\n' + '##Perf {2}'.format( + prec1, + prec5, + args.total_batch_size / total_time.avg)) class data_prefetcher(): """Based on prefetcher from the APEX example @@ -426,14 +511,14 @@ def train(train_loader, model, criterion, scaler, optimizer, epoch): model.train() end = time.time() - if args.disable_dali: + if args.disable_dali or args.dali_proxy: data_iterator = data_prefetcher(train_loader) data_iterator = iter(data_iterator) else: data_iterator = train_loader for i, data in enumerate(data_iterator): - if args.disable_dali: + if args.disable_dali or args.dali_proxy: input, target = data train_loader_len = len(train_loader) else: @@ -532,14 +617,14 @@ def validate(val_loader, model, criterion): end = time.time() - if args.disable_dali: + if args.disable_dali or args.dali_proxy: data_iterator = data_prefetcher(val_loader) data_iterator = iter(data_iterator) else: data_iterator = val_loader for i, data in enumerate(data_iterator): - if args.disable_dali: + if args.disable_dali or args.dali_proxy: input, target = data val_loader_len = len(val_loader) else: diff --git a/qa/TL3_RN50_benchmark/test_pytorch.sh b/qa/TL3_RN50_benchmark/test_pytorch.sh new file mode 100644 index 00000000000..7a9fca19089 --- /dev/null +++ b/qa/TL3_RN50_benchmark/test_pytorch.sh @@ -0,0 +1,73 @@ +#!/bin/bash -e + +set -o nounset +set -o errexit +set -o pipefail + +cd /opt/dali/docs/examples/use_cases/pytorch/resnet50 + +NUM_GPUS=$(nvidia-smi -L | wc -l) + +if [ ! -d "val" ]; then + ln -sf /data/imagenet/val-jpeg/ val +fi +if [ ! -d "train" ]; then + ln -sf /data/imagenet/train-jpeg/ train +fi + +# turn off SHARP to avoid NCCL errors +export NCCL_NVLS_ENABLE=0 + +# Function to check the training results from a log file +check_training_results() { + local LOG="$1" + + RET=${PIPESTATUS[0]} + if [[ $RET -ne 0 ]]; then + echo "Error in training script." + return 2 + fi + + # Define the minimum performance thresholds + local MIN_TOP1=20.0 + local MIN_TOP5=40.0 + local MIN_PERF=2900 + + # Extract relevant information from the log file + local TOP1=$(grep "^##Top-1" "$LOG" | awk '{print $2}') + local TOP5=$(grep "^##Top-5" "$LOG" | awk '{print $2}') + local PERF=$(grep "^##Perf" "$LOG" | awk '{print $2}') + + # Check if the TOP1 and TOP5 values are available + if [[ -z "$TOP1" || -z "$TOP5" ]]; then + echo "Incomplete output." + return 3 + fi + + # Compare results against the minimum thresholds + local TOP1_RESULT=$(echo "$TOP1 $MIN_TOP1" | awk '{if ($1>=$2) {print "OK"} else { print "FAIL" }}') + local TOP5_RESULT=$(echo "$TOP5 $MIN_TOP5" | awk '{if ($1>=$2) {print "OK"} else { print "FAIL" }}') + local PERF_RESULT=$(echo "$PERF $MIN_PERF" | awk '{if ($1>=$2) {print "OK"} else { print "FAIL" }}') + + # Display results + echo + printf "TOP-1 Accuracy: %.2f%% (expect at least %f%%) %s\n" $TOP1 $MIN_TOP1 $TOP1_RESULT + printf "TOP-5 Accuracy: %.2f%% (expect at least %f%%) %s\n" $TOP5 $MIN_TOP5 $TOP5_RESULT + printf "Average perf: %.2f (expect at least %f) samples/sec %s\n" $PERF $MIN_PERF $PERF_RESULT + + # If all results are "OK", exit with status 0 + if [[ "$TOP1_RESULT" == "OK" && "$TOP5_RESULT" == "OK" && "$PERF_RESULT" == "OK" ]]; then + return 0 + fi +} + +torchrun --nproc_per_node=${NUM_GPUS} main.py -a resnet50 --b 256 --loss-scale 128.0 --workers 8 --lr=0.4 --fp16-mode --epochs 5 ./ 2>&1 | tee dali.log +check_training_results dali.log +RESULT_DALI=$? + +torchrun --nproc_per_node=${NUM_GPUS} main.py -a resnet50 --b 256 --loss-scale 128.0 --workers 8 --lr=0.4 --fp16-mode --epochs 5 --dali_proxy ./ 2>&1 | tee dali_proxy.log +check_training_results dali_proxy.log +RESULT_DALI_PROXY=$? + +# Return 0 if both are 0, otherwise return the first non-zero code +exit ${RESULT_DALI:-$RESULT_DALI_PROXY}