-
-
Notifications
You must be signed in to change notification settings - Fork 160
Open
Labels
questionFurther information is requestedFurther information is requested
Description
Hi!
Just wondering how the RNN could be mixed into the ODEProblem
In flux times, it seems a Recur layer need to be created. However there is already a Recurrence
in Lux.jl
Training of UDEs with recurrent networks
How can Lux.jl
do the job now?
I self defined a GRUcell and it runs well combined with the beginner tutorial Training a Simple LSTM
using ConcreteStructs: @concrete
using Lux
using Static
using Random
IntegerType = Union{Integer,Static.StaticInteger}
BoolType = Union{StaticBool, Bool, Val{true},Val{false}}
@concrete struct FastGRUCell <:Lux.AbstractRecurrentCell
train_state <: StaticBool
in_dims <: IntegerType
out_dims <: IntegerType
init_bias
init_weight
init_state
dynamics_nonlinearity
gating_nonlinearity
α<:AbstractFloat
layernormQ::StaticBool
end
function FastGRUCell(
(in_dims,out_dims)::Pair{<:Lux.IntegerType,<:Lux.IntegerType},
Δt::T, τ::T,layernormQ::BoolType;
train_state::BoolType=False(),
init_weight=Lux.glorot_normal,
init_bias=Lux.zeros32,
init_state=zeros32,
dynamics_nonlinearity = Lux.sigmoid_fast,
gating_nonlinearity = Lux.tanh_fast) where T<:AbstractFloat
init_weight = ntuple(Returns(init_weight),3)
init_bias = ntuple(Returns(init_bias),3)
α = Δt/τ
return FastGRUCell(
static(train_state),
in_dims,out_dims,init_bias,init_weight,init_state,
dynamics_nonlinearity,gating_nonlinearity,α,static(layernormQ)
)
end
function Lux.initialparameters(rng::AbstractRNG,gru::FastGRUCell)
# hidden to hidden
Wz,Wr,Wh = (Lux.init_rnn_weight(
rng,init_weight,gru.out_dims,(gru.out_dims,gru.out_dims)) for init_weight in gru.init_weight)
# input to hidden
Uz,Ur,Uh = (Lux.init_rnn_weight(
rng,init_weight,gru.out_dims,(gru.out_dims,gru.in_dims)) for init_weight in gru.init_weight)
ps = (; Wz,Wr,Wh,Uz,Ur,Uh)
biasz,biasr,biash = (Lux.init_rnn_weight(rng,init_bias,gru.out_dims,gru.out_dims) for init_bias in gru.init_bias)
ps = merge(ps, (; biasz,biasr,biash))
Lux.has_train_state(gru) && (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),)))
return ps
end
Lux.initialstates(rng::AbstractRNG,::FastGRUCell) = (rng=Lux.Utils.sample_replicate(rng),)
function (gru::FastGRUCell{True})(x::AbstractMatrix,ps,st::NamedTuple)
hidden_state = Lux.init_trainable_rnn_hidden_state(ps.hidden_state, x)
return gru((x, (hidden_state,)), ps, st)
end
function (gru::FastGRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
rng = Lux.replicate(st.rng)
st = merge(st, (; rng))
hidden_state = Lux.init_rnn_hidden_state(rng, gru, x)
return gru((x, (hidden_state,)), ps, st)
end
const _FastGRUCellInputType = Tuple{
<:AbstractMatrix, Tuple{<:AbstractMatrix}}
function (m::FastGRUCell)(
(x,(h,))::_FastGRUCellInputType, ps,st::NamedTuple)
Wzh = fused_dense_bias_activation(identity,ps.Wz,h,ps.biasz)
Wrh = fused_dense_bias_activation(identity,ps.Wr,h,ps.biasr)
Uzx = fused_dense_bias_activation(identity,ps.Uz,x,nothing)
Urx = fused_dense_bias_activation(identity,ps.Ur,x,nothing)
z = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wzh,nothing,nothing) .+ Uzx)) : (@. m.gating_nonlinearity(Wzh+Uzx))
r = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wrh,nothing,nothing) .+ Urx)) : (@. m.gating_nonlinearity(Wrh+Urx))
Whh = fused_dense_bias_activation(identity,ps.Wh, h .* r ,ps.biash)
Uhh = fused_dense_bias_activation(identity,ps.Uh, x ,nothing)
h̃ = dynamic(m.layernormQ) ? (m.dynamics_nonlinearity.(layernorm(Whh,nothing,nothing) .+ Uhh)) : (@. m.dynamics_nonlinearity(Whh+Uhh))
h′ = @. (1-m.α * z) * h + m.α * z * h̃
return (h′,(h′,)),st
end
# --------------------------------------------------------------------------------------------------
# adapted from https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN#Creating-a-Classifier
using Lux, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:fastgru_cell, :classifier)}
fastgru_cell::L
classifier::C
end
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
FastGRUCell(in_dims => hidden_dims, 0.01f0, 1.0f0, true),
Dense(hidden_dims => out_dims, sigmoid))
end
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_fastgru = s.fastgru_cell(x_init, ps.fastgru_cell, st.fastgru_cell)
for x in x_rest
(y, carry), st_fastgru = s.fastgru_cell((x, carry), ps.fastgru_cell, st_fastgru)
end
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
st = merge(st, (classifier=st_classifier, fastgru_cell = st_fastgru))
return vec(y), st
end
# ----- loss
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
# ----- training
function main(model_type)
dev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
# x: (2,50,128), y: (128,) # dimension time trials
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
When I try to transfer my self-defined GRUcell to the tutorial MNIST Classification using Neural ODEs, I don't know how to start the job.
Really appreciate If anyone could help me!
Thanks!
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested