Skip to content

How can I integrate an RNN into an ODEProblem in Lux.jl? #952

@disadone

Description

@disadone

Hi!
Just wondering how the RNN could be mixed into the ODEProblem

In flux times, it seems a Recur layer need to be created. However there is already a Recurrence in Lux.jl
Training of UDEs with recurrent networks

How can Lux.jl do the job now?
I self defined a GRUcell and it runs well combined with the beginner tutorial Training a Simple LSTM

using ConcreteStructs: @concrete
using Lux
using Static
using Random

IntegerType = Union{Integer,Static.StaticInteger}
BoolType = Union{StaticBool, Bool, Val{true},Val{false}}

@concrete struct FastGRUCell <:Lux.AbstractRecurrentCell
    train_state <: StaticBool
    in_dims <: IntegerType
    out_dims <: IntegerType
    init_bias
    init_weight
    init_state
    dynamics_nonlinearity
    gating_nonlinearity
    α<:AbstractFloat
    layernormQ::StaticBool
end

function FastGRUCell(
        (in_dims,out_dims)::Pair{<:Lux.IntegerType,<:Lux.IntegerType},
        Δt::T, τ::T,layernormQ::BoolType;
        train_state::BoolType=False(),
        init_weight=Lux.glorot_normal,
        init_bias=Lux.zeros32,
        init_state=zeros32,
        dynamics_nonlinearity = Lux.sigmoid_fast,
        gating_nonlinearity = Lux.tanh_fast) where T<:AbstractFloat
    init_weight = ntuple(Returns(init_weight),3)
    init_bias = ntuple(Returns(init_bias),3)
    α = Δt/τ
    return FastGRUCell(
        static(train_state),
        in_dims,out_dims,init_bias,init_weight,init_state,
        dynamics_nonlinearity,gating_nonlinearity,α,static(layernormQ)
    )
end

function Lux.initialparameters(rng::AbstractRNG,gru::FastGRUCell)
    # hidden to hidden
    Wz,Wr,Wh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.out_dims)) for init_weight in gru.init_weight)
    # input to hidden
    Uz,Ur,Uh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.in_dims)) for init_weight in gru.init_weight)

    ps = (; Wz,Wr,Wh,Uz,Ur,Uh)

    biasz,biasr,biash = (Lux.init_rnn_weight(rng,init_bias,gru.out_dims,gru.out_dims) for init_bias in gru.init_bias)

    ps = merge(ps, (; biasz,biasr,biash))
    Lux.has_train_state(gru) &&  (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),)))
    return ps
end
Lux.initialstates(rng::AbstractRNG,::FastGRUCell) = (rng=Lux.Utils.sample_replicate(rng),)

function (gru::FastGRUCell{True})(x::AbstractMatrix,ps,st::NamedTuple)
    hidden_state = Lux.init_trainable_rnn_hidden_state(ps.hidden_state, x)
    return gru((x, (hidden_state,)), ps, st)
end

function (gru::FastGRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
    rng = Lux.replicate(st.rng)
    st = merge(st, (; rng))
    hidden_state = Lux.init_rnn_hidden_state(rng, gru, x)
    return gru((x, (hidden_state,)), ps, st)
end

const _FastGRUCellInputType = Tuple{
    <:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (m::FastGRUCell)(
    (x,(h,))::_FastGRUCellInputType, ps,st::NamedTuple)

    Wzh =  fused_dense_bias_activation(identity,ps.Wz,h,ps.biasz)
    Wrh =  fused_dense_bias_activation(identity,ps.Wr,h,ps.biasr)
    Uzx =  fused_dense_bias_activation(identity,ps.Uz,x,nothing)
    Urx =  fused_dense_bias_activation(identity,ps.Ur,x,nothing)
    
    z = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wzh,nothing,nothing) .+ Uzx)) : (@. m.gating_nonlinearity(Wzh+Uzx))
    r = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wrh,nothing,nothing) .+ Urx)) : (@. m.gating_nonlinearity(Wrh+Urx))

    Whh = fused_dense_bias_activation(identity,ps.Wh, h .* r ,ps.biash)
    Uhh = fused_dense_bias_activation(identity,ps.Uh, x ,nothing)
    h̃ = dynamic(m.layernormQ) ? (m.dynamics_nonlinearity.(layernorm(Whh,nothing,nothing) .+ Uhh)) : (@. m.dynamics_nonlinearity(Whh+Uhh))
    h′ = @. (1-m.α * z) * h + m.α * z *return (h′,(h′,)),st
end


# --------------------------------------------------------------------------------------------------
# adapted from https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN#Creating-a-Classifier
using Lux, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:fastgru_cell, :classifier)}
    fastgru_cell::L
    classifier::C
end
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        FastGRUCell(in_dims => hidden_dims, 0.01f0, 1.0f0, true), 
        Dense(hidden_dims => out_dims, sigmoid))
end

function (s::SpiralClassifier)(
    x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}

    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_fastgru = s.fastgru_cell(x_init, ps.fastgru_cell, st.fastgru_cell)

    for x in x_rest
        (y, carry), st_fastgru = s.fastgru_cell((x, carry), ps.fastgru_cell, st_fastgru)
    end

    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, fastgru_cell = st_fastgru))

    return vec(y), st
end

 # ----- loss
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# ----- training

function main(model_type)
    dev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)

    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            # x: (2,50,128), y: (128,)  # dimension time trials
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)

When I try to transfer my self-defined GRUcell to the tutorial MNIST Classification using Neural ODEs, I don't know how to start the job.
Really appreciate If anyone could help me!
Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions