MartinuzziFrancesco / NeuralBasisExpansions.jl

Julia (Flux) implementation of NBeats

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NBeats

Build Status Coverage Aqua

Implementation of the NBeats model (paper) in Julia (Flux). To use the package please do the following, as the package is not yet in the general registry:

using Pkg
Pkg.add("https://github.com/MartinuzziFrancesco/NeuralBasisExpansions.jl")

The package is still undergoing heavy testing, expect unexpected behavior.

Full sin example with helper functions is given in the example folder, under readme.jl.

# Model parameters
forecast_length = 5
backcast_length = 2*forecast_length
batch_size = 32
hidden_units = 128
theta_dims = (4, 8)
blocks_per_stack = 3

# Generate and batch the data
data = generate_sine_data(1000, backcast_length, forecast_length)
train_data, test_data = data[1:800], data[801:end]
train_batches = batch_data(train_data, batch_size)
test_batches = batch_data(test_data, batch_size)

# Create the NBeatsNet model
model = NBeatsNet(
    stacks=[generic_basis, trend_basis],
    blocks_stacks=blocks_per_stack,
    forecast_length=forecast_length,
    backcast_length=backcast_length,
    thetas_dim=theta_dims,
    hidden_units=hidden_units
)

# Loss function and optimizer
loss_fn(x, y) = Flux.mse(model(x)[2], y)
optimizer = Flux.ADAM(0.001)

# Training loop
epochs = 50
for epoch in 1:epochs
    Flux.train!(loss_fn, Flux.params(model), train_batches, optimizer)
    train_loss = mean([loss_fn(getindex(batch, 1), getindex(batch, 2)) for batch in train_batches])
    test_loss = mean([loss_fn(getindex(batch, 1), getindex(batch, 2)) for batch in test_batches])
    println("Epoch $epoch: Train Loss = $train_loss, Test Loss = $test_loss")
end

# Forecast using the model (example)
x_test, y_true = test_batches[1]
y_pred = model(x_test)[2]

mse, mae, r_squared = evaluate_predictions(y_true, y_pred)

println("Mean Squared Error: $mse")
println("Mean Absolute Error: $mae")
println("R-squared: $r_squared")

Quick example with random data to test the model

forecast_length = 5
backcast_length = 10
blocks_stacks = 3
thetas_dim = (4, 8)
hidden_units = 256

nbeats_net = NBeatsNet(
    stacks = [trend_basis, seasonality_basis],
    blocks_stacks = blocks_stacks,
    forecast_length = forecast_length,
    backcast_length = backcast_length,
    thetas_dim = thetas_dim,
    share_weights = false,
    hidden_units = hidden_units
)

# Create a batch of input data
batch_size = 3  # Number of instances in the batch
input_data = randn(Float32, backcast_length, batch_size)

backcast_output, forecast_output = nbeats_net(input_data)

About

Julia (Flux) implementation of NBeats

License:MIT License


Languages

Language:Julia 100.0%