org.joone.engine
Class RTRLLearnerFactory

java.lang.Object
  extended by org.joone.engine.RTRLLearnerFactory
All Implemented Interfaces:
java.io.Serializable, java.util.EventListener, LearnerFactory, NeuralNetListener

public class RTRLLearnerFactory
extends java.lang.Object
implements LearnerFactory, NeuralNetListener, java.io.Serializable

A RTRL implementation. Based mostly on http://www.willamette.edu/~gorr/classes/cs449 and a few others. A partial RTRL implementation. Network weights are optimised using an offline RTRL implementation. The initial states of context nodes are not optimised, but could easily be added. For now, initial states are simply assumed to be what they are set to be in the context layer itself. RTRL does not rely on a backpropagated error and this can be turned off, and should be turned off, in order to speed things up. Functionality for this is included in the Monitor class and turned off whenever the setMonitor message is called. In order to speed things up, this includes an experimental lineseek approach where firstly the gradient is calculated using the offline RTRL algorithm. Then a step is taken along the gradient for as long as the sum of squared errors decreases in a lineseek type fashion. As soon as a step results in an increased sum of squared errors, a step back is taken, typically smaller than the step forward, and the gradient is once again updated. The stepping up and down is scaled in the spirit of the RPROP algorithm, so that the learning rate is adjusted after each cycle. Weights can also be randomised in the spirit of simulated annealing at the end of each cycle. As with the above lineseek approach, see the constructor for more details. These two features were really added to try and speed up convergence - if at all! Their practical benefit remain highly suspect at best. This class has a main method which also serves as a demo of the RTRL. Please refer to that. A suitable net can easily be created using the GUI and then trained using the main method, with a few alterations to the code based on the number of patterns for example, which, amongst others, is currently hard coded. The main method also shows how to save and restore a network trained via RTRL. While this class does implement the Serializable interface, it is highly suspect and not meant to be serialised together with the network. This implementation is highly academic at present. Any good exmaples where this can be applied will be much appreciated. I am still looking for them. The initial conditions as well as the learning rate seems to have such a high impact on the convergence of this as to make it of almost no practical use it seems. Also, strangely, it often seems that a higher rather than lower learning rate is better for convergence.

Author:
mg, http://www.ferra4models.com
See Also:
Serialized Form

Nested Class Summary
protected  class RTRLLearnerFactory.InitialState
          An initial state.
protected  class RTRLLearnerFactory.Node
          A node.
 class RTRLLearnerFactory.RTRLLearner
          The learner we will return from this factory.
protected  class RTRLLearnerFactory.Weight
          A weight.
 
Field Summary
protected  double currentSSE
          The current cycle's sum of squared errors, used especially when the experimental lineseek version is enabled
protected  int cycleCount
          Evaluation count, used when setting up a CFOProblem
protected  Layer inputLayer
          The input layer
protected  boolean interCycleUpdates
          True if we regularly update the weights within a cycle, false if we only update once the cycle completes
protected  double[] lastError
          The error last seen at the output layer T
protected  java.util.List<RTRLLearnerFactory.RTRLLearner> learners
          List of all learners referenced here, not used per se
protected  double learningRate
          The learning rate, stored locally to speed things up and adjusted dynamically if the experimental lineseek version is used.
protected  boolean lineseek
          True if the experimental lineseek version is in use
protected  double lowerLearningRate
          Lower bound for learning rate
protected  int minimumPatternCount
          Number of patterns that have to be seen before weights are updated
protected  double momentum
          The momentum, stored locally to speed things up
protected  Monitor monitor
          The (hopefully) shared monitor
protected  NeuralNet network
          The neural network trained here
protected  Layer outputLayer
          The output layer from which we calculate errors and which we use to determine if a node is in T
protected  double[][][] p
          The p matrix.
protected  int patternCount
          The current pattern.
protected  double previousSSE
          The previous cycle's sum of squared errors, used especially when the experimental lineseek version is enabled
protected  double[][] q
          The q matrix.
protected  java.util.Random random
          A random number generator used to shock the weights
protected  double shockFactor
          The shock factor
protected  double stepDownScale
          The scaling to apply if we are stepping down along an updated gradient, used when the experimental lineseek version is enabled
protected  double stepUpScale
          The scaling to apply if we are stepping along the gradient, used when the experimental lineseek version is enabled
protected  java.util.List<RTRLLearnerFactory.Node> T
          The output nodes, a subset of Z and used to speed up calculations
protected  java.util.List<RTRLLearnerFactory.Node> U
          The nodes in U, a subset of Z, specifically the first few nodes in Z, but those that are not input nodes.
protected  double[][][] updateP
          A utility matrix used to speed up things and when we update p.
protected  double updateProbability
          Probability of updating weights at next pattern
protected  double[][] updateQ
          A utility matrix used to speed up things and when we update q.
protected  double upperLearningRate
          Upper bound for learning rate
protected  boolean verbose
          If true, display progress along the way
protected  double weightMagnitude
          The maximum weight magnitude.
protected  java.util.List<RTRLLearnerFactory.Weight> weights
          List of all weights, used when setting up a CFOProblem
protected  java.util.List<RTRLLearnerFactory.Node> z
          The nodes, referred to as the z element
protected  java.util.List<RTRLLearnerFactory.Node> Z
          The same as z, but this time ordered so that nodes in U come first.
 
Constructor Summary
RTRLLearnerFactory(NeuralNet network, boolean verbose)
          Creates a new instance of RTRLLearnerFactory.
RTRLLearnerFactory(NeuralNet network, boolean lineseek, boolean verbose, double shockFactor)
          Creates a new instance of RTRLLearnerFactory.
 
Method Summary
protected  void attachErrorPatternListener()
          Attach a synapse to the output layer to calculate the error pattern and updat the RTRL p matrix and weights if online.
 void cicleTerminated(NeuralNetEvent e)
          Use this event to reset the p and update the weight matrices
 void errorChanged(NeuralNetEvent e)
          Use this event to update the p and delta matrices
protected  RTRLLearnerFactory.InitialState getInitialState(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
          Determine and return the initial state that impacts on an output node from a given input node firing into it.
 Learner getLearner(Monitor monitor)
          Return a suitable learner, part of the learner factory interface
 Monitor getMonitor()
          Retrieve the monitor
 CFOProblem getProblem(java.lang.String oatName, javax.swing.event.ChangeListener changeListener, double weightMagnitude, double guessStdev)
          Function to return this network as a oat CFOProblem
 double getStepDownScale()
          The current scale down factor
 double getStepUpScale()
          The current scale up factor
protected  RTRLLearnerFactory.Weight getWeight(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
          Determine the weight between an input and a output node.
 double getWeightMagnitude()
          The maximum weight magnitude - weights are not allowed to exceed this
 java.util.List<RTRLLearnerFactory.Weight> getWeights()
          Retrieve the weights
protected  void init()
          Initialise the underlying parameters and structures.
protected  boolean isLinked(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
          Determine if there exists a link or weight between two nodes.
protected  boolean isLinkedToInitialState(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
          Determine if there exists a link or weight between a context and a output node.
static void main(java.lang.String[] args)
          Execute one of the testers
 void netStarted(NeuralNetEvent e)
          Used to set the learning rate and momentum
 void netStopped(NeuralNetEvent e)
          Print out the weight matrix when done
 void netStoppedError(NeuralNetEvent e, java.lang.String error)
          Ignored
 void printP(java.io.PrintStream out)
          Helper to print out the p matrix
 void printWeights(java.io.PrintStream out)
          Helper function to print out weight matrix
protected  void registerLearnable(RTRLLearnerFactory.RTRLLearner learner)
          Register a new learnable.
protected  void resetP()
          Reset the p matrix in preparation for the next cycle
 void setLowerLearningRate(double lowerLearningRate)
          Set a lower bound on the (scaled) learning rate
 void setMonitor(Monitor monitor)
          Set the monitor and also turns off backpropagation
 void setStepDownScale(double stepDownScale)
          The scale down factor.
 void setStepUpScale(double stepUpScale)
          The scale up factor.
 void setUpperLearningRate(double upperLearningRate)
          Set an upper bound on the (scaled) learning rate
 void setWeight(int i, double weight)
          Set a given weight
 void setWeightMagnitude(double weightMagnitude)
          The maximum weight magnitude - weights are not allowed to exceed this
 void shrinkWeights(double scale)
          Helper function to shrink weights.
static void testAll(java.lang.String[] args)
          Test all the oat CFO algorithms on the general network
static void testCustom(java.lang.String[] args)
          Custom test
static void testOAT(java.lang.String[] args)
          Test of general learning, using something from optalgotoolkit.
static void testRTRL(java.lang.String[] args)
          Test of RTRL learning.
protected  void updateDeltas()
          Update the weights' deltas.
protected  void updateP()
          Update the p matrix
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

lineseek

protected boolean lineseek
True if the experimental lineseek version is in use


monitor

protected transient Monitor monitor
The (hopefully) shared monitor


learners

protected java.util.List<RTRLLearnerFactory.RTRLLearner> learners
List of all learners referenced here, not used per se


z

protected java.util.List<RTRLLearnerFactory.Node> z
The nodes, referred to as the z element


Z

protected java.util.List<RTRLLearnerFactory.Node> Z
The same as z, but this time ordered so that nodes in U come first. We are more interested in Z, but use z to extract the nodes and then at a later stage sort it into Z.


T

protected java.util.List<RTRLLearnerFactory.Node> T
The output nodes, a subset of Z and used to speed up calculations


U

protected java.util.List<RTRLLearnerFactory.Node> U
The nodes in U, a subset of Z, specifically the first few nodes in Z, but those that are not input nodes.


inputLayer

protected Layer inputLayer
The input layer


outputLayer

protected Layer outputLayer
The output layer from which we calculate errors and which we use to determine if a node is in T


learningRate

protected double learningRate
The learning rate, stored locally to speed things up and adjusted dynamically if the experimental lineseek version is used.


momentum

protected double momentum
The momentum, stored locally to speed things up


p

protected double[][][] p
The p matrix. It is a 3d matrix with dimensions the number of nodes in U x number of nodes in U x number of nodes in Z, typically indexed via K. In the literature this is typically indexed via sup k, sub i and j.


updateP

protected double[][][] updateP
A utility matrix used to speed up things and when we update p. We will swap this with p when we update.


q

protected double[][] q
The q matrix. This is a 2d matrix, but very similar to the p matrix. It is used for initial states and not weights though. For now, we only optimise the weights and not the initial state, so it is not used at present.


updateQ

protected double[][] updateQ
A utility matrix used to speed up things and when we update q. We will swap this with q when we update. As with q, it is not used at present.


lastError

protected double[] lastError
The error last seen at the output layer T


network

protected NeuralNet network
The neural network trained here


currentSSE

protected double currentSSE
The current cycle's sum of squared errors, used especially when the experimental lineseek version is enabled


previousSSE

protected double previousSSE
The previous cycle's sum of squared errors, used especially when the experimental lineseek version is enabled


stepUpScale

protected double stepUpScale
The scaling to apply if we are stepping along the gradient, used when the experimental lineseek version is enabled


stepDownScale

protected double stepDownScale
The scaling to apply if we are stepping down along an updated gradient, used when the experimental lineseek version is enabled


verbose

protected boolean verbose
If true, display progress along the way


upperLearningRate

protected double upperLearningRate
Upper bound for learning rate


lowerLearningRate

protected double lowerLearningRate
Lower bound for learning rate


patternCount

protected int patternCount
The current pattern. Not really required, sometimes used for debugging purposes.


random

protected java.util.Random random
A random number generator used to shock the weights


shockFactor

protected double shockFactor
The shock factor


weightMagnitude

protected double weightMagnitude
The maximum weight magnitude. Weights are not allowed to exceed this magnitude.


interCycleUpdates

protected boolean interCycleUpdates
True if we regularly update the weights within a cycle, false if we only update once the cycle completes


minimumPatternCount

protected int minimumPatternCount
Number of patterns that have to be seen before weights are updated


updateProbability

protected double updateProbability
Probability of updating weights at next pattern


weights

protected java.util.List<RTRLLearnerFactory.Weight> weights
List of all weights, used when setting up a CFOProblem


cycleCount

protected int cycleCount
Evaluation count, used when setting up a CFOProblem

Constructor Detail

RTRLLearnerFactory

public RTRLLearnerFactory(NeuralNet network,
                          boolean lineseek,
                          boolean verbose,
                          double shockFactor)
Creates a new instance of RTRLLearnerFactory.

Parameters:
network - the neural network to train
lineseek - if true, an experimental version will be used. This implies an offline version. To speed things up, this version of RTRL will traverse along the previously calculated gradient for as long as the goal function decreases and only when it starts to increase will it again update the gradient. Use with care, but it can speed things up quite a bit.
verbose - if true, will cause the RTRL to display debugging information along the way
shockFactor - if not 0, represents the amplitude that will be used to periodically - after every cycle - shock the weights with

RTRLLearnerFactory

public RTRLLearnerFactory(NeuralNet network,
                          boolean verbose)
Creates a new instance of RTRLLearnerFactory. Use this constructor if RTRL will not be used per se, but rather some OAT optimisation algorithm.

Parameters:
network - the neural network to train
verbose - if true, will cause the RTRL to display debugging information along the way
Method Detail

isLinked

protected boolean isLinked(RTRLLearnerFactory.Node input,
                           RTRLLearnerFactory.Node output)
Determine if there exists a link or weight between two nodes. Note that not all links qualify. Only links that we are allowed to change. Not sure if this will have an impact somewhere.

Returns:
true if the input node fires into the output node

getWeight

protected RTRLLearnerFactory.Weight getWeight(RTRLLearnerFactory.Node input,
                                              RTRLLearnerFactory.Node output)
Determine the weight between an input and a output node.

Returns:
null if no link exists

isLinkedToInitialState

protected boolean isLinkedToInitialState(RTRLLearnerFactory.Node input,
                                         RTRLLearnerFactory.Node output)
Determine if there exists a link or weight between a context and a output node. Note that not all links qualify. Only links that we are allowed to change.

Returns:
true if the input node is a context layer that fires into the output node

getInitialState

protected RTRLLearnerFactory.InitialState getInitialState(RTRLLearnerFactory.Node input,
                                                          RTRLLearnerFactory.Node output)
Determine and return the initial state that impacts on an output node from a given input node firing into it.

Returns:
null if no such link exists

getLearner

public Learner getLearner(Monitor monitor)
Return a suitable learner, part of the learner factory interface

Specified by:
getLearner in interface LearnerFactory
Parameters:
monitor - the monitor.

registerLearnable

protected void registerLearnable(RTRLLearnerFactory.RTRLLearner learner)
Register a new learnable. This is messaged from the subclass RTRLLearner


init

protected void init()
Initialise the underlying parameters and structures. This method create and traverse the nodes, classify them as being in U or not, and assign k (in z) and i (in U) numbers to them. It also creates and initialises the p, weight and delta matrices.


updateP

protected void updateP()
Update the p matrix


resetP

protected void resetP()
Reset the p matrix in preparation for the next cycle


updateDeltas

protected void updateDeltas()
Update the weights' deltas. This is called once a network run has been completed.


attachErrorPatternListener

protected void attachErrorPatternListener()
Attach a synapse to the output layer to calculate the error pattern and updat the RTRL p matrix and weights if online.


getMonitor

public Monitor getMonitor()
Retrieve the monitor


setMonitor

public void setMonitor(Monitor monitor)
Set the monitor and also turns off backpropagation


setStepUpScale

public void setStepUpScale(double stepUpScale)
The scale up factor. Set the scale factor to apply to the learning rate if this is a lineseek RTRL and the current cycle was a step along the gradient and it resulted in an improvement in the goal function. Typically more than 1, default is 1.5


getStepUpScale

public double getStepUpScale()
The current scale up factor


setStepDownScale

public void setStepDownScale(double stepDownScale)
The scale down factor. Set the scale factor to apply to the learning rate if this is a lineseek RTRL and the current cycle was a step along the gradient and it did not result in an improvement in the goal function. Typically less than 1 - we overstepped and want to get back to the optimum, default is 0.7


getStepDownScale

public double getStepDownScale()
The current scale down factor


setUpperLearningRate

public void setUpperLearningRate(double upperLearningRate)
Set an upper bound on the (scaled) learning rate


setLowerLearningRate

public void setLowerLearningRate(double lowerLearningRate)
Set a lower bound on the (scaled) learning rate


netStarted

public void netStarted(NeuralNetEvent e)
Used to set the learning rate and momentum

Specified by:
netStarted in interface NeuralNetListener

cicleTerminated

public void cicleTerminated(NeuralNetEvent e)
Use this event to reset the p and update the weight matrices

Specified by:
cicleTerminated in interface NeuralNetListener

netStopped

public void netStopped(NeuralNetEvent e)
Print out the weight matrix when done

Specified by:
netStopped in interface NeuralNetListener

errorChanged

public void errorChanged(NeuralNetEvent e)
Use this event to update the p and delta matrices

Specified by:
errorChanged in interface NeuralNetListener

netStoppedError

public void netStoppedError(NeuralNetEvent e,
                            java.lang.String error)
Ignored

Specified by:
netStoppedError in interface NeuralNetListener

shrinkWeights

public void shrinkWeights(double scale)
Helper function to shrink weights. Weights should really be initialised to small values - it seems. This allows for weights to be shrinked using the scale factor.


printWeights

public void printWeights(java.io.PrintStream out)
Helper function to print out weight matrix


printP

public void printP(java.io.PrintStream out)
Helper to print out the p matrix


getWeights

public java.util.List<RTRLLearnerFactory.Weight> getWeights()
Retrieve the weights


setWeight

public void setWeight(int i,
                      double weight)
Set a given weight


getWeightMagnitude

public double getWeightMagnitude()
The maximum weight magnitude - weights are not allowed to exceed this


setWeightMagnitude

public void setWeightMagnitude(double weightMagnitude)
The maximum weight magnitude - weights are not allowed to exceed this


getProblem

public CFOProblem getProblem(java.lang.String oatName,
                             javax.swing.event.ChangeListener changeListener,
                             double weightMagnitude,
                             double guessStdev)
Function to return this network as a oat CFOProblem

Parameters:
oatName - a name used when reporting the error on stderr
changeListener - null or a listener that is notified whenever a new error was calculated
weightMagnitude - the magnitude allowed for a weight in this problem, e.g. 10 to force the solution weights to be between -10 and +10
guessStdev - if not zero, this indicates that the network's current weights can be used as an initial guess and that this standard deviation can be used to generate other initial guesses around it

main

public static void main(java.lang.String[] args)
Execute one of the testers


testRTRL

public static void testRTRL(java.lang.String[] args)
Test of RTRL learning. Each command line argument represents a recurrent net that will be loaded and trained using RTRL. It will then be saved back, overridding the original file. You have been warned.


testAll

public static void testAll(java.lang.String[] args)
Test all the oat CFO algorithms on the general network


testOAT

public static void testOAT(java.lang.String[] args)
Test of general learning, using something from optalgotoolkit. Each command line argument pair represents a net that will be loaded and trained using some optalgotoolkit algorithm. The algorithm is the second argument in the pair. It will then be saved back, overridding the original file. You have been warned.


testCustom

public static void testCustom(java.lang.String[] args)
Custom test



Submit Feedback to pmarrone@users.sourceforge.net