paweller / n_class-m_samples-Balanced-Batch-Generator

This repository contains a Keras-compatible generator that yields balanced batches consisting of n different classes with m samples, respectively.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

n_class-m_samples-Balanced-Batch-Generator: Keras-compatible generator to create batches with n classes and m samples per class.

MIT license

Overview

This repository contains an algorithm to create balanced batches containing classes_per_batch batches and samples_per_class samples for each class, respectively. If necessary, the algorithm will over-sample. If shuffle is set to True generated batches are shuffled. The generator supports SISO models.

The generator is compatible with Keras models' fit method.

API

generator = BatchGenerator(
    data=data,
    labels=labels,
    params=params,
    categorical=False,
    seed=None
)

Arguments:

  • data (numpy.ndarray): Input data.
  • labels (numpy.ndarray): Data's labels.
  • params (object): Object of a class. Must include parameters classes_per_batch (int), samples_per_class (int) and shuffle (bool).
  • categorical (bool)(optional)(default=False): If true, the generator yields binary class matrices. Otherwise, it yields class vectors.
  • seed (optional)(default=None): Random seed.

Returns:

  • A Keras-compatible generator yielding batches as batch_data, batch_lables.

Dummy code example

from tensorflow.keras.utils import to_categorical

from utils import Params, reshape_data
from batch_generator import BatchGenerator

# Load the parameters
params_pd = './params.json'
params = Params(params_pd)

# Load data and labels
train_data = ...
train_labels = ...   # shape (num_labels,)
valid_data = ...
valid_labels = ...   # shape (num_labels,)

# Optional step: To make sure data is formatted correctly, the
# `reshape_data` function from the `utils.py` file can be used.
train_data = reshape_data(train_data)
valid_data = reshape_data(valid_data)

# Create generator objects for training and validation
train_generator = BatchGenerator(
    data=train_data,
    labels=to_categorical(train_labels),
    params=params
)

valid_generator = BatchGenerator(
    data=valid_data,
    labels=to_categorical(valid_labels),
    params=params
)

# Create, compile and fit a sequential model
model = ...

model.compile(
    optimizer=...,
    loss=...
)

history = model.fit(
    x=train_generator,
    validation_data=valid_generator,
    epochs=params.num_epochs
)

Sources

About

This repository contains a Keras-compatible generator that yields balanced batches consisting of n different classes with m samples, respectively.

License:MIT License


Languages

Language:Python 100.0%