forked from nnrepo/LSM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReadouts.h
102 lines (99 loc) · 3.06 KB
/
Readouts.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#ifndef READOUTS
#define READOUTS
#include "NemoLSM.h"
#include <cblas.h>
/**
* @class Readout
* @brief An implementation of a generic readout component.
*
* @details The class is an interface to the basic funcitonality of the
* readout. That is, it provides:
* - Training and filtering capabilities
* - Error handling
* Override the train virtual function to implement learning.
*
* @version 1.0
* @author Emmanouil Hourdakis
* @email [email protected]
*/
class Readout {
public:
LSM *_lsm; ///> A pointer to the liquid which the readout is attached
Training *_training; ///> The training data
VectorXd _solution; ///> The solution vector
unsigned int _samplingInterval; ///> Sampling interval for the filtering of the readout
/**
* @function Readout
* @brief Constructor, used to initialize the variables
*
* @param lsm pointer to the liquid
* @param training the training data
* @param samplingInterval the sampling interval for the filtering
*/
Readout(LSM *lsm, Training *training, unsigned int samplingInterval) ;
/**
* @function train
* @brief Virtual abstract function
*
* @param filter the filtered data
* @param trainingVector the training data
*/
virtual VectorXd train(Eigen::MatrixXd filter, VectorXd trainingVector) = 0;
/**
* @function trainData
* @brief train the readout using the training data and the train function
* @param layerIndex the index of the layer to use for training
*/
MatrixXd trainData(unsigned int layerIndex);
/**
* @function getError
* @brief returns the error of the readout
* @param layerIndex the index of the layer to use for training
*/
virtual double getError(MatrixXd solved);
};
/**
* @class JacobiReadout
* @brief A Readout descendant, which implements Jacobi
*
* @version 1.0
*/
class JacobiReadout : public Readout {
public:
JacobiReadout(LSM *lsm, Training *training, unsigned int samplingInterval);
virtual VectorXd train(Eigen::MatrixXd filter, VectorXd trainingVector);
};
/**
* @class CholeskyReadout
* @brief A Readout descendant, which implements Cholesky factorization
*
* @version 1.0
*/
class CholeskyReadout : public Readout {
public:
CholeskyReadout(LSM *lsm, Training *training, unsigned int samplingInterval);
virtual VectorXd train(Eigen::MatrixXd filter, VectorXd trainingVector);
};
/**
* @class NormalEquationsReadout
* @brief A Readout descendant, which implements normal equations
*
* @version 1.0
*/
class NormalEquationsReadout : public Readout {
public:
NormalEquationsReadout(LSM *lsm, Training *training, unsigned int samplingInterval);
virtual VectorXd train(Eigen::MatrixXd filter, VectorXd trainingVector) ;
};
/**
* @class BLASLeastSquares
* @brief A Readout descendant, which implements least squares solving
*
* @version 1.0
*/
class BLASLeastSquares : public Readout {
public:
BLASLeastSquares(LSM *lsm, Training *training, unsigned int samplingInterval);
virtual VectorXd train(Eigen::MatrixXd filter, VectorXd trainingVector) ;
};
#endif