org.joone.engine
Class RTRL

java.lang.Object
  extended by org.joone.engine.RTRL

public class RTRL
extends java.lang.Object

A RTRL implementation. Based mostly on http://www.willamette.edu/~gorr/classes/cs449 and a few others. A partial RTRL implementation. Network weights are optimised using an offline RTRL implementation. The initial states of context nodes are not optimised, but could easily be added. For now, initial states are simply assumed to be what they are set to be in the context layer itself. RTRL does not rely on a backpropagated error and this can and should be turned off, in order to speed things up. Functionality for this is included in the Monitor class and turned off whenever the setMonitor message is called. In order to speed things up, this includes an experimental lineseek approach where firstly the gradient is calculated using the offline RTRL algorithm. Then a step is taken along the gradient for as long as the sum of squared errors decreases in a typical lineseek type fashion. As soon as a step results in an increased sum of squared errors, a step back is taken, typically smaller than the step forward, and the gradient is once again updated. The stepping up and down is scaled in the spirit of the RPROP algorithm, so that the learning rate is adjusted after each cycle. Weights can also be randomised in the spirit of simulated annealing at the end of each cycle. As with the above lineseek approach, see the constructor for more details. These two features were really added to try and speed up convergence - if at all! Their practical benefit remain highly suspect at best. This class has a main method which also serves as a demo of the RTRL. Please refer to that. A suitable net can easily be created using the GUI and then trained using the main method, with a few alterations to the code based on the number of patterns for example, which, amongst others, is currently hard coded. The main method also shows how to save and restore a network trained via RTRL. While this class does implement the Serializable interface, it is highly suspect and not meant to be serialised together with the network. This implementation is highly academic at present. Any good exmaples where this can be applied will be much appreciated. I am still looking for them. The initial conditions as well as the learning rate seems to have such a high impact on the convergence of this as to make it of almost no practical use it seems. Also, strangely, it often seems that a higher rather than lower learning rate is better for convergence. Support for multiprocessors have now been added.

Author:
mg, http://www.ferra4models.com

Field Summary
protected  NodesAndWeights nodesAndWeights
          The network we are training
protected  java.util.List<java.util.List<NodesAndWeights.Node>> nodesList
          List of list of nodes that will be updated by each processor
protected  double[][] p
          The p matrix, p [ k ] [ij ] is node k's (in U) derivative with respect to weight ij
protected  int patternCount
          Pattern counter
protected  int processorCount
          Number of processors to use, 1 or less on a uniprocessor
protected  double[][] updateP
          The utility updateP matrix, used when updating the p matrix
protected  java.util.List<java.util.List<NodesAndWeights.Weight>> weightsList
          List of list of weights that will be updated by each processor
 
Constructor Summary
RTRL(NodesAndWeights nodesAndWeights)
          Create a new instance of RTRL
 
Method Summary
 int getProcessorCount()
          Retrieve processor count
protected  void init()
          Initialise
 void printP(java.io.PrintStream out)
          Helper to print out the p matrix
protected  void resetP()
          Reset the p matrix in preparation for the next cycle - called at the end of a cycle
 void setProcessorCount(int processorCount)
          Set the number of processors to use
 void update(double[] error)
          Update RTRL Call this with the most recent error pattern as soon as one becomes available.
 void updateCycle(double learningRate)
          Update the weights Call this once a full set of patterns were presented to the network to update the weights
protected  void updateDeltas(double[] error)
          Update the weights' deltas.
protected  void updateDeltas(double[] error, java.util.List<NodesAndWeights.Weight> weights)
          Update the given weights' deltas.
protected  void updateP()
          Update the p matrix - called after a pattern has been presented to the network
protected  void updateP(java.util.List<NodesAndWeights.Node> nodes)
          Update the p matrix for the given list of nodes
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

nodesAndWeights

protected NodesAndWeights nodesAndWeights
The network we are training


p

protected double[][] p
The p matrix, p [ k ] [ij ] is node k's (in U) derivative with respect to weight ij


updateP

protected double[][] updateP
The utility updateP matrix, used when updating the p matrix


patternCount

protected int patternCount
Pattern counter


processorCount

protected int processorCount
Number of processors to use, 1 or less on a uniprocessor


nodesList

protected java.util.List<java.util.List<NodesAndWeights.Node>> nodesList
List of list of nodes that will be updated by each processor


weightsList

protected java.util.List<java.util.List<NodesAndWeights.Weight>> weightsList
List of list of weights that will be updated by each processor

Constructor Detail

RTRL

public RTRL(NodesAndWeights nodesAndWeights)
Create a new instance of RTRL

Parameters:
nodesAndWeights - the network to be optimised's structure
Method Detail

init

protected void init()
Initialise


setProcessorCount

public void setProcessorCount(int processorCount)
Set the number of processors to use


getProcessorCount

public int getProcessorCount()
Retrieve processor count


updateP

protected void updateP(java.util.List<NodesAndWeights.Node> nodes)
Update the p matrix for the given list of nodes


updateP

protected void updateP()
Update the p matrix - called after a pattern has been presented to the network


updateDeltas

protected void updateDeltas(double[] error,
                            java.util.List<NodesAndWeights.Weight> weights)
Update the given weights' deltas. This is called once a pattern has been presented to the network.

Parameters:
error - most recently seen error pattern

updateDeltas

protected void updateDeltas(double[] error)
Update the weights' deltas. This is called once a pattern has been presented to the network.

Parameters:
error - most recently seen error pattern

resetP

protected void resetP()
Reset the p matrix in preparation for the next cycle - called at the end of a cycle


update

public void update(double[] error)
Update RTRL Call this with the most recent error pattern as soon as one becomes available.

Parameters:
error - the most recently seen error pattern

updateCycle

public void updateCycle(double learningRate)
Update the weights Call this once a full set of patterns were presented to the network to update the weights


printP

public void printP(java.io.PrintStream out)
Helper to print out the p matrix

Parameters:
out - stream to which to dump the matrix


Submit Feedback to pmarrone@users.sourceforge.net