Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions torchrec_dlrm/README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,29 @@ torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \

# Criteo Kaggle Display Advertising Challenge dataset usage.

### Preliminary
- Python >= 3.9
- Cuda >= 12.0

### Setup environment
Install PyTorch nightly version
```bash
pip install torch --index-url https://download.pytorch.org/whl/nightly/cu126
```
Install FBGEMM-GPU
```bash
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126
```
Install torchrec from local build
```bash
git clone https://github.com/pytorch/torchrec.git
python -m pip install -e torchrec
```
Install additional dependencies
```bash
pip install -r requirements.txt
```

### Download the dataset.
```
wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz
Expand All @@ -292,11 +315,18 @@ python -m torchrec.datasets.scripts.npy_preproc_criteo --input_dir $INPUT_PATH -
export PREPROCESSED_DATASET=$insert_your_path_here
export GLOBAL_BATCH_SIZE=16384 ;
export WORLD_SIZE=8 ;
export LEARNING_RATE=0.5 ;
torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \
--in_memory_binary_criteo_path $PREPROCESSED_DATASET \
--pin_memory \
--mmap_mode \
--batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \
--learning_rate 1.0 \
--dataset_name criteo_kaggle
--learning_rate $LEARNING_RATE \
--dataset_name criteo_kaggle \
--num_embeddings_per_feature 40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36 \
--embedding_dim 128 \
--over_arch_layer_sizes 1024,1024,512,256,1 \
--dense_arch_layer_sizes 512,256,128 \
--epochs 1 \
--validation_freq_within_epoch 12802
```
5 changes: 3 additions & 2 deletions torchrec_dlrm/dlrm_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _evaluate(

iterator = itertools.islice(iter(eval_dataloader), limit_batches)

auroc = metrics.AUROC(compute_on_step=False, num_classes=2).to(device)
auroc = metrics.AUROC(task="multiclass", num_classes=2).to(device)

is_rank_zero = dist.get_rank() == 0
if is_rank_zero:
Expand All @@ -349,7 +349,8 @@ def _evaluate(
try:
_loss, logits, labels = pipeline.progress(iterator)
preds = torch.sigmoid(logits)
auroc(preds, labels)
preds_reshaped = torch.stack((1 - preds, preds), dim=1)
auroc(preds_reshaped, labels)
if is_rank_zero:
pbar.update(1)
except StopIteration:
Expand Down
5 changes: 2 additions & 3 deletions torchrec_dlrm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
fbgemm-gpu==0.3.2
torchmetrics==0.11.0
torchrec==0.3.2
tqdm
torchmetrics
40 changes: 40 additions & 0 deletions torchrec_dlrm/tee
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Traceback (most recent call last):
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__run_xar_main__.py", line 140, in <module>
__invoke_main()
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main
run_as_main(main_module, main_function)
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main
oss_run_as_main(
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main
main()
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/main.py", line 118, in main
run_main(get_sub_cmds(), argv)
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/main.py", line 104, in run_main
parser = create_parser(subcmds)
^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/main.py", line 97, in create_parser
cmd.add_arguments(cmd_parser)
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/fb/cmd_run.py", line 54, in add_arguments
_update_scheduler_args(subparser)
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/cli/fb/cmd_run.py", line 104, in _update_scheduler_args
**get_scheduler_factories(),
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/schedulers/__init__.py", line 59, in get_scheduler_factories
return load_group(
^^^^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/torchx/util/entrypoints.py", line 90, in load_group
entrypoints = metadata.entry_points().select(group=group)
^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/__init__.py", line 933, in entry_points
return EntryPoints(eps).select(**params)
^^^^^^^^^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/__init__.py", line 931, in <genexpr>
dist.entry_points for dist in _unique(distributions())
^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/_itertools.py", line 15, in unique_everseen
for element in iterable:
^^^^^^^^
File "/mnt/xarfuse/uid-693209/fdcdaa60-seed-nspid4026531836_cgpid117217439-ns-4026531841/importlib_metadata/__init__.py", line 365, in __new__
if getattr(getattr(cls, name), '__isabstractmethod__', False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt