PlusML
Loading...
Searching...
No Matches
Public Member Functions | List of all members
plusml::MulticlassLogisticRegression Class Reference

Class implementing multiclassification logistic regression. More...

#include <multiclass_logistic_regression.h>

Public Member Functions

 MulticlassLogisticRegression (uint64_t features, uint64_t classes, bool bias_enabled=true)
 Constructor for MulticlassLogisticRegression.
 
uint64_t Features () const
 Get number of features in each sample for the model.
 
void FitSGD (const Eigen::MatrixXf &samples, const Eigen::MatrixXi &targets, float learning_rate, uint64_t batch_size, uint64_t epochs)
 Fit model using stochastic gradient descent with given hyperparameters.
 
Eigen::MatrixXf Predict (Eigen::MatrixXf samples)
 Get predictions for a given matrix of samples.
 

Detailed Description

Class implementing multiclassification logistic regression.

Constructor & Destructor Documentation

◆ MulticlassLogisticRegression()

plusml::MulticlassLogisticRegression::MulticlassLogisticRegression ( uint64_t features,
uint64_t classes,
bool bias_enabled = true )

Constructor for MulticlassLogisticRegression.

Parameters
featuresNumber of features in each sample
classesNumber of classes in classification problem
bias_enabledSpecifies whether to use bias or not

Member Function Documentation

◆ Features()

uint64_t plusml::MulticlassLogisticRegression::Features ( ) const

Get number of features in each sample for the model.

Returns
Number of features in each sample for the model

◆ FitSGD()

void plusml::MulticlassLogisticRegression::FitSGD ( const Eigen::MatrixXf & samples,
const Eigen::MatrixXi & targets,
float learning_rate,
uint64_t batch_size,
uint64_t epochs )

Fit model using stochastic gradient descent with given hyperparameters.

Parameters
samplesMatrix of samples (MxN where M is number of samples and N is number of features in each sample)
targetsMatrix of targets (Mx1 where M is number of samples)
learning_rateLearning rate for SGD
batch_sizeBatch size for SGD
epochsNumber of epochs for SGD

◆ Predict()

Eigen::MatrixXf plusml::MulticlassLogisticRegression::Predict ( Eigen::MatrixXf samples)

Get predictions for a given matrix of samples.

Parameters
samplesMatrix of samples (MxN where M is number of samples and N is number of features in each sample)
Returns
Matrix of predicted probabilities (MxN where M is number of samples and N is number of classes)

The documentation for this class was generated from the following files: