Denbergvanthijs / imbDRL

Imbalanced Classification with Deep Reinforcement Learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

imbDRL

GitHub Workflow Status License

Imbalanced Classification with Deep Reinforcement Learning.

This repository contains an (Double) Deep Q-Network implementation of binary classification on unbalanced datasets using TensorFlow 2.3+ and TF Agents 0.6+. The Double DQN as published in this paper by van Hasselt et al. (2015) is using a custom environment based on this paper by Lin, Chen & Qi (2019).

Example scripts on the Mnist, Fashion Mnist, Credit Card Fraud and Titanic datasets can be found in the ./imbDRL/examples/ddqn/ folder.

Results

The following results are collected with the scripts in the appendix: imbDRLAppendix. Experiments conducted on the latest release of imbDRL and based on this paper by Lin, Chen & Qi (2019).

Results

Requirements

  • Python 3.7+
  • The required packages as listed in: requirements.txt
  • Logs are by default saved in ./logs/
  • Trained models are by default saved in ./models/
  • Optional: ./data/ folder located at the root of this repository.
    • This folder must contain creditcard.csv downloaded from Kaggle if you would like to use the Credit Card Fraud dataset.
    • Note: creditcard.csv needs to be split in a seperate train and test file. Please use the function imbDRL.utils.split_csv

Getting started

Install via pip:

  • pip install imbDRL

Run any of the following scripts:

  • python .\imbDRL\examples\ddqn\train_credit.py
  • python .\imbDRL\examples\ddqn\train_famnist.py
  • python .\imbDRL\examples\ddqn\train_mnist.py
  • python .\imbDRL\examples\ddqn\train_titanic.py

TensorBoard

To enable TensorBoard, run tensorboard --logdir logs

Tests and linting

Extra arguments are handled with the ./tox.ini file.

  • Pytest: python -m pytest
  • Flake8: flake8
  • Coverage can be found in the generated ./htmlcov folder

Appendix

The appendix can be found in the imbDRLAppendix repository.

About

Imbalanced Classification with Deep Reinforcement Learning

License:Apache License 2.0


Languages

Language:Python 100.0%