PlusML
Loading...
Searching...
No Matches
mse_grad.h
1#ifndef MSE_GRAD_H
2#define MSE_GRAD_H
3
4#include "PlusML/gradient/loss_gradient.h"
5#include "PlusML/util.h"
6
7namespace plusml {
11class EXPORT MSEGrad : public LossGradient {
12public:
20 Eigen::MatrixXf Compute(const Eigen::MatrixXf& w,
21 const Eigen::MatrixXf& X,
22 const Eigen::MatrixXf& y) const override;
23
28 void L2Regularization(const float c);
29private:
30 bool l2_regularization_enabled_ = false;
31 float l2_regularization_coefficient_ = 0;
32};
33} //namespace plusml
34
35#endif //MSE_GRAD_H
Base class for loss gradient implementations.
Definition loss_gradient.h:11
Class implementing Mean Square Error gradient.
Definition mse_grad.h:11