项目作者: NYU-CDS-Capstone-Project

项目描述 :
Learning Visual Embeddings
高级语言: Jupyter Notebook
项目地址: git://github.com/NYU-CDS-Capstone-Project/learning_visual_embeddings.git


DS-GA 1006 Capstone Project and Presentation

Learning Visual Embeddings for Reinforcement Learning

Members:

  • Mihir Rana
  • Kenil Tanna

Requirements

For ease of setup, we have created a requirements.yaml file which will create a conda environment with the name visual_embeddings and install all dependencies and requirements into that environment. To do this:

  • Install Anaconda and run:
    1. conda env create -f requirements.yaml
  • Optionally, if you want to run it on a GPU, install CUDA and cuDNN

Installation

Again, for simplicity, we have created a module with the name visual_embeddings which can be installed directly into pip by running the following command from the main project directory:

  1. pip install -e .

Usage

  1. usage: main.py [-h] [--project-dir PROJECT_DIR] [--data-dir DATA_DIR]
  2. [--plots-dir PLOTS_DIR] [--logs-dir LOGS_DIR]
  3. [--checkpoints-dir CHECKPOINTS_DIR]
  4. [--embeddings-dir EMBEDDINGS_DIR] [--dataset-type DATASET_TYPE]
  5. [--dataset DATASET] [--data-ext DATA_EXT] [--offline] [--force]
  6. [--cpu] [--cuda] [--device DEVICE] [--device-ids DEVICE_IDS]
  7. [--parallel] [--emb-model EMB_MODEL]
  8. [--load-ckpt LOAD_CHECKPOINT] [--load-emb-ckpt LOAD_EMB_CKPT]
  9. [--load-cls-ckpt LOAD_CLS_CKPT] [--batch-size BATCH_SIZE]
  10. [--epochs EPOCHS] [--lr LR] [--flatten] [--num-train NUM_TRAIN]
  11. [--num-frames NUM_FRAMES_IN_STACK]
  12. [--num-channels NUM_CHANNELS]
  13. [--num-pairs NUM_PAIRS_PER_EXAMPLE] [--use-pool] [--use-res]
  14. Learning Visual Embeddings for Reinforcement Learning
  15. optional arguments:
  16. -h, --help show this help message and exit
  17. --project-dir PROJECT_DIR path to project directory
  18. --data-dir DATA_DIR path to data directory (used if different from "data/")
  19. --plots-dir PLOTS_DIR path to plots directory (used if different from "logs"plots/)
  20. --logs-dir LOGS_DIR path to logs directory (used if different from "logs/")
  21. --checkpoints-dir CHECKPOINTS_DIR path to checkpoints directory (used if different from "checkpoints/")
  22. --embeddings-dir EMBEDDINGS_DIR path to embeddings directory (used if different from "checkpoints/embeddings/")
  23. --dataset-type DATASET_TYPE name of PyTorch Dataset to use
  24. maze | fixed_mmnist | random_mmnist, default=maze
  25. --dataset DATASET name of dataset file in "data" directory
  26. mnist_test_seq | moving_bars_20_121 | etc., default=all_mazes_16_3_6
  27. --data-ext DATA_EXT extension of dataset file in data directory
  28. --offline use offline preprocessing of data loader
  29. --force overwrites all existing dumped data sets (if used with `--offline`)
  30. --cpu use CPU
  31. --cuda use CUDA, default id: 0
  32. --device cuda | cpu, default=cuda
  33. device to train on
  34. --device-ids DEVICE_IDS IDs of GPUs to use
  35. --parallel use all GPUs available
  36. --emb-model EMB_MODEL name of embedding network
  37. --load-ckpt LOAD_CHECKPOINT name of checkpoint file to load
  38. --load-emb-ckpt LOAD_EMB_CKPT name of embedding network file to load
  39. --load-cls-ckpt LOAD_CLS_CKPT name of classification network file to load
  40. --batch-size BATCH_SIZE input batch size, default=64
  41. --epochs EPOCHS number of epochs, default=10
  42. --lr LR learning rate, default=1e-4
  43. --flatten flatten data into 1 long video
  44. --num-train NUM_TRAIN number of training examples
  45. --num-frames NUM_FRAMES_IN_STACK number of stacked frames, default=2
  46. --num-channels NUM_CHANNELS number of channels in input image, default=1
  47. --num-pairs NUM_PAIRS_PER_EXAMPLE number of pairs per video, default=5
  48. --use-pool use max pooling instead of strided convolutions
  49. --use-res use residual layers

Training

Minigrid Maze

  1. python main.py --dataset all_mazes_10000_16_3_6 --dataset-type maze --epochs 15 --num-train 500000 --emb-model emb-cnn1 --num-frames 1 --num-channels 3 --flatten

Moving MNIST (Random Trajectories)

  1. python main.py --dataset moving_mnist --dataset-type random_mmnist --data-ext .h5 --num-frames 4 --use-pool

Moving MNIST (Fixed Trajectories)

  1. python main.py --dataset mnist_test_seq --dataset-type fixed_mmnist --data-ext .npy --num-frames 2 --use-pool

Moving Bars

  1. python generate_lines_data.py --seq-len 50 --img-dim 121
  2. python main.py --dataset moving_bars_50_121 --dataset-type fixed_mmnist --data-ext .npy --num-frames 4