diff --git a/README.md b/README.md index ca28a01..326e59d 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,9 @@ remaining useful life (RUL) tasks by combining partially and fully monotonic networks. This example looks at predicting the RUL for turbofan engine degradation. +- [Battery State of Charge Estimation Using Monotonic Neural Networks](examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.md) +This example shows how to train two monotonic neural networks to estimate the state of charge (SOC) of a battery, one to model the charging behavior, and one to model the discharging behavior. In this example, you train the networks to predict the rate of change of the state of charge and force the output to be positive or negative for the charging and discharging networks, respectively. This way, you enforce monotonicity of the battery state of charge by constraining its derivative to be positive or negative. + - [Train Image Classification Lipschitz Constrained Networks and Measure Robustness to Adversarial Examples](examples/lipschitz/classificationDigits/LipschitzClassificationNetworksRobustToAdversarialExamples.md) @@ -125,17 +128,12 @@ more robust classification network. ## Functions -This repository introduces the following functions that are used throughout the -examples: - -- [`buildConstrainedNetwork`](conslearn/buildConstrainedNetwork.m) - Build a multi-layer perceptron (MLP) with constraints on the architecture and initialization of the weights. -- [`buildConvexCNN`](conslearn/buildConvexCNN.m) - Build a fully-inpt convex convolutional neural network (CNN). -- [`trainConstrainedNetwork`](conslearn/trainConstrainedNetwork.m) - Train a - constrained network and maintain the constraint during training. -- [`lipschitzUpperBound`](conslearn/lipschitzUpperBound.m) - Compute an upper - bound on the Lipschitz constant for a Lipschitz neural network. -- [`convexNetworkOutputBounds`](conslearn/convexNetworkOutputBounds.m) - Compute - guaranteed upper and lower bounds on hypercubic grids for convex networks. +This repository introduces the following functions that are used throughout the examples: +- [`buildConstrainedNetwork`](conslearn/buildConstrainedNetwork.m) - Build a multi-layer perceptron (MLP) with specific constraints on the architecture and initialization of the weights. +- [`buildConvexCNN`](conslearn/buildConvexCNN.m) - Build a convolutional neural network (CNN) with convex constraints on the architecture and initialization of the weights. +- [`trainConstrainedNetwork`](conslearn/trainConstrainedNetwork.m) - Train a constrained network and maintain the constraint during training. +- [`lipschitzUpperBound`](conslearn/lipschitzUpperBound.m) - Compute an upper bound on the Lipschitz constant for a Lipschitz neural network. +- [`convexNetworkOutputBounds`](conslearn/convexNetworkOutputBounds.m) - Compute guaranteed upper and lower bounds on hypercubic grids for convex networks. ## Tests diff --git a/conslearn/buildConstrainedNetwork.m b/conslearn/buildConstrainedNetwork.m index cea1a38..17c36e2 100644 --- a/conslearn/buildConstrainedNetwork.m +++ b/conslearn/buildConstrainedNetwork.m @@ -11,9 +11,8 @@ % The network includes either a featureInputLayer or an imageInputLayer, % depending on INPUTSIZE: % -% - If INPUTSIZE is a scalar, then the network has a featureInputLayer. -% -% - If INPUTSIZE is a vector with three elements, then the network has an +% - If INPUTSIZE is a scalar, then the network has a featureInputLayer. - +% If INPUTSIZE is a vector with three elements, then the network has an % imageInputLayer. % % NUMHIDDENUNITS is a vector of integers that corresponds to the sizes @@ -30,7 +29,8 @@ % ConvexNonDecreasingActivation - Convex, non-decreasing % ("fully-convex") activation functions. % ("partially-convex") The options are "softplus" or -% "relu". The default is "softplus". +% "relu". +% The default is "softplus". % Activation - Network activation function. % ("partially-convex") The options are "tanh", "relu" or % "fullsort". The default is "tanh". @@ -80,9 +80,9 @@ % "fullsort". The default is % "fullsort". % UpperBoundLipschitzConstant - Upper bound on the Lipschitz -% constant for the network, as a -% positive real number. The default -% value is 1. +% constant +% for the network, as a positive real +% number. The default value is 1. % pNorm - p-norm value for measuring % distance with respect to the % Lipschitz continuity definition. diff --git a/conslearn/trainConstrainedNetwork.m b/conslearn/trainConstrainedNetwork.m index 61ae9ab..73f9bc2 100644 --- a/conslearn/trainConstrainedNetwork.m +++ b/conslearn/trainConstrainedNetwork.m @@ -31,6 +31,12 @@ % iteration, specified as: "mse", "mae", or % "crossentropy". % The default is "mse". +% L2Regularization - Factor for L2 regularization (weight decay). +% The default is 0. +% ValidationData - Data to use for validation during training, +% specified as a minibatchqueue object. +% ValidationFrequency - Frequency of validation in number of +% iterations. The default is 50. % TrainingMonitor - Flag to display the training progress monitor % showing the training data loss. % The default is true. @@ -73,6 +79,9 @@ trainingOptions.LossMetric {... mustBeTextScalar, ... mustBeMember(trainingOptions.LossMetric,["mse","mae","crossentropy"])} = "mse"; + trainingOptions.L2Regularization (1,1) {mustBeNumeric, mustBeNonnegative} = 0 + trainingOptions.ValidationData minibatchqueue {mustBeScalarOrEmpty} = minibatchqueue.empty + trainingOptions.ValidationFrequency (1,1) {mustBeNumeric, mustBePositive, mustBeInteger} = 50 trainingOptions.TrainingMonitor (1,1) logical = true; trainingOptions.TrainingMonitorLogScale (1,1) logical = true; trainingOptions.ShuffleMinibatches (1,1) logical = false; @@ -84,26 +93,39 @@ % Set up the training progress monitor if trainingOptions.TrainingMonitor monitor = trainingProgressMonitor; + + % Track progress information monitor.Info = ["LearningRate","Epoch","Iteration"]; - monitor.Metrics = "TrainingLoss"; + + % Plot the training and validation metrics on the same plot + monitor.Metrics = ["TrainingLoss", "ValidationLoss"]; + groupSubPlot(monitor, "Loss", ["TrainingLoss", "ValidationLoss"]); + % Apply loss log scale if trainingOptions.TrainingMonitorLogScale - yscale(monitor,"TrainingLoss","log"); + yscale(monitor,"Loss","log"); end + % Specify the horizontal axis label for the training plot. monitor.XLabel = "Iteration"; + % Start the monitor monitor.Status = "Running"; stopButton = @() ~monitor.Stop; else + % Let training run without a monitor by setting stop to false stopButton = @() 1; end + % Prepare the generic hyperparameters maxEpochs = trainingOptions.MaxEpochs; initialLearnRate = trainingOptions.InitialLearnRate; decay = trainingOptions.Decay; metric = trainingOptions.LossMetric; shuffleMinibatches = trainingOptions.ShuffleMinibatches; +l2Regularization = trainingOptions.L2Regularization; +validationData = trainingOptions.ValidationData; +validationFrequency = trainingOptions.ValidationFrequency; % Specify ADAM options avgG = []; @@ -147,7 +169,7 @@ % Evaluate the model gradients, and loss using dlfeval and the % modelLoss function and update the network state. - [lossTrain,gradients,state] = dlfeval(@iModelLoss,net,X,T,metric); + [lossTrain,gradients,state] = dlfeval(dlaccelerate(@iModelLoss),net,X,T,metric,l2Regularization); net.State = state; % Gradient Update @@ -162,10 +184,33 @@ LearningRate=learnRate, ... Epoch=string(epoch) + " of " + string(maxEpochs), ... Iteration=string(iteration)); + recordMetrics(monitor,iteration, ... TrainingLoss=lossTrain); + monitor.Progress = 100*epoch/maxEpochs; end + + % Record validation loss, if requested + if ~isempty(validationData) + if (iteration == 1) || (mod(iteration, validationFrequency) == 0) + + % Reset the validation data + if ~hasdata(validationData) + reset(validationData); + end + + % Compute the validation loss + [X, T] = next(validationData); + lossValidation = iModelLoss(net, X, T, metric, l2Regularization); + + % Update the training monitor + if trainingOptions.TrainingMonitor + recordMetrics(monitor,iteration, ... + ValidationLoss=lossValidation); + end + end + end end end @@ -181,23 +226,29 @@ end %% Helpers -function [loss,gradients,state] = iModelLoss(net,X,T,metric) +function [loss,gradients,state] = iModelLoss(net,X,T,metric,l2Regularization) % Make a forward pass -[Y,state] = forward(net,X); +[Y, state] = forward(net,X); % Compute the loss switch metric case "mse" loss = mse(Y,T); case "mae" - loss = mean(abs(Y-T)); + loss = mean(abs(Y-T), 'all'); case "crossentropy" loss = crossentropy(softmax(Y),T); end -% Compute the gradient of the loss with respect to the learnabless -gradients = dlgradient(loss,net.Learnables); +if nargout > 1 + % Compute the gradient of the loss with respect to the learnables + gradients = dlgradient(loss,net.Learnables); + + % Apply L2 regularization + idxWeights = net.Learnables.Parameter == "Weights"; + gradients(idxWeights,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idxWeights, :), net.Learnables(idxWeights, :)); +end end function proximalOp = iSetupProximalOperator(constraint,trainingOptions) diff --git a/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.md b/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.md index fad1e0b..1975ef5 100644 --- a/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.md +++ b/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.md @@ -1,11 +1,11 @@ # Train Fully Convex Neural Network for Image Classification -This example shows how to create a fully input convex neural network and train it on CIFAR\-10 data. This example uses fully connected based convex networks, rather than the more typical convolutional networks, proven to give higher accuracy on the training and test data set. The aim of this example is to demonstrate the expressive capabilities convex constrained networks have by classifying natural images and demonstrating high accuracies on the training set. Further discussion on the expressive capabilities of convex networks for tasks including image classification can be found in \[1\]. +This example shows how to create a fully input convex convolutional neural network (FIC-CNN) and train it on CIFAR\-10 data \[1\]. # Prepare Data -Download the CIFAR\-10 data set \[1\]. The data set contains 60,000 images. Each image is 32\-by\-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time. +Download the CIFAR\-10 data set \[2\]. The data set contains 60,000 images. Each image is 32\-by\-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time. ```matlab datadir = "."; @@ -18,16 +18,6 @@ Load the CIFAR\-10 training and test images as 4\-D arrays. The training set con [XTrain,TTrain,XTest,TTest] = loadCIFARData(datadir); ``` -For illustration in this example, subsample this data set evenly in each class. You can increase the number of samples by moving the slider to smaller values. - -```matlab -subSampleFrequency = 10; -XTrain = XTrain(:,:,:,1:subSampleFrequency:end); -XTest = XTest(:,:,:,1:subSampleFrequency:end); -TTrain = TTrain(1:subSampleFrequency:end); -TTest = TTest(1:subSampleFrequency:end); -``` - You can display a random sample of the training images using the following code.
@@ -37,34 +27,37 @@ im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]);
 imshow(im)
 
-# Define FICNN Network Architecture +# Define FIC-CNN Network Architecture -Use the buildConstrainedNetwork function to create a fully input convex neural network suitable for this data set. +Use the buildConvexCNN function to create a fully input convex convolutional neural network suitable for this data set. -- The CIFAR\-10 images are 32\-by\-32 pixels. Therefore, create a fully convex network specifying the inputSize=[32 32 3]. -- Specify a vector a hidden unit sizes of decreasing value in numHiddenUnits. The final number of outputs of the network must be equal to the number of classes, which in this example is 10. +- The CIFAR\-10 images are 32\-by\-32 pixels, and belong to one of ten classes. Therefore, create a fully convex network specifying the inputSize=[32 32 3] and the numClasses=10. +- For each convolutional layer, specify the filter size in filterSize, the number of filters in numFilters, and the stride size in stride. ```matlab inputSize = [32 32 3]; -numHiddenUnits = [512 128 32 10]; +numClasses = 10; +filterSize = [3; 3; 3; 3; 3; 1; 1]; +numFilters = [96; 96; 192; 192; 192; 192; 10]; +stride = [1; 2; 1; 2; 1; 1; 1]; ``` -Seed the network initialization for reproducibility. +Seed the network initialization for reproducibility. ```matlab rng(0); -ficnnet = buildConstrainedNetwork("fully-convex",inputSize,numHiddenUnits) +ficnnet = buildConvexCNN(inputSize, numClasses, filterSize, numFilters, Stride=stride) ``` ```matlabTextOutput ficnnet = dlnetwork with properties: - Layers: [15x1 nnet.cnn.layer.Layer] - Connections: [17x2 table] - Learnables: [14x3 table] - State: [0x3 table] - InputNames: {'image_input'} - OutputNames: {'add_4'} + Layers: [24x1 nnet.cnn.layer.Layer] + Connections: [23x2 table] + Learnables: [30x3 table] + State: [14x3 table] + InputNames: {'input'} + OutputNames: {'fc_+_end'} Initialized: 1 View summary with summary. @@ -72,25 +65,26 @@ ficnnet = ``` ```matlab -plot(ficnnet) +plot(ficnnet); ```

- +

# Specify Training Options -Train for a specified number of epochs with a mini\-batch size of 256. To attain high training accuracy, you may need to train for a larger number of epochs, for example numEpochs=8000, which could take several hours. +Train for a specified number of epochs with a mini\-batch size of 256. To attain high training accuracy, you may need to train for a larger number of epochs, for example numEpochs=400, which could take several hours. ```matlab -numEpochs = 8000; +numEpochs = 400; miniBatchSize = 256; -initialLearnRate = 0.1; -decay = 0.005; +initialLearnRate = 0.0025; +decay = eps; lossMetric = "crossentropy"; +l2Regularization = 1e-4; ``` Create a minibatchqueue object that processes and manages mini\-batches of images during training. For each mini\-batch: @@ -103,6 +97,7 @@ Create a minibatchqueue object that processes and manages mini\-bat xds = arrayDatastore(XTrain,IterationDimension=4); tds = arrayDatastore(TTrain,IterationDimension=1); cds = combine(xds,tds); + mbqTrain = minibatchqueue(cds,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... @@ -160,7 +155,7 @@ disp("Training accuracy: " + (1-trainError)*100 + "%") ``` ```matlabTextOutput -Training accuracy: 90.4848% +Training accuracy: 70.2123% ``` Compute the accuracy on the test set. @@ -173,7 +168,7 @@ disp("Test accuracy: " + (1-testError)*100 + "%") ``` ```matlabTextOutput -Test accuracy: 27.4554% +Test accuracy: 66.266% ``` The networks output has been constrained to be convex in every pixel in every colour. Even with this level of restriction, the network is able to fit reasonably well to the training data. You can see poor accuracy on the test data set but, as discussed at the start of the example, it is not anticipated that such a fully input convex network comprising of fully connected operations should generalize well to natural image classification. @@ -197,14 +192,14 @@ cm.RowSummary = "row-normalized"; To summarise, the fully input convex network is able to fit to the training data set, which is labelled natural images. The training can take a considerable amount of time owing to the weight projection to the constrained set after each gradient update, which slows down training convergence. Nevertheless, this example illustrates the flexibility and expressivity convex neural networks have to correctly classifying natural images. -# Supporting Functions -## Mini Batch Preprocessing Function +# Supporting Functions +## Mini\-Batch Preprocessing Function -The preprocessMiniBatch function preprocesses a mini\-batch of predictors and labels using the following steps: +The preprocessMiniBatch function preprocesses a mini\-batch of predictions and labels using the following steps: 1. Preprocess the images using the preprocessMiniBatchPredictors function. 2. Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension. -3. One\-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output. +3. One\-hot encode the categorical labels into numeric arrays. Encoding in the first dimension produces an encoded array that matches the shape of the network output. ```matlab function [X,T] = preprocessMiniBatch(dataX,dataT) @@ -219,19 +214,20 @@ T = onehotencode(T,1); end ``` -## Mini\-Batch Predictors Preprocessing Function +## Mini\-Batch Predictors Preprocessing Function -The preprocessMiniBatchPredictors function preprocesses a mini\-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. You divide by 255 to normalize the pixels to [0,1] range. +The preprocessMiniBatchPredictors function preprocesses a mini\-batch of predictors by extracting the image data from the input cell array and concatenating it into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. You divide by 255 to normalize the pixels to [0,1] range. ```matlab function X = preprocessMiniBatchPredictors(dataX) -X = single(cat(4,dataX{1:end}))/255; +X = (single(cat(4,dataX{1:end}))/255); % Normalizes to [0, 1] +X = 2*X - 1; % Normalizes to [-1, 1]. end ``` -# References - -\[1\] Amos, Brandon, et al. Input Convex Neural Networks. arXiv:1609.07152, arXiv, 14 June 2017. arXiv.org, https://doi.org/10.48550/arXiv.1609.07152. +# References +\[1\] Amos, Brandon, et al. "Input Convex Neural Networks." (2017). https://doi.org/10.48550/arXiv.1609.07152. -*Copyright 2024 The MathWorks, Inc.* +\[2\] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf +*Copyright 2024-2025 The MathWorks, Inc.* \ No newline at end of file diff --git a/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.mlx b/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.mlx index d495b3e..b3ba9fd 100644 Binary files a/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.mlx and b/examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.mlx differ diff --git a/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig1.png b/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig1.png index 6f49779..ec550b8 100644 Binary files a/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig1.png and b/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig1.png differ diff --git a/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig2.png b/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig2.png index b11b61f..145cb5b 100644 Binary files a/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig2.png and b/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig2.png differ diff --git a/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig3.png b/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig3.png index 6f1bea2..3dcde0e 100644 Binary files a/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig3.png and b/examples/convex/classificationCIFAR10/figures/TrainICNN_Fig3.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BSOC_25_degrees_testing.mat b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BSOC_25_degrees_testing.mat new file mode 100644 index 0000000..50f2a01 Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BSOC_25_degrees_testing.mat differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BSOC_25_degrees_training.mat b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BSOC_25_degrees_training.mat new file mode 100644 index 0000000..a73276f Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BSOC_25_degrees_training.mat differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.md b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.md new file mode 100644 index 0000000..86a6c49 --- /dev/null +++ b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.md @@ -0,0 +1,535 @@ + +# Battery State of Charge Estimation Using Monotonic Neural Networks + +This example shows how to train two monotonic neural networks to estimate the state of charge (SOC) of a battery, one to model the charging behavior, and one to model the discharging behavior. Guaranteeing monotonic behaviour can be crucial in deploynment to safety\-critical systems. For example, in applications in predictive maintenance, violations in the monotonic decrease of the remaining useful life of system component implies that as time passes and the components degrade, their remaining useful life actually increases. + + +A reasonable requirement on the system is that battery SOC should always be increasing with increasing time in the charging phase and decreasing with increasing time in the discharging phase. This defines a monotonic requirement on the system. This example shows how to create a network architecture that guarantees monotonic output sequences. For an example showing how to model the state of charge of a battery using a traditional deep learning approach, see [Battery State of Charge Estimation Using Deep Learning](https://www.mathworks.com/help/deeplearning/ug/battery-state-of-charge-estimation-using-deep-learning.html). + + +To guarantee monotonicity, train the networks to predict the rate of change of the state of charge and force the output to be positive or negative for the charging and discharging networks, respectively. This is equivalent to constraining the derivative to be postive or negative. If the rate of change is positive, then the state of charge is monotonically increasing, and vice versa. Forcing the outputs to be positive (or negative) is a simple constraint that still allows for a wide variety of expressive neural networks. Crucially, this constraint on the network architecture guarantees the monotonicity at each training iteration, including initialization of the network. This means that your monotonic requirement can be traced through the entire neural network training phase. + +# Preprocess Training Data + +The data used in this example was generated from a Simulink model simulating a battery charging and discharging over several hours. The Simulink model is identical to the model in [Battery State\-of\-Charge Estimation](https://www.mathworks.com/help/simscape-battery/ug/battery-state-of-charge-estimation.html), except for the replacement of the Current profile subsystem with two [Cycler](https://www.mathworks.com/help/simscape-battery/ref/cycler.html) blocks to simulate ideal charging and discharging. + +- The input data has three variables: battery temperature (°C), voltage (V), and current (A) . +- The output data is the state of charge. +- The simulated ambient temperature is 25°C. +- The battery charges for 9000 seconds and then discharges for 3000 seconds. This cycle repeats for 10 hours. +- The initial state of charge is 0.5. +## Load Data + +Load the training data. The data is attached to this example as supporting files. + +```matlab +data = load("BSOC_25_degrees_training.mat"); + +XTrain = data.X; +YTrain = data.Y'; +``` +## Split into Charging and Discharging Subsets + +Define a `cycleLength` and `chargeRatio` for the signal and split the data into charging and charging subsets using `splitChargingCycles`. Integrated in the full system, the charging state would be determined by other components in the system. + +```matlab +cycleLength = 12000; +chargeRatio = 0.75; +[XTrainCharge,XTrainDischarge] = splitChargingCycles(XTrain,cycleLength,chargeRatio); +[YTrainCharge,YTrainDischarge,chargingIdx,dischargingIdx] = splitChargingCycles(YTrain,cycleLength,chargeRatio); +``` + +Visualize the charging and discharging target data using `plotTrainingData`. + +```matlab +plotTrainingData(YTrainCharge,YTrainDischarge,chargingIdx,dischargingIdx) +``` + +![Plot of charging/discharging split in the training data.](./figures/trainingData.png) +## Chunk Data + +Split the data into chunks to prepare it for training. Discard remaining sequences that are too short. + +```matlab +function [XChunkedCell, YChunkedCell] = chunkData(XCell,YCell,chunkSize,stride) + +XChunkedCell = []; +YChunkedCell = []; + +for i=1:numel(XCell) + + X = XCell{i}; + Y = YCell{i}; + + numSamples = length(Y); + numChunks= floor(numSamples/(chunkSize-stride))-1; + XChunked = cell(numChunks,1); + YChunked = cell(numChunks,1); + + for j = 1:numChunks + idxStart = 1+(j-1)*stride; + idxEnd = idxStart+chunkSize-1; + XChunked{j} = X(idxStart:idxEnd,:); + YChunked{j} = Y(idxStart:idxEnd); + end + + XChunkedCell = [XChunkedCell; XChunked]; + YChunkedCell = [YChunkedCell; YChunked]; +end +end + +chunkSize = 200; +stride = 100; + +[XTrainCharge,YTrainCharge] = chunkData(XTrainCharge,YTrainCharge,chunkSize,stride); +[XTrainDischarge,YTrainDischarge] = chunkData(XTrainDischarge,YTrainDischarge,chunkSize,stride); +``` +# Create Constrained Network + +The SOC problem can be viewed as a cumulative problem. SOC increases incrementally over time and you can add this increment to previous values to find the new SOC. The voltage, current, and temperature at a given time determine the increase or decrease in SOC during the next time period. Train a network to predict the different between the SOC at consecutive timesteps. You can then enforce monotonically increasing or decreasing SOC by ensuring that the network output is negative or positive, respectively. + + $$ \Delta y\left(t\right)=y\left(t+1\right)-y\left(t\right)=g\left(x\left(t\right)\right)\;\ldotp $$ + +Calculate the SOC at a time `t`. + + $$ y\left(t\right)=y\left(t-1\right)+\Delta y\left(t-1\right)=y\left(t-1\right)+g\left(x\left(t-1\right)\right)=y\left(1\right)+\sum_{k=1}^t g\left(x\left(k-1\right)\right) $$ +## Data Preprocessing for Constrained Network + +To train the new network g(x), preprocess the target data as follows: + +1. Calculate the difference between the SOC at t and t+1. +2. Normalize the differences since they will be much smaller than the weights. + +To ensure that applying an inverse normalization to the network outputs cannot cause the output to change sign and break monotonicity, normalize the targets using scaling and no offset. + + +To preprocess the target data for training on the differences, define the preprocessing function, `preprocessDiffTargets`. The function finds the difference between the target data at two consecutive time steps and applies normalization. To apply an inverse normalization on the network outputs during testing, the function also outputs the calculated normalization statistic `stdDiff`. + +```matlab +function [XDiff, YDiff, stdDiff] = preprocessDiffTargets(X,Y) + +YDiff = cellfun(@(x) diff(x),Y,UniformOutput=false); + +stdDiff = std(cell2mat(YDiff)); +YDiff = cellfun(@(x) x/stdDiff,YDiff,UniformOutput=false); + +% Remove the final value in the input sequence +XDiff = cellfun(@(x) x(1:end-1,:),X,UniformOutput=false); + +end + +[XTrainChargingDifference, YTrainChargingDifference, stdDiffCharge] = preprocessDiffTargets(XTrainCharge,YTrainCharge); +[XTrainDischargingDifference, YTrainDischargingDifference, stdDiffDischarge] = preprocessDiffTargets(XTrainDischarge,YTrainDischarge); +``` +## Constrain Sign of Network Output + +To enforce monotonicity during charging, you can constrain the charging the network output to always be positive by applying a positive activation function, such as `relu`. + + $$ \Delta y\left(t\right)=y\left(t+1\right)-y\left(t\right)=\textrm{ReLU}\left(h_{\textrm{charging}} \left(x\left(t\right)\right)\right) $$ + +For discharging, multiply the output by `-1` before and after applying the `relu` operation. + + $$ \Delta y\left(t\right)=y\left(t+1\right)-y\left(t\right)=-\textrm{ReLU}\left({-h}_{\textrm{discharging}} \left(x\left(t\right)\right)\right) $$ +## Train Charging Network + +The original recurrent neural network architecture (RNN) used for battery state of charge estimation is described in [Battery State of Charge Estimation Using Deep Learning](https://www.mathworks.com/help/deeplearning/ug/battery-state-of-charge-estimation-using-deep-learning.html). This example adapts that network by training two separate networks for the charging and discharging phases. Each network uses an LSTM network architecture with blocks that consist of an LSTM layer followed by a dropout layer. The dropout layers help to avoid overfitting. + + +Define the network architecture. Create an RNN that consists of two blocks containing an `lstmLayer` followed by a `dropoutLayer`, with decreasing `numHiddenUnits` between the blocks and a dropout probability of 0.2. Since the network predicts the battery state of charge (SOC), add a `fullyConnectedLayer` with an output size of `numRespones`. To ensure the outputs are positive, add a `reluLayer` at the end. + +```matlab +numFeatures = size(XTrainCharge{1},2); +numHiddenUnits = 32; +numResponses = size(YTrainCharge{1},2); + +layers = [sequenceInputLayer(numFeatures,Normalization="rescale-zero-one") + lstmLayer(numHiddenUnits) + dropoutLayer(0.2) + lstmLayer(numHiddenUnits/2) + dropoutLayer(0.2) + fullyConnectedLayer(numResponses) + reluLayer]; +``` + +Specify the training options for the network. Train for 50 epochs with mini\-batches of size 16 using the Adam optimizer. To decrease the learn rate schedule throughout training, set `LearnRateSchedule` to `cosine`. Specify the learning rate as 0.001. To prevent the gradients from exploding, set the gradient threshold to 1. Since the state is not carried between chunks, set "`Shuffle`" to "`every-epoch`" to ensure batches are not dominated by specific sequences in the charging data. Turn on the training progress plot, and turn off the command window output (`Verbose`). + +```matlab +epochs = 50; +miniBatchSize = 16; + +chargingOptions = trainingOptions("adam", ... + MaxEpochs=epochs, ... + GradientThreshold=1, ... + InitialLearnRate=0.001, ... + LearnRateSchedule="cosine", ... + MiniBatchSize=miniBatchSize, ... + Verbose=0, ... + Plots="training-progress", ... + Shuffle="every-epoch"); +``` + +Train the charging network using `trainnet`. + +```matlab +chargingConstrainedNet = trainnet(XTrainChargingDifference,YTrainChargingDifference,layers,"mse",chargingOptions); +``` + +![Training progress plot for the charging network.](./figures/trainingPlotCharging.png) +## Train Discharging Network + +To enforce negative outputs for the discharging network, surround the `reluLayer` with two `scalingLayer` objects with scale `-1`. + +```matlab +layers = [sequenceInputLayer(numFeatures,Normalization="rescale-zero-one") + lstmLayer(numHiddenUnits) + dropoutLayer(0.2) + lstmLayer(numHiddenUnits/2) + dropoutLayer(0.2) + fullyConnectedLayer(numResponses) + scalingLayer(Scale=-1) + reluLayer + scalingLayer(Scale=-1)]; +``` + +Since there are fewer observations for the discharging phase, increase the number of epochs before training the discharging network using `trainnet`. + +```matlab +dischargingOptions = chargingOptions; +dischargingOptions.MaxEpochs = epochs * ceil(numel(XTrainCharge)/numel(XTrainDischarge)); +dischargingConstrainedNet = trainnet(XTrainDischargingDifference,YTrainDischargingDifference,layers,"mse",dischargingOptions); +``` + +![Training progress plot for the discharging network.](./figures/trainingPlotDischarging.png) +# Test Networks +## Train Unconstrained Networks + +To compare the performance of the network with other unconstrained networks, train four unconstrained networks using the supporting function `trainUnconstrainedNetworks`. `chargingConstrainedNet` and `dischargingUnconstrainedNet` use the raw predictors, not the differences, as the network inputs. `chargingUnconstrainedDiffNet` and `dischargingUnconstrainedDiffNet` are trained on the differences but do not constrain the output to be positive or negative. The rest of the network architecture and the training options are the same as the constrained networks. + +```matlab +[chargingUnconstrainedNet,dischargingUnconstrainedNet,chargingUnconstrainedDiffNet,dischargingUnconstrainedDiffNet] = trainUnconstrainedNetworks(XTrainCharge,YTrainCharge, ... + XTrainDischarge,YTrainDischarge, ... + XTrainChargingDifference,YTrainChargingDifference, ... + XTrainDischargingDifference,YTrainDischargingDifference, ... + chargingOptions, dischargingOptions); +``` +## Predict SOC + +The test data is several hours of simulated battery charging data at the same temperature as the test data. + +```matlab +data = load("BSOC_25_degrees_testing.mat"); +XTest = data.X; +YTest = data.Y'; +``` + +To find the SOC for the test input data, split the input data into charging and discharging data and predict the SOC using the corresponding network. Apply inverse normalization to the network output using the normalization statistics calculated by `preprocessDiffTargets.` + +```matlab +function YPred = getCombinedNetworkOutputs(XTest,chargingNet,chargeScale,dischargingNet,dischargeScale,cycleLength,chargeRatio) + +[XTestCharge, XTestDischarge, chargingIdx, dischargingIdx] = splitChargingCycles(XTest,cycleLength,chargeRatio); + +YPredCharge = cell(numel(XTestCharge),1); +YPredDischarge = cell(numel(XTestDischarge),1); + +% Predict for charge cycles +for i = 1:numel(XTestCharge) + [Yout, state] = predict(chargingNet, XTestCharge{i}); + YPredCharge{i} = chargeScale*Yout; + chargingNet.State = state; +end + +% Predict for discharge cycles +for i = 1:numel(XTestDischarge) + [Yout, state] = predict(dischargingNet, XTestDischarge{i}); + YPredDischarge{i} = dischargeScale*Yout; + dischargingNet.State = state; +end + +% Concatenate predictions +YPred = cell(numel(YPredCharge)+numel(YPredDischarge),1); +YPred(chargingIdx) = YPredCharge; +YPred(dischargingIdx) = YPredDischarge; +YPred = cell2mat(YPred); + +end +``` + +To find the SOC value at each time step of the networks trained on the difference, cumulatively sum the network ouputs. + +```matlab +function YPred = getSOCFromDiffOutput(YPredDiff,SOC0) + +YPred = SOC0 + [0; cumsum(YPredDiff)]; +YPred = YPred(1:end-1); + +end +``` +## Compare Networks +### Plot Test Data + +Find the predicted SOC of the test data for each network and plot the results. + +```matlab +YPredUnconstrained = getCombinedNetworkOutputs(XTest,chargingUnconstrainedNet,1,dischargingUnconstrainedNet,1,cycleLength,chargeRatio); + +figure +hold on +plot(YPredUnconstrained) +plot(YTest) +legend(["Predicted SOC" "True SOC"],Location="bestoutside") +title("Unconstrained Network Trained on Raw SOC Values") +hold off +``` + +![Plot comparing true and predicted SOC for the unconstrained network.](./figures/testPlotUnconstrained.png) + +The unconstrained network trained on the raw SOC values shows an overall good fit to the test data. You can see the noise typical of an RNN and the impact of resetting the state between the two networks. + +```matlab +YOutUnconstrainedDiff = getCombinedNetworkOutputs(XTest,chargingUnconstrainedDiffNet,stdDiffCharge,dischargingUnconstrainedDiffNet,stdDiffDischarge,cycleLength,chargeRatio); + +YPredUnconstrainedDiff = getSOCFromDiffOutput(YOutUnconstrainedDiff,YTest(1)); + +figure +hold on +plot(YPredUnconstrainedDiff) +plot(YTest) +legend(["Predicted SOC" "True SOC"],Location="bestoutside") +title("Unconstrained Network Trained on Differences") +hold off +``` + +![Plot comparing true and predicted SOC for the unconstrained network trained on differences.](./figures/testPlotUnconstrainedDiff.png) + +The unconstrained network trained on the differences performs significantly better than the original network. It is less noisy and responds well to the random changes in discharge voltage current. + +```matlab +YOutConstrained = getCombinedNetworkOutputs(XTest,chargingConstrainedNet,stdDiffCharge,dischargingConstrainedNet,stdDiffDischarge,cycleLength,chargeRatio); + +YPredConstrained = getSOCFromDiffOutput(YOutConstrained,YTest(1)); + +figure +hold on +plot(YPredConstrained) +plot(YTest) +legend(["Predicted SOC" "True SOC"],Location="bestoutside") +title("Constrained Network Trained on Differences") +hold off +``` + +![Plot comparing true and predicted SOC for the constrained network trained on differences.](./figures/testPlotConstrainedDiff.png) + +The monotonic network trained on the differences performs comparably to the unconstrained network trained on the differences. + + +Save the results in a table: + +```matlab +predictions = table(["Unconstrained";"Unconstrained Diff";"Constrained"],[{YPredUnconstrained}; {YPredUnconstrainedDiff}; {YPredConstrained}],VariableNames=["Network","YPred"]); +``` +### RMSE + +For each network, calculate the RMSE. + +```matlab +function rmseValue = findRMSE(YTest,YPred) +rmseValue = sqrt(mean((YPred - YTest).^2)); +end + +predictions = [predictions rowfun(@(x) findRMSE(YTest,x{:}),predictions(:,"YPred"),OutputVariableNames="RMSE")]; +``` + +Plot the RMSE for each network: + +```matlab +figure +bar(predictions.Network,predictions.RMSE) +xlabel("Network") +ylabel("RMSE") +``` + +![Plot showing test RMSE for the three trained networks.](./figures/testRMSE.png) + +The two networks trained on the differences have significantly lower RMSEs than the other network. The constrained network performs slightly worse than the unconstrained network. + +### Monotonicity + +Monotonicity is guaranteed for the constrained network. You can sample the training or test set to get an idea of violations of monotonicity for the unconstrained networks. A convenient way to assess violation to monotonicity is to define the `monotonicityScore` to measure the degree of monotonicity of a signal. This is the ratio of intervals between two adjacent signals. + + $$ +\textrm{monotonicityScore} = \frac{1}{N-1} \sum_{n=1}^{N-1} +\begin{cases} +1 & \text{if } y(n+1) \geq y(n) \\ +0 & \text{if } y(n+1) < y(n) +\end{cases} +$$ + +A fully monotonic increasing signal has a `monotonicityScore` of 1 and a fully monotonic decreasing signal has a `monotonicityScore` of 0. + + +Define `getMonotonicityScore` to calculate the `monotoncityScore` of the network output. Use the `cycleLength` and `chargeRatio` to determine if the signal should be monotonically increasing or decreasing. The function also outputs the index of SOC values which break monotonicity. + +```matlab +function [monotonicityScore, brokenMonotocityIdx] = getMonotonicityScore(YPred,cycleLength,chargeRatio) + +[YPredCharge, YPredDischarge, chargingIdx, dischargingIdx] = splitChargingCycles(YPred,cycleLength,chargeRatio); + +YPredDiff = cell(numel(YPredCharge)+numel(YPredDischarge),1); + +YPredDiffCharge = cellfun(@(x) diff(x),YPredCharge,"UniformOutput",false); +YPredDiffDischarge = cellfun(@(x) -diff(x),YPredDischarge,"UniformOutput",false); + +YPredDiff(chargingIdx) = YPredDiffCharge; +YPredDiff(dischargingIdx) = YPredDiffDischarge; + +YPredDiff = cell2mat(YPredDiff); + +monotonicityScore = sum(YPredDiff>=0)/numel(YPredDiff); + +idx = 1:numel(YPredDiff); +brokenMonotocityIdx = {idx(YPredDiff<0)}; + +end + +predictions = [predictions rowfun(@(x) getMonotonicityScore(x{:},cycleLength,chargeRatio),predictions(:,"YPred"),OutputVariableNames=["Monotonicity","BrokenMonotonicityIdx"])]; +``` + +Plot the `monotonicityScore` for the three networks: + +```matlab +figure +bar(predictions.Network,predictions.Monotonicity) +xlabel("Network") +ylabel("Monotonicity Score") +``` + +![Plot showing monotonicity score for the three trained networks.](./figures/monotonicityScore.png) + +The unconstrained network trained on the raw SOC data has a `monotonicityScore` of 0.52 meaning that the signal increases nearly as much as it decreases. The unconstrained network trained on the differences is closer to being monotonic that the original network with a score of 0.99, however, this emphasises that 1% of the test set violates monotonicity. The constrained network is monotonic by construction and thus achieves a score of 1. + + +You can sample a small section of the SOC where monotonicity was violated for the unconstrained network trained on differences. Also plotted in the same region is the monotonic network. + +```matlab +figure +brokenIdx1 = predictions{2,"BrokenMonotonicityIdx"}{1}(1); +pltIdx = brokenIdx1-5:brokenIdx1+10; + +colororder({'k','k'}) +yyaxis left +diffPlot = predictions{2,"YPred"}{1}(pltIdx); +plot(pltIdx,diffPlot,Color="#0072BD") +ylabel("SOC: Unconstrained Diff") + +yyaxis right +monoPlot = predictions{3,"YPred"}{1}(pltIdx); +plot(pltIdx,monoPlot,Color="#0072BD",LineStyle="-.") +ylabel("SOC: Constrained") +diffPlotRange = max(diffPlot)-min(diffPlot); +ylim([median(monoPlot)-diffPlotRange/2 median(monoPlot)+diffPlotRange/2]) + +legend(["Unconstrained Diff" "Constrained"]) + +xlabel("Time (s)") +``` + +![Plot showing the unconstrained network breaking monotonicity.](./figures/brokenMonotonicity.png) + +The results overall show that training on the differences gives better performance than training directly on the data for this SOC data. Constraining the outputs to be monotonic comes with a small cost in performance however, monotonicity can be guaranteed. + +# Helper Functions +## Split Data into Charging Cycles Function + +This function splits data into charging and discharging cycles using information about the simulated charging cycle: `cycleLength` and `chargeRatio`. + +```matlab +function [chargingData, dischargingData, chargingIdx, dischargingIdx] = splitChargingCycles(data,cycleLength,chargeRatio) + +chargeLength = floor(cycleLength*chargeRatio); +dischargeLength = cycleLength-chargeLength; +numCycles = floor(size(data,1)/(cycleLength)); + +chargingData = cell(numCycles+1,1); +dischargingData = cell(numCycles-1,1); + +chargingIdx = zeros(numCycles+1,1); +dischargingIdx = zeros(numCycles-1,1); + +p=1; + +% Split the test data into charge and discharge cycles +for i = 1:numCycles + startIdx = (i - 1) * cycleLength + 1; + endIdx = startIdx + dischargeLength - 1; + + if i>1 + dischargingData{i-1} = data(startIdx:endIdx, :); + dischargingIdx(i-1) = p; + p = p + 1; + else + chargingData{i} = data(startIdx:endIdx, :); + chargingIdx(i) = p; + p = p + 1; + end + + startIdx = endIdx + 1; + + endIdx = startIdx + chargeLength - 1; + if i == numCycles + endIdx = size(data,1); + end + + chargingData{i+1} = data(startIdx:endIdx, :); + chargingIdx(i+1) = p; + p = p + 1; + +end + +end +``` +## Plot Training Data Function + +This function plots the training data to visualize the charging/discharging split. + +```matlab +function plotTrainingData(YTrainCharge,YTrainDischarge,chargingIdx,dischargingIdx) + +YTrainSorted = [YTrainCharge; YTrainDischarge]; +isCharging = [ones(numel(chargingIdx),1); zeros(numel(dischargingIdx),1)]; +[~,sortIdx] = sort([chargingIdx; dischargingIdx]); +YTrainSorted = YTrainSorted(sortIdx); +isCharging = isCharging(sortIdx); +t = 1; + +figure + +for i=1:numel(YTrainSorted) + + yPlot = YTrainSorted{i}; + xPlot = t:(t+numel(yPlot)-1); + + if isCharging(i)==1 + c = "#0072BD"; + d = "Charging"; + p=1; + else + c = "#D95319"; + d = "Discharging"; + p=2; + end + + h(p) = plot(xPlot,yPlot,"Color",c,"DisplayName",d); + hold on + + t = xPlot(end)+1; +end + +hold off +title("Training Data") +xlabel("Observation") +ylabel("SOC (%)") +legend([h(1) h(2)],Location="bestoutside") + +end +``` diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.mlx b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.mlx new file mode 100644 index 0000000..40f8573 Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/BatteryStateOfChargeEstimationUsingMonotonicNeuralNetworks.mlx differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/brokenMonotonicity.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/brokenMonotonicity.png new file mode 100644 index 0000000..c5a308f Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/brokenMonotonicity.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/monotonicityScore.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/monotonicityScore.png new file mode 100644 index 0000000..171d2da Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/monotonicityScore.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotConstrainedDiff.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotConstrainedDiff.png new file mode 100644 index 0000000..2a6255b Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotConstrainedDiff.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotUnconstrained.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotUnconstrained.png new file mode 100644 index 0000000..9d52d83 Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotUnconstrained.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotUnconstrainedDiff.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotUnconstrainedDiff.png new file mode 100644 index 0000000..92c277f Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testPlotUnconstrainedDiff.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testRMSE.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testRMSE.png new file mode 100644 index 0000000..56abbdb Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/testRMSE.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingData.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingData.png new file mode 100644 index 0000000..16f895d Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingData.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingPlotCharging.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingPlotCharging.png new file mode 100644 index 0000000..8a92c49 Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingPlotCharging.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingPlotDischarging.png b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingPlotDischarging.png new file mode 100644 index 0000000..3a3693f Binary files /dev/null and b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/figures/trainingPlotDischarging.png differ diff --git a/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/trainUnconstrainedNetworks.m b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/trainUnconstrainedNetworks.m new file mode 100644 index 0000000..81588d9 --- /dev/null +++ b/examples/monotonic/BSOCEstimateUsingMonotonicNetworks/trainUnconstrainedNetworks.m @@ -0,0 +1,40 @@ +function [chargingNet,dischargingNet,chargingDiffNet,dischargingDiffNet] = trainUnconstrainedNetworks(XTrainCharging,YTrainCharging,XTrainDischarging,YTrainDischarging,XTrainChargingDifference,YTrainChargingDifference,XTrainDischargingDifference,YTrainDischargingDifference, chargingOptions, dischargingOptions) + +% Don't plot training progress + +chargingOptions.Plots = "none"; +dischargingOptions.Plots = "none"; + +% Train unconstrained RNN + +numFeatures = size(XTrainCharging{1},2); +numHiddenUnits = 32; +numResponses = size(YTrainCharging{1},2); + +layers = [sequenceInputLayer(numFeatures,"Normalization","rescale-zero-one") + lstmLayer(numHiddenUnits) + dropoutLayer(0.2) + lstmLayer(numHiddenUnits/2) + dropoutLayer(0.2) + fullyConnectedLayer(numResponses) + sigmoidLayer]; + + +chargingNet = trainnet(XTrainCharging,YTrainCharging,layers,"mse",chargingOptions); + +dischargingNet = trainnet(XTrainDischarging,YTrainDischarging,layers,"mse",dischargingOptions); + +% Train network on differences + +layersDiff = [sequenceInputLayer(numFeatures,"Normalization","rescale-zero-one") + lstmLayer(numHiddenUnits) + dropoutLayer(0.2) + lstmLayer(numHiddenUnits/2) + dropoutLayer(0.2) + fullyConnectedLayer(numResponses)]; + +chargingDiffNet = trainnet(XTrainChargingDifference,YTrainChargingDifference,layersDiff,"mse",chargingOptions); + +dischargingDiffNet = trainnet(XTrainDischargingDifference,YTrainDischargingDifference,layersDiff,"mse",dischargingOptions); + +end \ No newline at end of file diff --git a/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.md b/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.md index 478b5df..afcf593 100644 --- a/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.md +++ b/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.md @@ -242,7 +242,7 @@ options = trainingOptions("adam",... Verbose=0); ``` -Train the network using trainNetwork. It should take about 1\-2 minutes. +Train the network using trainnet. It should take about 1\-2 minutes. ```matlab cnnetRUL = trainnet(XTrain,YTrain,cnnetRUL,"mse",options); diff --git a/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.mlx b/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.mlx index 6b305c8..7c13030 100644 Binary files a/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.mlx and b/examples/monotonic/RULEstimateUsingMonotonicNetworks/RULEstimationUsingMonotonicNetworksExample.mlx differ