OpenMS  2.4.0
Classes | Public Types | Public Member Functions | Protected Types | Protected Member Functions | Static Protected Member Functions | Protected Attributes | List of all members
SimpleSVM Class Reference

Simple interface to support vector machines for classification (via LIBSVM). More...

#include <OpenMS/ANALYSIS/SVM/SimpleSVM.h>

Inheritance diagram for SimpleSVM:
DefaultParamHandler

Classes

struct  Prediction
 SVM prediction result. More...
 

Public Types

typedef std::map< String, std::vector< double > > PredictorMap
 Mapping from predictor name to vector of predictor values. More...
 

Public Member Functions

 SimpleSVM ()
 Default constructor. More...
 
 ~SimpleSVM () override
 Destructor. More...
 
void setup (PredictorMap &predictors, const std::map< Size, Int > &labels)
 Load data and train a model. More...
 
void predict (std::vector< Prediction > &predictions, std::vector< Size > indexes=std::vector< Size >()) const
 Predict class labels (and probabilities). More...
 
void getFeatureWeights (std::map< String, double > &feature_weights) const
 Get the weights used for features (predictors) in the SVM model. More...
 
void writeXvalResults (const String &path) const
 Write cross-validation (parameter optimization) results to a CSV file. More...
 
- Public Member Functions inherited from DefaultParamHandler
 DefaultParamHandler (const String &name)
 Constructor with name that is displayed in error messages. More...
 
 DefaultParamHandler (const DefaultParamHandler &rhs)
 Copy constructor. More...
 
virtual ~DefaultParamHandler ()
 Destructor. More...
 
virtual DefaultParamHandleroperator= (const DefaultParamHandler &rhs)
 Assignment operator. More...
 
virtual bool operator== (const DefaultParamHandler &rhs) const
 Equality operator. More...
 
void setParameters (const Param &param)
 Sets the parameters. More...
 
const ParamgetParameters () const
 Non-mutable access to the parameters. More...
 
const ParamgetDefaults () const
 Non-mutable access to the default parameters. More...
 
const StringgetName () const
 Non-mutable access to the name. More...
 
void setName (const String &name)
 Mutable access to the name. More...
 
const std::vector< String > & getSubsections () const
 Non-mutable access to the registered subsections. More...
 

Protected Types

typedef std::vector< std::vector< double > > SVMPerformance
 Classification performance for different param. combinations (C/gamma): More...
 

Protected Member Functions

void scaleData_ (PredictorMap &predictors) const
 Scale predictor values to range 0-1. More...
 
void convertData_ (const PredictorMap &predictors)
 Convert predictors to LIBSVM format. More...
 
std::pair< double, doublechooseBestParameters_ () const
 Choose best SVM parameters based on cross-validation results. More...
 
void optimizeParameters_ ()
 Run cross-validation to optimize SVM parameters. More...
 
- Protected Member Functions inherited from DefaultParamHandler
virtual void updateMembers_ ()
 This method is used to update extra member variables at the end of the setParameters() method. More...
 
void defaultsToParam_ ()
 Updates the parameters after the defaults have been set in the constructor. More...
 

Static Protected Member Functions

static void printNull_ (const char *)
 Dummy function to suppress LIBSVM output. More...
 

Protected Attributes

std::vector< std::vector< struct svm_node > > nodes_
 Values of predictors (LIBSVM format) More...
 
struct svm_problem data_
 SVM training data (LIBSVM format) More...
 
struct svm_parameter svm_params_
 SVM parameters (LIBSVM format) More...
 
struct svm_model * model_
 Pointer to SVM model (LIBSVM format) More...
 
std::vector< Stringpredictor_names_
 Names of predictors in the model (excluding uninformative ones) More...
 
Size n_parts_
 Number of partitions for cross-validation. More...
 
std::vector< doublelog2_C_
 Parameter values to try during optimization. More...
 
std::vector< doublelog2_gamma_
 
SVMPerformance performance_
 Cross-validation results. More...
 
- Protected Attributes inherited from DefaultParamHandler
Param param_
 Container for current parameters. More...
 
Param defaults_
 Container for default parameters. This member should be filled in the constructor of derived classes! More...
 
std::vector< Stringsubsections_
 Container for registered subsections. This member should be filled in the constructor of derived classes! More...
 
String error_name_
 Name that is displayed in error messages during the parameter checking. More...
 
bool check_defaults_
 If this member is set to false no checking if parameters in done;. More...
 
bool warn_empty_defaults_
 If this member is set to false no warning is emitted when defaults are empty;. More...
 

Detailed Description

Simple interface to support vector machines for classification (via LIBSVM).

This class supports (multi-class) classification with a linear or RBF kernel. It uses cross-validation to optimize the SVM parameters C and (RBF kernel only) gamma.

SVM models are generated by the the setup() method. To simplify the scaling of input data (predictors), the data for both the training and the test set together are passed in as parameter predictors. Given N observations of M predictors, the data are coded as a map of predictors (size M), each a numeric vector of values for different observations (size N).

The parameter labels of setup() defines the training set; it contains the indexes of observations (corresponding to positions in the vectors in predictors) together with the class labels for training.

To predict class labels based on a model, use the predict() method. The parameter indexes of predict() takes a vector of indexes corresponding to the observations for which predictions should be made. (With an empty vector, the default, predictions are made for all observations, including those used for training.)

Parameters of this class are:

NameTypeDefaultRestrictionsDescription
kernel stringRBF RBF, linearSVM kernel
xval int5 min: 1Number of partitions for cross-validation (parameter optimization)
log2_C float list[-5, -3, -1, 1, 3, 5, 7, 9, 11, 13, 15]  Values to try for the SVM parameter 'C' during parameter optimization. A value 'x' is used as 'C = 2^x'.
log2_gamma float list[-15, -13, -11, -9, -7, -5, -3, -1, 1, 3]  Values to try for the SVM parameter 'gamma' during parameter optimization (RBF kernel only). A value 'x' is used as 'gamma = 2^x'.
epsilon float0.001 min: 0Stopping criterion
cache_size float100 min: 1Size of the kernel cache (in MB)
no_shrinking stringfalse true, falseDisable the shrinking heuristics

Note:

Member Typedef Documentation

◆ PredictorMap

typedef std::map<String, std::vector<double> > PredictorMap

Mapping from predictor name to vector of predictor values.

◆ SVMPerformance

typedef std::vector<std::vector<double> > SVMPerformance
protected

Classification performance for different param. combinations (C/gamma):

Constructor & Destructor Documentation

◆ SimpleSVM()

SimpleSVM ( )

Default constructor.

◆ ~SimpleSVM()

~SimpleSVM ( )
override

Destructor.

Member Function Documentation

◆ chooseBestParameters_()

std::pair<double, double> chooseBestParameters_ ( ) const
protected

Choose best SVM parameters based on cross-validation results.

◆ convertData_()

void convertData_ ( const PredictorMap predictors)
protected

Convert predictors to LIBSVM format.

◆ getFeatureWeights()

void getFeatureWeights ( std::map< String, double > &  feature_weights) const

Get the weights used for features (predictors) in the SVM model.

Currently only supported for two-class classification. If a linear kernel is used, the weights are informative for ranking features.

Exceptions
Exception::Preconditionif no model has been trained, or if the classification involves more than two classes

◆ optimizeParameters_()

void optimizeParameters_ ( )
protected

Run cross-validation to optimize SVM parameters.

◆ predict()

void predict ( std::vector< Prediction > &  predictions,
std::vector< Size indexes = std::vector< Size >() 
) const

Predict class labels (and probabilities).

Parameters
predictionsOutput vector of prediction results (same order as indexes).
indexesVector of observation indexes for which predictions are desired. If empty (default), predictions are made for all observations.
Exceptions
Exception::Preconditionif no model has been trained
Exception::InvalidValueif an invalid index is used in indexes

◆ printNull_()

static void printNull_ ( const char *  )
inlinestaticprotected

Dummy function to suppress LIBSVM output.

◆ scaleData_()

void scaleData_ ( PredictorMap predictors) const
protected

Scale predictor values to range 0-1.

◆ setup()

void setup ( PredictorMap predictors,
const std::map< Size, Int > &  labels 
)

Load data and train a model.

Parameters
predictorsMapping from predictor name to vector of predictor values (for different observations). All vectors should have the same length; values will be changed by scaling.
labelsMapping from observation index to class label in the training set.
Exceptions
Exception::IllegalArgumentif predictors is empty
Exception::InvalidValueif an invalid index is used in labels
Exception::MissingInformationif there are fewer than two class labels in labels, or if there are not enough observations for cross-validation

◆ writeXvalResults()

void writeXvalResults ( const String path) const

Write cross-validation (parameter optimization) results to a CSV file.

Member Data Documentation

◆ data_

struct svm_problem data_
protected

SVM training data (LIBSVM format)

◆ log2_C_

std::vector<double> log2_C_
protected

Parameter values to try during optimization.

◆ log2_gamma_

std::vector<double> log2_gamma_
protected

◆ model_

struct svm_model* model_
protected

Pointer to SVM model (LIBSVM format)

◆ n_parts_

Size n_parts_
protected

Number of partitions for cross-validation.

◆ nodes_

std::vector<std::vector<struct svm_node> > nodes_
protected

Values of predictors (LIBSVM format)

◆ performance_

SVMPerformance performance_
protected

Cross-validation results.

◆ predictor_names_

std::vector<String> predictor_names_
protected

Names of predictors in the model (excluding uninformative ones)

◆ svm_params_

struct svm_parameter svm_params_
protected

SVM parameters (LIBSVM format)