weka.core.metrics
Class GDMetricLearner

java.lang.Object
  extended byweka.core.metrics.MetricLearner
      extended byweka.core.metrics.GDMetricLearner
All Implemented Interfaces:
OptionHandler, java.io.Serializable

public class GDMetricLearner
extends MetricLearner
implements java.io.Serializable, OptionHandler

GDMetricLearner - sets the weights of a metric using gradient descent

See Also:
Serialized Form

Field Summary
protected  double m_epsilon
          The convergence criterion for total weight updates
protected  Instances m_instances
          The training data
protected  double m_learningRate
          The learning rate
protected  int m_maxIterations
          Maximum number of iterations
protected  LearnableMetric m_metric
          The metric that the classifier was used to learn, useful for external-calculation based metrics
protected  int m_numNegPairs
           
protected  int m_numPosPairs
           
protected  java.util.ArrayList m_pairList
           
protected  PairwiseSelector m_selector
          The pairwise selector used by the metric
 
Constructor Summary
GDMetricLearner()
          Create a new gradient descent metric learner
 
Method Summary
protected  double[] calculateGradients(double[] weights)
          A helper function that calculates the current gradient value
static java.lang.String concatStringArray(java.lang.String[] strings)
          A little helper to create a single String from an array of Strings
protected  java.util.ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs)
          Create a lists of pairs of two kinds: pairs of instances belonging to same class, and pairs of instances belonging to different classes.
 double getDistance(Instance instance1, Instance instance2)
          Use the Classifier for an estimation of distance
 double getEpsilon()
          Get the convergence criterion
 double getLearningRate()
          Get the learning rate
 int getMaxIterations()
          Get the maximum number of update iterations rate
 int getNumNegPairs()
          Get the number of different-class training pairs
 int getNumPosPairs()
          Get the number of same-class training pairs
 java.lang.String[] getOptions()
          Gets the current settings of WeightedDotP.
 PairwiseSelector getSelector()
          Get the pairwise selector
 double getSimilarity(Instance instance1, Instance instance2)
          Use the Classifier for an estimation of similarity
protected static java.lang.String getTimestamp()
          Gets a string containing current date and time.
 double lengthWeighted(Instance instance, double[] weights)
          Get the norm-2 length of an instance assuming all attributes are numeric and utilizing the attribute weights
 java.util.Enumeration listOptions()
          Returns an enumeration describing the available options.
protected  double[] normalizeWeights(double[] weights)
          Normalize weights
 void printTopAttributes(double[] weights, int n, int iteration)
          Print the heaviest-weighted attributes for a given set of weights
 void setEpsilon(double epsilon)
          Set the convergence criterion
 void setLearningRate(double learningRate)
          Set the learning rate
 void setMaxIterations(int maxIterations)
          Set the maximum number of update iterations rate
 void setNumNegPairs(int numNegPairs)
          Set the number of different-class training pairs
 void setNumPosPairs(int numPosPairs)
          Set the number of same-class training pairs
 void setOptions(java.lang.String[] options)
          Parses a given list of options.
 void setSelector(PairwiseSelector selector)
          Set the pairwise selector
 java.lang.String toString()
          Obtain a textual description of the metriclearner
 void trainMetric(LearnableMetric metric, Instances instances)
          Train a given metric using given training instances
 
Methods inherited from class weka.core.metrics.MetricLearner
createDiffInstanceLists, createDiffInstances, createDiffInstances, forName, getAttrInfoForDiffInstance
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

m_metric

protected LearnableMetric m_metric
The metric that the classifier was used to learn, useful for external-calculation based metrics


m_maxIterations

protected int m_maxIterations
Maximum number of iterations


m_learningRate

protected double m_learningRate
The learning rate


m_instances

protected Instances m_instances
The training data


m_pairList

protected java.util.ArrayList m_pairList

m_numPosPairs

protected int m_numPosPairs

m_numNegPairs

protected int m_numNegPairs

m_epsilon

protected double m_epsilon
The convergence criterion for total weight updates


m_selector

protected PairwiseSelector m_selector
The pairwise selector used by the metric

Constructor Detail

GDMetricLearner

public GDMetricLearner()
Create a new gradient descent metric learner

Method Detail

trainMetric

public void trainMetric(LearnableMetric metric,
                        Instances instances)
                 throws java.lang.Exception
Train a given metric using given training instances

Specified by:
trainMetric in class MetricLearner
Parameters:
metric - the metric to train
instances - data to train the metric on
Throws:
java.lang.Exception - if training has gone bad.

calculateGradients

protected double[] calculateGradients(double[] weights)
                               throws java.lang.Exception
A helper function that calculates the current gradient value

Parameters:
weights - the current weights vector
Returns:
the values of the partial derivatives
Throws:
java.lang.Exception

normalizeWeights

protected double[] normalizeWeights(double[] weights)
Normalize weights

Parameters:
weights - an unnormalized array of weights
Returns:
a normalized array of weights

lengthWeighted

public double lengthWeighted(Instance instance,
                             double[] weights)
Get the norm-2 length of an instance assuming all attributes are numeric and utilizing the attribute weights


getSimilarity

public double getSimilarity(Instance instance1,
                            Instance instance2)
                     throws java.lang.Exception
Use the Classifier for an estimation of similarity

Specified by:
getSimilarity in class MetricLearner
Parameters:
instance1 - first instance of a pair
instance2 - second instance of a pair
Throws:
java.lang.Exception

getDistance

public double getDistance(Instance instance1,
                          Instance instance2)
                   throws java.lang.Exception
Use the Classifier for an estimation of distance

Specified by:
getDistance in class MetricLearner
Parameters:
instance1 - first instance of a pair
instance2 - second instance of a pair
Throws:
java.lang.Exception

setEpsilon

public void setEpsilon(double epsilon)
Set the convergence criterion

Parameters:
epsilon - the maximum sum of weight updates required for GD to converge

getEpsilon

public double getEpsilon()
Get the convergence criterion

Returns:
the maximum sum of weight updates required for GD to converge

setLearningRate

public void setLearningRate(double learningRate)
Set the learning rate

Parameters:
learningRate - the gradient update coefficient

getLearningRate

public double getLearningRate()
Get the learning rate

Returns:
the gradient update coefficient

setMaxIterations

public void setMaxIterations(int maxIterations)
Set the maximum number of update iterations rate

Parameters:
maxIterations - the maximum number of gradient updates

getMaxIterations

public int getMaxIterations()
Get the maximum number of update iterations rate

Returns:
the maximum number of gradient updates

setNumPosPairs

public void setNumPosPairs(int numPosPairs)
Set the number of same-class training pairs

Parameters:
numPosPairs - the number of same-class training pairs to create for training

getNumPosPairs

public int getNumPosPairs()
Get the number of same-class training pairs

Returns:
the number of same-class training pairs to create for training

setNumNegPairs

public void setNumNegPairs(int numNegPairs)
Set the number of different-class training pairs

Parameters:
numNegPairs - the number of different-class training pairs to create for training

getNumNegPairs

public int getNumNegPairs()
Get the number of different-class training pairs

Returns:
the number of different-class training pairs to create for training

setSelector

public void setSelector(PairwiseSelector selector)
Set the pairwise selector

Parameters:
selector - the selector for training pairs

getSelector

public PairwiseSelector getSelector()
Get the pairwise selector

Returns:
the selector for training pairs

getOptions

public java.lang.String[] getOptions()
Gets the current settings of WeightedDotP.

Specified by:
getOptions in interface OptionHandler
Returns:
an array of strings suitable for passing to setOptions()

setOptions

public void setOptions(java.lang.String[] options)
                throws java.lang.Exception
Parses a given list of options. Valid options are:

-B classifierstring

Specified by:
setOptions in interface OptionHandler
Parameters:
options - the list of options as an array of strings
Throws:
java.lang.Exception - if an option is not supported

getTimestamp

protected static java.lang.String getTimestamp()
Gets a string containing current date and time.

Returns:
a string containing the date and time.

listOptions

public java.util.Enumeration listOptions()
Returns an enumeration describing the available options.

Specified by:
listOptions in interface OptionHandler
Returns:
an enumeration of all the available options.

toString

public java.lang.String toString()
Obtain a textual description of the metriclearner

Returns:
a textual description of the metric learner

concatStringArray

public static java.lang.String concatStringArray(java.lang.String[] strings)
A little helper to create a single String from an array of Strings

Parameters:
strings - an array of strings

createPairList

protected java.util.ArrayList createPairList(Instances instances,
                                             int numPosPairs,
                                             int numNegPairs)
Create a lists of pairs of two kinds: pairs of instances belonging to same class, and pairs of instances belonging to different classes.


printTopAttributes

public void printTopAttributes(double[] weights,
                               int n,
                               int iteration)
Print the heaviest-weighted attributes for a given set of weights