|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectweka.core.metrics.MetricLearner
weka.core.metrics.GDMetricLearner
GDMetricLearner - sets the weights of a metric using gradient descent
Field Summary | |
protected double |
m_epsilon
The convergence criterion for total weight updates |
protected Instances |
m_instances
The training data |
protected double |
m_learningRate
The learning rate |
protected int |
m_maxIterations
Maximum number of iterations |
protected LearnableMetric |
m_metric
The metric that the classifier was used to learn, useful for external-calculation based metrics |
protected int |
m_numNegPairs
|
protected int |
m_numPosPairs
|
protected java.util.ArrayList |
m_pairList
|
protected PairwiseSelector |
m_selector
The pairwise selector used by the metric |
Constructor Summary | |
GDMetricLearner()
Create a new gradient descent metric learner |
Method Summary | |
protected double[] |
calculateGradients(double[] weights)
A helper function that calculates the current gradient value |
static java.lang.String |
concatStringArray(java.lang.String[] strings)
A little helper to create a single String from an array of Strings |
protected java.util.ArrayList |
createPairList(Instances instances,
int numPosPairs,
int numNegPairs)
Create a lists of pairs of two kinds: pairs of instances belonging to same class, and pairs of instances belonging to different classes. |
double |
getDistance(Instance instance1,
Instance instance2)
Use the Classifier for an estimation of distance |
double |
getEpsilon()
Get the convergence criterion |
double |
getLearningRate()
Get the learning rate |
int |
getMaxIterations()
Get the maximum number of update iterations rate |
int |
getNumNegPairs()
Get the number of different-class training pairs |
int |
getNumPosPairs()
Get the number of same-class training pairs |
java.lang.String[] |
getOptions()
Gets the current settings of WeightedDotP. |
PairwiseSelector |
getSelector()
Get the pairwise selector |
double |
getSimilarity(Instance instance1,
Instance instance2)
Use the Classifier for an estimation of similarity |
protected static java.lang.String |
getTimestamp()
Gets a string containing current date and time. |
double |
lengthWeighted(Instance instance,
double[] weights)
Get the norm-2 length of an instance assuming all attributes are numeric and utilizing the attribute weights |
java.util.Enumeration |
listOptions()
Returns an enumeration describing the available options. |
protected double[] |
normalizeWeights(double[] weights)
Normalize weights |
void |
printTopAttributes(double[] weights,
int n,
int iteration)
Print the heaviest-weighted attributes for a given set of weights |
void |
setEpsilon(double epsilon)
Set the convergence criterion |
void |
setLearningRate(double learningRate)
Set the learning rate |
void |
setMaxIterations(int maxIterations)
Set the maximum number of update iterations rate |
void |
setNumNegPairs(int numNegPairs)
Set the number of different-class training pairs |
void |
setNumPosPairs(int numPosPairs)
Set the number of same-class training pairs |
void |
setOptions(java.lang.String[] options)
Parses a given list of options. |
void |
setSelector(PairwiseSelector selector)
Set the pairwise selector |
java.lang.String |
toString()
Obtain a textual description of the metriclearner |
void |
trainMetric(LearnableMetric metric,
Instances instances)
Train a given metric using given training instances |
Methods inherited from class weka.core.metrics.MetricLearner |
createDiffInstanceLists, createDiffInstances, createDiffInstances, forName, getAttrInfoForDiffInstance |
Methods inherited from class java.lang.Object |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait |
Field Detail |
protected LearnableMetric m_metric
protected int m_maxIterations
protected double m_learningRate
protected Instances m_instances
protected java.util.ArrayList m_pairList
protected int m_numPosPairs
protected int m_numNegPairs
protected double m_epsilon
protected PairwiseSelector m_selector
Constructor Detail |
public GDMetricLearner()
Method Detail |
public void trainMetric(LearnableMetric metric, Instances instances) throws java.lang.Exception
trainMetric
in class MetricLearner
metric
- the metric to traininstances
- data to train the metric on
java.lang.Exception
- if training has gone bad.protected double[] calculateGradients(double[] weights) throws java.lang.Exception
weights
- the current weights vector
java.lang.Exception
protected double[] normalizeWeights(double[] weights)
weights
- an unnormalized array of weights
public double lengthWeighted(Instance instance, double[] weights)
public double getSimilarity(Instance instance1, Instance instance2) throws java.lang.Exception
getSimilarity
in class MetricLearner
instance1
- first instance of a pairinstance2
- second instance of a pair
java.lang.Exception
public double getDistance(Instance instance1, Instance instance2) throws java.lang.Exception
getDistance
in class MetricLearner
instance1
- first instance of a pairinstance2
- second instance of a pair
java.lang.Exception
public void setEpsilon(double epsilon)
epsilon
- the maximum sum of weight updates required for GD to convergepublic double getEpsilon()
public void setLearningRate(double learningRate)
learningRate
- the gradient update coefficientpublic double getLearningRate()
public void setMaxIterations(int maxIterations)
maxIterations
- the maximum number of gradient updatespublic int getMaxIterations()
public void setNumPosPairs(int numPosPairs)
numPosPairs
- the number of same-class training pairs to create for trainingpublic int getNumPosPairs()
public void setNumNegPairs(int numNegPairs)
numNegPairs
- the number of different-class training pairs to create for trainingpublic int getNumNegPairs()
public void setSelector(PairwiseSelector selector)
selector
- the selector for training pairspublic PairwiseSelector getSelector()
public java.lang.String[] getOptions()
getOptions
in interface OptionHandler
public void setOptions(java.lang.String[] options) throws java.lang.Exception
-B classifierstring
setOptions
in interface OptionHandler
options
- the list of options as an array of strings
java.lang.Exception
- if an option is not supportedprotected static java.lang.String getTimestamp()
public java.util.Enumeration listOptions()
listOptions
in interface OptionHandler
public java.lang.String toString()
public static java.lang.String concatStringArray(java.lang.String[] strings)
strings
- an array of stringsprotected java.util.ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs)
public void printTopAttributes(double[] weights, int n, int iteration)
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |