diff --git a/torchrec_dlrm/README.MD b/torchrec_dlrm/README.MD index 28a54454..fc1dd71f 100644 --- a/torchrec_dlrm/README.MD +++ b/torchrec_dlrm/README.MD @@ -275,6 +275,29 @@ torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ # Criteo Kaggle Display Advertising Challenge dataset usage. +### Preliminary +- Python >= 3.9 +- Cuda >= 12.0 + +### Setup environment +Install PyTorch nightly version +```bash +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu126 +``` +Install FBGEMM-GPU +```bash +pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126 +``` +Install torchrec from local build +```bash +git clone https://github.com/pytorch/torchrec.git +python -m pip install -e torchrec +``` +Install additional dependencies +```bash +pip install -r requirements.txt +``` + ### Download the dataset. ``` wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz @@ -292,11 +315,18 @@ python -m torchrec.datasets.scripts.npy_preproc_criteo --input_dir $INPUT_PATH - export PREPROCESSED_DATASET=$insert_your_path_here export GLOBAL_BATCH_SIZE=16384 ; export WORLD_SIZE=8 ; +export LEARNING_RATE=0.5 ; torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ --in_memory_binary_criteo_path $PREPROCESSED_DATASET \ --pin_memory \ --mmap_mode \ --batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \ - --learning_rate 1.0 \ - --dataset_name criteo_kaggle + --learning_rate $LEARNING_RATE \ + --dataset_name criteo_kaggle \ + --num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \ + --embedding_dim 128 \ + --over_arch_layer_sizes 1024,1024,512,256,1 \ + --dense_arch_layer_sizes 512,256,128 \ + --epochs 1 \ + --validation_freq_within_epoch 12802 ``` diff --git a/torchrec_dlrm/dlrm_main.py b/torchrec_dlrm/dlrm_main.py index cc727033..460aba85 100644 --- a/torchrec_dlrm/dlrm_main.py +++ b/torchrec_dlrm/dlrm_main.py @@ -334,7 +334,7 @@ def _evaluate( iterator = itertools.islice(iter(eval_dataloader), limit_batches) - auroc = metrics.AUROC(compute_on_step=False, num_classes=2).to(device) + auroc = metrics.AUROC(task="multiclass", num_classes=2).to(device) is_rank_zero = dist.get_rank() == 0 if is_rank_zero: @@ -349,7 +349,8 @@ def _evaluate( try: _loss, logits, labels = pipeline.progress(iterator) preds = torch.sigmoid(logits) - auroc(preds, labels) + preds_reshaped = torch.stack((1 - preds, preds), dim=1) + auroc(preds_reshaped, labels) if is_rank_zero: pbar.update(1) except StopIteration: diff --git a/torchrec_dlrm/requirements.txt b/torchrec_dlrm/requirements.txt index d34e4e72..29da10f7 100644 --- a/torchrec_dlrm/requirements.txt +++ b/torchrec_dlrm/requirements.txt @@ -1,3 +1,2 @@ -fbgemm-gpu==0.3.2 -torchmetrics==0.11.0 -torchrec==0.3.2 +tqdm +torchmetrics diff --git a/torchrec_dlrm/tee b/torchrec_dlrm/tee new file mode 100644 index 00000000..8ea69a84 --- /dev/null +++ b/torchrec_dlrm/tee @@ -0,0 +1,40 @@ +Traceback (most recent call last): + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__run_xar_main__.py", line 140, in + __invoke_main() + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main + run_as_main(main_module, main_function) + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main + oss_run_as_main( + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main + main() + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/main.py", line 118, in main + run_main(get_sub_cmds(), argv) + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/main.py", line 104, in run_main + parser = create_parser(subcmds) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/main.py", line 97, in create_parser + cmd.add_arguments(cmd_parser) + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/fb/cmd_run.py", line 54, in add_arguments + _update_scheduler_args(subparser) + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/fb/cmd_run.py", line 104, in _update_scheduler_args + **get_scheduler_factories(), + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/schedulers/__init__.py", line 59, in get_scheduler_factories + return load_group( + ^^^^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/util/entrypoints.py", line 90, in load_group + entrypoints = metadata.entry_points().select(group=group) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/__init__.py", line 933, in entry_points + return EntryPoints(eps).select(**params) + ^^^^^^^^^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/__init__.py", line 931, in + dist.entry_points for dist in _unique(distributions()) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/_itertools.py", line 15, in unique_everseen + for element in iterable: + ^^^^^^^^ + File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/__init__.py", line 365, in __new__ + if getattr(getattr(cls, name), '__isabstractmethod__', False) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +KeyboardInterrupt