|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.joone.engine.RTRLLearnerFactory
public class RTRLLearnerFactory
A RTRL implementation. Based mostly on http://www.willamette.edu/~gorr/classes/cs449 and a few others. A partial RTRL implementation. Network weights are optimised using an offline RTRL implementation. The initial states of context nodes are not optimised, but could easily be added. For now, initial states are simply assumed to be what they are set to be in the context layer itself. RTRL does not rely on a backpropagated error and this can be turned off, and should be turned off, in order to speed things up. Functionality for this is included in the Monitor class and turned off whenever the setMonitor message is called. In order to speed things up, this includes an experimental lineseek approach where firstly the gradient is calculated using the offline RTRL algorithm. Then a step is taken along the gradient for as long as the sum of squared errors decreases in a lineseek type fashion. As soon as a step results in an increased sum of squared errors, a step back is taken, typically smaller than the step forward, and the gradient is once again updated. The stepping up and down is scaled in the spirit of the RPROP algorithm, so that the learning rate is adjusted after each cycle. Weights can also be randomised in the spirit of simulated annealing at the end of each cycle. As with the above lineseek approach, see the constructor for more details. These two features were really added to try and speed up convergence - if at all! Their practical benefit remain highly suspect at best. This class has a main method which also serves as a demo of the RTRL. Please refer to that. A suitable net can easily be created using the GUI and then trained using the main method, with a few alterations to the code based on the number of patterns for example, which, amongst others, is currently hard coded. The main method also shows how to save and restore a network trained via RTRL. While this class does implement the Serializable interface, it is highly suspect and not meant to be serialised together with the network. This implementation is highly academic at present. Any good exmaples where this can be applied will be much appreciated. I am still looking for them. The initial conditions as well as the learning rate seems to have such a high impact on the convergence of this as to make it of almost no practical use it seems. Also, strangely, it often seems that a higher rather than lower learning rate is better for convergence.
Nested Class Summary | |
---|---|
protected class |
RTRLLearnerFactory.InitialState
An initial state. |
protected class |
RTRLLearnerFactory.Node
A node. |
class |
RTRLLearnerFactory.RTRLLearner
The learner we will return from this factory. |
protected class |
RTRLLearnerFactory.Weight
A weight. |
Field Summary | |
---|---|
protected double |
currentSSE
The current cycle's sum of squared errors, used especially when the experimental lineseek version is enabled |
protected int |
cycleCount
Evaluation count, used when setting up a CFOProblem |
protected Layer |
inputLayer
The input layer |
protected boolean |
interCycleUpdates
True if we regularly update the weights within a cycle, false if we only update once the cycle completes |
protected double[] |
lastError
The error last seen at the output layer T |
protected java.util.List<RTRLLearnerFactory.RTRLLearner> |
learners
List of all learners referenced here, not used per se |
protected double |
learningRate
The learning rate, stored locally to speed things up and adjusted dynamically if the experimental lineseek version is used. |
protected boolean |
lineseek
True if the experimental lineseek version is in use |
protected double |
lowerLearningRate
Lower bound for learning rate |
protected int |
minimumPatternCount
Number of patterns that have to be seen before weights are updated |
protected double |
momentum
The momentum, stored locally to speed things up |
protected Monitor |
monitor
The (hopefully) shared monitor |
protected NeuralNet |
network
The neural network trained here |
protected Layer |
outputLayer
The output layer from which we calculate errors and which we use to determine if a node is in T |
protected double[][][] |
p
The p matrix. |
protected int |
patternCount
The current pattern. |
protected double |
previousSSE
The previous cycle's sum of squared errors, used especially when the experimental lineseek version is enabled |
protected double[][] |
q
The q matrix. |
protected java.util.Random |
random
A random number generator used to shock the weights |
protected double |
shockFactor
The shock factor |
protected double |
stepDownScale
The scaling to apply if we are stepping down along an updated gradient, used when the experimental lineseek version is enabled |
protected double |
stepUpScale
The scaling to apply if we are stepping along the gradient, used when the experimental lineseek version is enabled |
protected java.util.List<RTRLLearnerFactory.Node> |
T
The output nodes, a subset of Z and used to speed up calculations |
protected java.util.List<RTRLLearnerFactory.Node> |
U
The nodes in U, a subset of Z, specifically the first few nodes in Z, but those that are not input nodes. |
protected double[][][] |
updateP
A utility matrix used to speed up things and when we update p. |
protected double |
updateProbability
Probability of updating weights at next pattern |
protected double[][] |
updateQ
A utility matrix used to speed up things and when we update q. |
protected double |
upperLearningRate
Upper bound for learning rate |
protected boolean |
verbose
If true, display progress along the way |
protected double |
weightMagnitude
The maximum weight magnitude. |
protected java.util.List<RTRLLearnerFactory.Weight> |
weights
List of all weights, used when setting up a CFOProblem |
protected java.util.List<RTRLLearnerFactory.Node> |
z
The nodes, referred to as the z element |
protected java.util.List<RTRLLearnerFactory.Node> |
Z
The same as z, but this time ordered so that nodes in U come first. |
Constructor Summary | |
---|---|
RTRLLearnerFactory(NeuralNet network,
boolean verbose)
Creates a new instance of RTRLLearnerFactory. |
|
RTRLLearnerFactory(NeuralNet network,
boolean lineseek,
boolean verbose,
double shockFactor)
Creates a new instance of RTRLLearnerFactory. |
Method Summary | |
---|---|
protected void |
attachErrorPatternListener()
Attach a synapse to the output layer to calculate the error pattern and updat the RTRL p matrix and weights if online. |
void |
cicleTerminated(NeuralNetEvent e)
Use this event to reset the p and update the weight matrices |
void |
errorChanged(NeuralNetEvent e)
Use this event to update the p and delta matrices |
protected RTRLLearnerFactory.InitialState |
getInitialState(RTRLLearnerFactory.Node input,
RTRLLearnerFactory.Node output)
Determine and return the initial state that impacts on an output node from a given input node firing into it. |
Learner |
getLearner(Monitor monitor)
Return a suitable learner, part of the learner factory interface |
Monitor |
getMonitor()
Retrieve the monitor |
CFOProblem |
getProblem(java.lang.String oatName,
javax.swing.event.ChangeListener changeListener,
double weightMagnitude,
double guessStdev)
Function to return this network as a oat CFOProblem |
double |
getStepDownScale()
The current scale down factor |
double |
getStepUpScale()
The current scale up factor |
protected RTRLLearnerFactory.Weight |
getWeight(RTRLLearnerFactory.Node input,
RTRLLearnerFactory.Node output)
Determine the weight between an input and a output node. |
double |
getWeightMagnitude()
The maximum weight magnitude - weights are not allowed to exceed this |
java.util.List<RTRLLearnerFactory.Weight> |
getWeights()
Retrieve the weights |
protected void |
init()
Initialise the underlying parameters and structures. |
protected boolean |
isLinked(RTRLLearnerFactory.Node input,
RTRLLearnerFactory.Node output)
Determine if there exists a link or weight between two nodes. |
protected boolean |
isLinkedToInitialState(RTRLLearnerFactory.Node input,
RTRLLearnerFactory.Node output)
Determine if there exists a link or weight between a context and a output node. |
static void |
main(java.lang.String[] args)
Execute one of the testers |
void |
netStarted(NeuralNetEvent e)
Used to set the learning rate and momentum |
void |
netStopped(NeuralNetEvent e)
Print out the weight matrix when done |
void |
netStoppedError(NeuralNetEvent e,
java.lang.String error)
Ignored |
void |
printP(java.io.PrintStream out)
Helper to print out the p matrix |
void |
printWeights(java.io.PrintStream out)
Helper function to print out weight matrix |
protected void |
registerLearnable(RTRLLearnerFactory.RTRLLearner learner)
Register a new learnable. |
protected void |
resetP()
Reset the p matrix in preparation for the next cycle |
void |
setLowerLearningRate(double lowerLearningRate)
Set a lower bound on the (scaled) learning rate |
void |
setMonitor(Monitor monitor)
Set the monitor and also turns off backpropagation |
void |
setStepDownScale(double stepDownScale)
The scale down factor. |
void |
setStepUpScale(double stepUpScale)
The scale up factor. |
void |
setUpperLearningRate(double upperLearningRate)
Set an upper bound on the (scaled) learning rate |
void |
setWeight(int i,
double weight)
Set a given weight |
void |
setWeightMagnitude(double weightMagnitude)
The maximum weight magnitude - weights are not allowed to exceed this |
void |
shrinkWeights(double scale)
Helper function to shrink weights. |
static void |
testAll(java.lang.String[] args)
Test all the oat CFO algorithms on the general network |
static void |
testCustom(java.lang.String[] args)
Custom test |
static void |
testOAT(java.lang.String[] args)
Test of general learning, using something from optalgotoolkit. |
static void |
testRTRL(java.lang.String[] args)
Test of RTRL learning. |
protected void |
updateDeltas()
Update the weights' deltas. |
protected void |
updateP()
Update the p matrix |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected boolean lineseek
protected transient Monitor monitor
protected java.util.List<RTRLLearnerFactory.RTRLLearner> learners
protected java.util.List<RTRLLearnerFactory.Node> z
protected java.util.List<RTRLLearnerFactory.Node> Z
protected java.util.List<RTRLLearnerFactory.Node> T
protected java.util.List<RTRLLearnerFactory.Node> U
protected Layer inputLayer
protected Layer outputLayer
protected double learningRate
protected double momentum
protected double[][][] p
protected double[][][] updateP
protected double[][] q
protected double[][] updateQ
protected double[] lastError
protected NeuralNet network
protected double currentSSE
protected double previousSSE
protected double stepUpScale
protected double stepDownScale
protected boolean verbose
protected double upperLearningRate
protected double lowerLearningRate
protected int patternCount
protected java.util.Random random
protected double shockFactor
protected double weightMagnitude
protected boolean interCycleUpdates
protected int minimumPatternCount
protected double updateProbability
protected java.util.List<RTRLLearnerFactory.Weight> weights
protected int cycleCount
Constructor Detail |
---|
public RTRLLearnerFactory(NeuralNet network, boolean lineseek, boolean verbose, double shockFactor)
network
- the neural network to trainlineseek
- if true, an experimental version will be used.
This implies an offline version. To speed things up, this
version of RTRL will traverse along the previously calculated
gradient for as long as the goal function decreases and only
when it starts to increase will it again update the gradient.
Use with care, but it can speed things up quite a bit.verbose
- if true, will cause the RTRL to display debugging
information along the wayshockFactor
- if not 0, represents the amplitude that will be
used to periodically - after every cycle - shock the weights
withpublic RTRLLearnerFactory(NeuralNet network, boolean verbose)
network
- the neural network to trainverbose
- if true, will cause the RTRL to display debugging
information along the wayMethod Detail |
---|
protected boolean isLinked(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
protected RTRLLearnerFactory.Weight getWeight(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
protected boolean isLinkedToInitialState(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
protected RTRLLearnerFactory.InitialState getInitialState(RTRLLearnerFactory.Node input, RTRLLearnerFactory.Node output)
public Learner getLearner(Monitor monitor)
getLearner
in interface LearnerFactory
monitor
- the monitor.protected void registerLearnable(RTRLLearnerFactory.RTRLLearner learner)
protected void init()
protected void updateP()
protected void resetP()
protected void updateDeltas()
protected void attachErrorPatternListener()
public Monitor getMonitor()
public void setMonitor(Monitor monitor)
public void setStepUpScale(double stepUpScale)
public double getStepUpScale()
public void setStepDownScale(double stepDownScale)
public double getStepDownScale()
public void setUpperLearningRate(double upperLearningRate)
public void setLowerLearningRate(double lowerLearningRate)
public void netStarted(NeuralNetEvent e)
netStarted
in interface NeuralNetListener
public void cicleTerminated(NeuralNetEvent e)
cicleTerminated
in interface NeuralNetListener
public void netStopped(NeuralNetEvent e)
netStopped
in interface NeuralNetListener
public void errorChanged(NeuralNetEvent e)
errorChanged
in interface NeuralNetListener
public void netStoppedError(NeuralNetEvent e, java.lang.String error)
netStoppedError
in interface NeuralNetListener
public void shrinkWeights(double scale)
public void printWeights(java.io.PrintStream out)
public void printP(java.io.PrintStream out)
public java.util.List<RTRLLearnerFactory.Weight> getWeights()
public void setWeight(int i, double weight)
public double getWeightMagnitude()
public void setWeightMagnitude(double weightMagnitude)
public CFOProblem getProblem(java.lang.String oatName, javax.swing.event.ChangeListener changeListener, double weightMagnitude, double guessStdev)
oatName
- a name used when reporting the error on stderrchangeListener
- null or a listener that is notified whenever a new error was calculatedweightMagnitude
- the magnitude allowed for a weight in this problem, e.g. 10 to force
the solution weights to be between -10 and +10guessStdev
- if not zero, this indicates that the network's current weights can be used
as an initial guess and that this standard deviation can be used to generate
other initial guesses around itpublic static void main(java.lang.String[] args)
public static void testRTRL(java.lang.String[] args)
public static void testAll(java.lang.String[] args)
public static void testOAT(java.lang.String[] args)
public static void testCustom(java.lang.String[] args)
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |