PlusML
Loading...
Searching...
No Matches
Public Member Functions | List of all members
plusml::MulticlassSVM Class Reference

Class implementing multiclassification SVM. More...

#include <multiclass_svm.h>

Public Member Functions

 MulticlassSVM (uint64_t features, uint64_t classes, ClassificationMode mode, bool bias_enabled=true)
 Constructor for MulticlassSVM.
 
uint64_t Features () const
 Get number of features in each sample for the model.
 
void FitSGD (const Eigen::MatrixXf &samples, const Eigen::MatrixXi &targets, float learning_rate, float l2_alpha, uint64_t batch_size, uint64_t epochs)
 Fit model using stochastic gradient descent with given hyperparameters.
 
Eigen::MatrixXi Predict (Eigen::MatrixXf samples)
 Get predictions for a given matrix of samples.
 

Detailed Description

Class implementing multiclassification SVM.

Pass plusml::kOneVsOne or plusml::kOneVsAll to the constructor to select classification mode.

Constructor & Destructor Documentation

◆ MulticlassSVM()

plusml::MulticlassSVM::MulticlassSVM ( uint64_t features,
uint64_t classes,
ClassificationMode mode,
bool bias_enabled = true )

Constructor for MulticlassSVM.

Parameters
featuresNumber of features in each sample
classesNumber of classes in classification problem
modeClassification mode (plusml::kOneVsOne or plusml::kOneVsAll)
bias_enabledSpecifies whether to use bias or not

Member Function Documentation

◆ Features()

uint64_t plusml::MulticlassSVM::Features ( ) const

Get number of features in each sample for the model.

Returns
Number of features in each sample for the model

◆ FitSGD()

void plusml::MulticlassSVM::FitSGD ( const Eigen::MatrixXf & samples,
const Eigen::MatrixXi & targets,
float learning_rate,
float l2_alpha,
uint64_t batch_size,
uint64_t epochs )

Fit model using stochastic gradient descent with given hyperparameters.

Parameters
samplesMatrix of samples (MxN where M is number of samples and N is number of features in each sample)
targetsMatrix of targets (Mx1 where M is number of samples)
learning_rateLearning rate for SGD
l2_alphaCoefficient for L2 regularization
batch_sizeBatch size for SGD
epochsNumber of epochs for SGD

◆ Predict()

Eigen::MatrixXi plusml::MulticlassSVM::Predict ( Eigen::MatrixXf samples)

Get predictions for a given matrix of samples.

Parameters
samplesMatrix of samples (MxN where M is number of samples and N is number of features in each sample)
Returns
Matrix of predicted classes (Mx1 where M is number of samples)

The documentation for this class was generated from the following files: