Light curve classification using LSTM and Phased LSTM recurrent models.
This repository contains the implementation of the light curve classifier (L+P) described in C. Donoso-Oliva et al., 2021.
NOTE: A new version of the L+P is being developed at this link. It will contain new experimental features and code best practices. (not yet evaluated)
The faster way to run the code is creating a conda environment as follows:
conda env create -f environment.yml
the first line on the YAML file defines the name of the environment. Feel free to change it.
Optionally, you can use the requirements.txt
file to install dependencies using pip
pip install -r requirements.txt
./models/
: This folder includes the LSTM and Phased LSTM model class.plstm.py
: Phased LSTM classifier following the architecture explained in the paper.lstm.py
: LSTM classifier following the architecture explained in the paper../layers/
: Custom layers used on this workphased.py
: Phased LSTM unit. It consists in a LSTM + time gatedata.py
: Contains relevant function for loading and creating records.get_data.py
: Script to download datamain.py
: Main script that loads the data and instances the models (use —help to see running options) for trainingpredict
: Prediction script which receive 4 (sys.argv) arguments: {train, test}_script.py
: Code routines for hyperparameter tuningexperiments
and results
folders which contain adjusted models and metrics/figures, respectively. Although they are not in this repo, you can download them using the experiments link or the results link. Once downloaded paste the folders into the root directory.You can download data in their raw format or preprocessed record. Getting records implies using the same folds as the author.
python get_data.py --help
: to see script options.python get_data.py --dataset <name>
: to download raw dataset.Alternatively, you can download directly from google drive
python get_data.py --dataset linear --record
For custom training, please convert your data to tf.Record
using the standard function create_record(light_curves, labels, path='')
in ./data.py