diff --git a/CMakeLists.txt b/CMakeLists.txt index ca0f431..54ba0f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,13 +9,19 @@ set(intrinsic "none" CACHE STRING "The possible intrinsics are the following: no # Default c++ flags. set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS}") -set(CMAKE_CXX_FLAGS_DEBUG "-ggdb -Wall --pedantic ${CMAKE_CXX_FLAGS_DEBUG}") +set(CMAKE_CXX_FLAGS_DEBUG "-ggdb -O0 -Wall --pedantic ${CMAKE_CXX_FLAGS_DEBUG}") set(CMAKE_CXX_FLAGS_RELEASE "-O3 ${CMAKE_CXX_FLAGS_RELEASE}") if (CMAKE_BUILD_TYPE STREQUAL "Debug") add_definitions("-DDEBUG") endif() +if (ENABLE_DB STREQUAL "true") + MESSAGE(STATUS "DB Enabled: ${ENABLE_DB}") + add_definitions("-DENABLE_DB") + set(ENABLE_DB "true") +endif() + add_subdirectory(src) add_subdirectory(lib) add_subdirectory(test) diff --git a/README.md b/README.md index 673c0d7..568e835 100644 --- a/README.md +++ b/README.md @@ -3,32 +3,36 @@ rl [![Build Status](http://ci.joeyandres.com/job/rl-unit-test-master/badge/icon)](http://ci.joeyandres.com/job/rl-unit-test-master/) -Modularized various Reinforcement Learning Algorithm library. -See test/include and test/src for examples. - -Note: This is currently only built for Linux systems. -Some threading libraries are linux specific (even the latest c++ standard which is supposed to be platform independent). +Modularized various Reinforcement Learning Algorithm library. # Compilation and Installation -### Caveat -**rl** have some minor _double precision floating point_ issues in older compilers and os. -The mountain car problem in test won't converge to a an optimal solution in osx and -old g++ compiler (e.g. g++ 4.2.1). For optimal performance, use linux and new g++ compiler. - ### Dependency: #### Required * g++-4.9 or greater or clang. * cmake 3.2.2 or greater. -* boost v1.62 (might work for version < 1.62) +* boost v1.59 or greater. +### Optional: To enable cassandradb +* cassandra v3.9 or greater. +* [datastax-cpp-driver](https://github.com/datastax/cpp-driver) v2.5 or greater. ### Installing dependencies Ubuntu 16.04: `sudo apt install g++ cmake libboost-all-dev` -### Building +// TODO: Installing dependencies from http://downloads.datastax.com/cpp-driver/ubuntu/16.04/ +// TODO: Make a script to do this? + +### Building (no cassandradb) +1. `mkdir build` +2. `cd build` +3. `cmake .. -DCMAKE_BUILD_TYPE=Release` +4. `make -j16` +5. `sudo make install` + +### Building (with cassandradb) 1. `mkdir build` 2. `cd build` -3. `cmake ..` +3. `cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_DB=true` 4. `make -j16` 5. `sudo make install` diff --git a/include/algorithm/gradient-descent/GradientDescentTileCode.h b/include/algorithm/gradient-descent/GradientDescentTileCode.h index 0f40302..3e3c491 100644 --- a/include/algorithm/gradient-descent/GradientDescentTileCode.h +++ b/include/algorithm/gradient-descent/GradientDescentTileCode.h @@ -39,15 +39,21 @@ namespace algorithm { * \brief Gradient Descent implementation. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT, + size_t STATE_DIM> class GradientDescentTileCode : - public GradientDescentTileCodeAbstract { + public GradientDescentTileCodeAbstract< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM> { public: using GradientDescentTileCodeAbstract< - D, NUM_TILINGS, STATE_DIM>::GradientDescentTileCodeAbstract; + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::GradientDescentTileCodeAbstract; void updateWeights( const typename GradientDescentAbstract< @@ -63,8 +69,9 @@ class GradientDescentTileCode : const FLOAT reward) override; }; -template -void GradientDescentTileCode::updateWeights( +template +void GradientDescentTileCode< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::updateWeights( const typename GradientDescentAbstract< D, STATE_DIM>::spStateParam& currentStateVector, diff --git a/include/algorithm/gradient-descent/GradientDescentTileCodeAbstract.h b/include/algorithm/gradient-descent/GradientDescentTileCodeAbstract.h index c156e87..593442b 100644 --- a/include/algorithm/gradient-descent/GradientDescentTileCodeAbstract.h +++ b/include/algorithm/gradient-descent/GradientDescentTileCodeAbstract.h @@ -32,10 +32,15 @@ namespace algorithm { * * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT, + size_t STATE_DIM> class GradientDescentTileCodeAbstract : public GradientDescentAbstract { public: @@ -45,10 +50,11 @@ class GradientDescentTileCodeAbstract : * @param discountRate discount rate for gradient descent. * @param lambda How influential is current state-action to ther state-action. */ - GradientDescentTileCodeAbstract(const spTileCode& tileCode, - rl::FLOAT stepSize, - rl::FLOAT discountRate, - rl::FLOAT lambda); + GradientDescentTileCodeAbstract( + const spTileCode& tileCode, + rl::FLOAT stepSize, + rl::FLOAT discountRate, + rl::FLOAT lambda); /** * Get the value of the parameters in the real space. @@ -69,13 +75,13 @@ class GradientDescentTileCodeAbstract : * \brief Refers to the same object as _courseCode but this one is downcast'd * to spTileCode allowing access to tile code specific methods. */ - spTileCode _tileCode; + spTileCode _tileCode; }; -template +template GradientDescentTileCodeAbstract< - D, NUM_TILINGS, STATE_DIM>::GradientDescentTileCodeAbstract( - const spTileCode &tileCode, + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::GradientDescentTileCodeAbstract( + const spTileCode &tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda) : @@ -88,18 +94,18 @@ GradientDescentTileCodeAbstract< _tileCode = tileCode; } -template +template FLOAT GradientDescentTileCodeAbstract< - D, NUM_TILINGS, STATE_DIM>::getValueFromFeatureVector( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::getValueFromFeatureVector( const FEATURE_VECTOR& fv) const { return _tileCode->getValueFromFeatureVector(fv); } -template +template FEATURE_VECTOR GradientDescentTileCodeAbstract< - D, NUM_TILINGS, STATE_DIM>::getFeatureVector( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::getFeatureVector( const floatArray& parameters) const { return _tileCode->getFeatureVector(parameters); } diff --git a/include/algorithm/gradient-descent/GradientDescentTileCodeET.h b/include/algorithm/gradient-descent/GradientDescentTileCodeET.h index 8af94e8..b52ca94 100644 --- a/include/algorithm/gradient-descent/GradientDescentTileCodeET.h +++ b/include/algorithm/gradient-descent/GradientDescentTileCodeET.h @@ -38,12 +38,18 @@ namespace algorithm { /*! \class GradientDescentTileCodeET * \brief Gradient Descent eligibility traces. * \tparam D Number of dimension. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT, + size_t STATE_DIM> class GradientDescentTileCodeET : - public GradientDescentTileCodeAbstract { + public GradientDescentTileCodeAbstract< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM> { public: /** * @param tileCode Type of tile coding. @@ -52,7 +58,7 @@ class GradientDescentTileCodeET : * @param lambda How influential is current state-action to their state-action. */ GradientDescentTileCodeET( - const spTileCode& tileCode, + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda); @@ -109,50 +115,51 @@ class GradientDescentTileCodeET : std::vector _e; //!< Vector of eligibility traces. }; -template -GradientDescentTileCodeET::GradientDescentTileCodeET( - const spTileCode& tileCode, +template +GradientDescentTileCodeET< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::GradientDescentTileCodeET( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda) : GradientDescentTileCodeAbstract< - D, NUM_TILINGS, STATE_DIM>::GradientDescentTileCodeAbstract( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::GradientDescentTileCodeAbstract( tileCode, stepSize, discountRate, lambda) { _e = floatVector(this->_courseCode->getSize(), 0); } -template +template void GradientDescentTileCodeET< - D, NUM_TILINGS, STATE_DIM>::incrementEligibilityTraces( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::incrementEligibilityTraces( const FEATURE_VECTOR& fv) { for (rl::INT f : fv) { ++(this->_e)[f]; } } -template +template void GradientDescentTileCodeET< - D, NUM_TILINGS, STATE_DIM>::replaceEligibilityTraces( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::replaceEligibilityTraces( const FEATURE_VECTOR& fv) { for (rl::INT f : fv) { this->_e[f] = 1; } } -template +template void GradientDescentTileCodeET< - D, NUM_TILINGS, STATE_DIM>::decreaseEligibilityTraces() { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::decreaseEligibilityTraces() { size_t n = this->getSize(); for (size_t i = 0; i < n; i++) { this->_e[i] *= this->_discountRateTimesLambda; } } -template +template void GradientDescentTileCodeET< - D, NUM_TILINGS, STATE_DIM>::backUpWeights(FLOAT tdError) { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::backUpWeights(FLOAT tdError) { rl::FLOAT multiplier = (this->_stepSize / NUM_TILINGS) * tdError; size_t n = this->getSize(); for (size_t i = 0; i < n-1; i++) { @@ -160,9 +167,9 @@ GradientDescentTileCodeET< } } -template +template void GradientDescentTileCodeET< - D, NUM_TILINGS, STATE_DIM>::updateWeights( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::updateWeights( const typename GradientDescentAbstract< D, STATE_DIM>::spStateParam& currentStateVector, @@ -188,9 +195,9 @@ void GradientDescentTileCodeET< decreaseEligibilityTraces(); } -template +template void GradientDescentTileCodeET< - D, NUM_TILINGS, STATE_DIM>::resetEligibilityTraces() { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::resetEligibilityTraces() { std::fill(&this->_e[0], &this->_e[0] + this->getSize(), 0); } diff --git a/include/algorithm/gradient-descent/QLearningETGD.h b/include/algorithm/gradient-descent/QLearningETGD.h index 69086d6..8a9a67d 100644 --- a/include/algorithm/gradient-descent/QLearningETGD.h +++ b/include/algorithm/gradient-descent/QLearningETGD.h @@ -30,14 +30,19 @@ namespace algorithm { * and learning policy). * \tparam D Number of dimensions. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. This defaults to D-1. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class QLearningETGD : - public ReinforcementLearningGDET { + public ReinforcementLearningGDET { public: - QLearningETGD(const spTileCode& tileCode, + QLearningETGD(const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -45,16 +50,16 @@ class QLearningETGD : D, STATE_DIM>::spPolicy& policy); }; -template -QLearningETGD::QLearningETGD( - const spTileCode& tileCode, +template +QLearningETGD::QLearningETGD( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, const typename ReinforcementLearningGDAbstract< D, STATE_DIM>::spPolicy& controlPolicy) : ReinforcementLearningGDET< - D, NUM_TILINGS, STATE_DIM>::ReinforcementLearningGDET( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::ReinforcementLearningGDET( tileCode, stepSize, discountRate, lambda, controlPolicy) { } diff --git a/include/algorithm/gradient-descent/QLearningETGDFactory.h b/include/algorithm/gradient-descent/QLearningETGDFactory.h index d9abfdd..18e76a6 100644 --- a/include/algorithm/gradient-descent/QLearningETGDFactory.h +++ b/include/algorithm/gradient-descent/QLearningETGDFactory.h @@ -28,16 +28,25 @@ namespace algorithm { * \brief Factory method for QLearningETGD. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in state. * Implies that action is D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class QLearningETGDFactory : public ReinforcementLearningGDFactory< - D, NUM_TILINGS, STATE_DIM, QLearningETGD> { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM, QLearningETGD> { public: using ReinforcementLearningGDFactory< - D, NUM_TILINGS, STATE_DIM, QLearningETGD>::ReinforcementLearningGDFactory; + D, + NUM_TILINGS, + WEIGHT_CONT, + STATE_DIM, + QLearningETGD>::ReinforcementLearningGDFactory; }; } // namespace algorithm diff --git a/include/algorithm/gradient-descent/QLearningGD.h b/include/algorithm/gradient-descent/QLearningGD.h index cbce443..3e50755 100644 --- a/include/algorithm/gradient-descent/QLearningGD.h +++ b/include/algorithm/gradient-descent/QLearningGD.h @@ -30,13 +30,19 @@ namespace algorithm { * and learning policy). * \tparam D Number of dimensions. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. This defaults to D-1. * This also implies ACTION_DIM = D - STATE_DIM. */ -template -class QLearningGD : public ReinforcementLearningGD { +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> +class QLearningGD : + public ReinforcementLearningGD { public: - QLearningGD(const spTileCode& tileCode, + QLearningGD(const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -44,15 +50,16 @@ class QLearningGD : public ReinforcementLearningGD { D, STATE_DIM>::spPolicy& policy); }; -template -QLearningGD::QLearningGD( - const spTileCode& tileCode, +template +QLearningGD::QLearningGD( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, const typename ReinforcementLearningGDAbstract< D, STATE_DIM>::spPolicy& controlPolicy) : - ReinforcementLearningGD::ReinforcementLearningGD( + ReinforcementLearningGD< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::ReinforcementLearningGD( tileCode, stepSize, discountRate, lambda, controlPolicy) { } diff --git a/include/algorithm/gradient-descent/QLearningGDFactory.h b/include/algorithm/gradient-descent/QLearningGDFactory.h index 10f6d96..9651d71 100644 --- a/include/algorithm/gradient-descent/QLearningGDFactory.h +++ b/include/algorithm/gradient-descent/QLearningGDFactory.h @@ -28,16 +28,25 @@ namespace algorithm { * \brief Factory method for QLearningGD. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in state. * Implies that action is D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class QLearningGDFactory : public ReinforcementLearningGDFactory< - D, NUM_TILINGS, STATE_DIM, QLearningGD> { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM, QLearningGD> { public: using ReinforcementLearningGDFactory< - D, NUM_TILINGS, STATE_DIM, QLearningGD>::ReinforcementLearningGDFactory; + D, + NUM_TILINGS, + WEIGHT_CONT, + STATE_DIM, + QLearningGD>::ReinforcementLearningGDFactory; }; } // namespace algorithm diff --git a/include/algorithm/gradient-descent/ReinforcementLearningGD.h b/include/algorithm/gradient-descent/ReinforcementLearningGD.h index 4de5183..0700af4 100644 --- a/include/algorithm/gradient-descent/ReinforcementLearningGD.h +++ b/include/algorithm/gradient-descent/ReinforcementLearningGD.h @@ -34,15 +34,20 @@ namespace algorithm { * \brief Gradient descent implementation of Reinforcement Learning. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. This defaults to D-1. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class ReinforcementLearningGD : public ReinforcementLearningGDAbstract { public: ReinforcementLearningGD( - const spTileCode& tileCode, + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -51,9 +56,10 @@ class ReinforcementLearningGD : virtual ~ReinforcementLearningGD(); }; -template -ReinforcementLearningGD::ReinforcementLearningGD( - const spTileCode& tileCode, +template +ReinforcementLearningGD< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::ReinforcementLearningGD( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -63,13 +69,13 @@ ReinforcementLearningGD::ReinforcementLearningGD( D, STATE_DIM>::ReinforcementLearningGDAbstract( tileCode, stepSize, discountRate, lambda, policy) { this->_gradientDescent = spGradientDescentAbstract( - new GradientDescentTileCode( + new GradientDescentTileCode( tileCode, stepSize, discountRate, lambda)); } -template +template ReinforcementLearningGD< - D, NUM_TILINGS, STATE_DIM>::~ReinforcementLearningGD() { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::~ReinforcementLearningGD() { } } // namespace algorithm diff --git a/include/algorithm/gradient-descent/ReinforcementLearningGDAbstract.h b/include/algorithm/gradient-descent/ReinforcementLearningGDAbstract.h index 0cb2267..e0fd6ae 100644 --- a/include/algorithm/gradient-descent/ReinforcementLearningGDAbstract.h +++ b/include/algorithm/gradient-descent/ReinforcementLearningGDAbstract.h @@ -40,6 +40,7 @@ namespace algorithm { /*! \class ReinforcementLearningGDAbstract * \brief Gradient descent implementation of Reinforcement Learning. * \tparam D Number of dimension. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. * This also implies ACTION_DIM = D - STATE_DIM. */ diff --git a/include/algorithm/gradient-descent/ReinforcementLearningGDET.h b/include/algorithm/gradient-descent/ReinforcementLearningGDET.h index 6253b9c..93d0d8d 100644 --- a/include/algorithm/gradient-descent/ReinforcementLearningGDET.h +++ b/include/algorithm/gradient-descent/ReinforcementLearningGDET.h @@ -35,15 +35,20 @@ namespace algorithm { * with Eligibility Traces. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. This defaults to D-1. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class ReinforcementLearningGDET : public ReinforcementLearningGDAbstract { public: ReinforcementLearningGDET( - const spTileCode& tileCode, + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -52,10 +57,10 @@ class ReinforcementLearningGDET : virtual ~ReinforcementLearningGDET(); }; -template +template ReinforcementLearningGDET< - D, NUM_TILINGS, STATE_DIM>::ReinforcementLearningGDET( - const spTileCode& tileCode, + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::ReinforcementLearningGDET( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -65,13 +70,13 @@ ReinforcementLearningGDET< D, STATE_DIM>::ReinforcementLearningGDAbstract( tileCode, stepSize, discountRate, lambda, policy) { this->_gradientDescent = spGradientDescentAbstract( - new GradientDescentTileCodeET( + new GradientDescentTileCodeET( tileCode, stepSize, discountRate, lambda)); } -template +template ReinforcementLearningGDET< - D, NUM_TILINGS, STATE_DIM>::~ReinforcementLearningGDET() { + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::~ReinforcementLearningGDET() { } } // namespace algorithm diff --git a/include/algorithm/gradient-descent/ReinforcementLearningGDFactory.h b/include/algorithm/gradient-descent/ReinforcementLearningGDFactory.h index ecf5b40..b43b946 100644 --- a/include/algorithm/gradient-descent/ReinforcementLearningGDFactory.h +++ b/include/algorithm/gradient-descent/ReinforcementLearningGDFactory.h @@ -29,14 +29,16 @@ namespace algorithm { * \brief Factory method for ReinforcementLearningGDFactory. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in state. * Implies that action is D - STATE_DIM. * \tparam REINFORCEMENT_LEARNING_GD ReinforcementLearningGDAbstract child class. */ -template class REINFORCEMENT_LEARNING_GD> class ReinforcementLearningGDFactory : public ReinforcementLearningFactory< @@ -44,19 +46,18 @@ class ReinforcementLearningGDFactory : typename ReinforcementLearningGDAbstract::ActionParam> { public: ReinforcementLearningGDFactory( - const spTileCode& tileCode, + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, const typename ReinforcementLearningGDAbstract< - D, - STATE_DIM>::spPolicy& policy) { + D, STATE_DIM>::spPolicy& policy) { this->_instance = spLearningAlgorithm< typename ReinforcementLearningGDAbstract< D, STATE_DIM>::StateParam, typename ReinforcementLearningGDAbstract< D, STATE_DIM>::ActionParam>( - new REINFORCEMENT_LEARNING_GD( + new REINFORCEMENT_LEARNING_GD( tileCode, stepSize, discountRate, lambda, policy)); } }; diff --git a/include/algorithm/gradient-descent/SarsaETGD.h b/include/algorithm/gradient-descent/SarsaETGD.h index bb51eae..b2d1d98 100644 --- a/include/algorithm/gradient-descent/SarsaETGD.h +++ b/include/algorithm/gradient-descent/SarsaETGD.h @@ -32,14 +32,19 @@ namespace algorithm { * learning and action selection). * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. This defaults to D-1. * This also implies ACTION_DIM = D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class SarsaETGD final: - public ReinforcementLearningGDET { + public ReinforcementLearningGDET { public: - SarsaETGD(const spTileCode& tileCode, + SarsaETGD(const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -47,16 +52,16 @@ class SarsaETGD final: D, STATE_DIM>::spPolicy& policy); }; -template -SarsaETGD::SarsaETGD( - const spTileCode& tileCode, +template +SarsaETGD::SarsaETGD( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, const typename ReinforcementLearningGDAbstract< D, STATE_DIM>::spPolicy& policy) : ReinforcementLearningGDET< - D, NUM_TILINGS, STATE_DIM>::ReinforcementLearningGDET( + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::ReinforcementLearningGDET( tileCode, stepSize, discountRate, lambda, policy) { this->setLearningPolicy(policy); } diff --git a/include/algorithm/gradient-descent/SarsaETGDFactory.h b/include/algorithm/gradient-descent/SarsaETGDFactory.h index 0f4a48e..53b1342 100644 --- a/include/algorithm/gradient-descent/SarsaETGDFactory.h +++ b/include/algorithm/gradient-descent/SarsaETGDFactory.h @@ -28,15 +28,25 @@ namespace algorithm { * \brief Factory method for SarsaETGD. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in state. * Implies that action is D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class SarsaETGDFactory : - public ReinforcementLearningGDFactory { + public ReinforcementLearningGDFactory< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM, SarsaETGD> { public: using ReinforcementLearningGDFactory< - D, NUM_TILINGS, STATE_DIM, SarsaETGD>::ReinforcementLearningGDFactory; + D, + NUM_TILINGS, + WEIGHT_CONT, + STATE_DIM, + SarsaETGD>::ReinforcementLearningGDFactory; }; } // namespace algorithm diff --git a/include/algorithm/gradient-descent/SarsaGD.h b/include/algorithm/gradient-descent/SarsaGD.h index 9cb3e81..8430562 100644 --- a/include/algorithm/gradient-descent/SarsaGD.h +++ b/include/algorithm/gradient-descent/SarsaGD.h @@ -32,13 +32,19 @@ namespace algorithm { * learning and action selection). * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in State. This defaults to D-1. * This also implies ACTION_DIM = D - STATE_DIM. */ -template -class SarsaGD final: public ReinforcementLearningGD { +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> +class SarsaGD final: + public ReinforcementLearningGD { public: - SarsaGD(const spTileCode& tileCode, + SarsaGD(const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, @@ -46,15 +52,16 @@ class SarsaGD final: public ReinforcementLearningGD { D, STATE_DIM>::spPolicy & policy); }; -template -SarsaGD::SarsaGD( - const spTileCode& tileCode, +template +SarsaGD::SarsaGD( + const spTileCode& tileCode, rl::FLOAT stepSize, rl::FLOAT discountRate, rl::FLOAT lambda, const typename ReinforcementLearningGDAbstract< D, STATE_DIM>::spPolicy& policy) : - ReinforcementLearningGD::ReinforcementLearningGD( + ReinforcementLearningGD< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM>::ReinforcementLearningGD( tileCode, stepSize, discountRate, lambda, policy) { this->setLearningPolicy(policy); } diff --git a/include/algorithm/gradient-descent/SarsaGDFactory.h b/include/algorithm/gradient-descent/SarsaGDFactory.h index 20c5404..9da422b 100644 --- a/include/algorithm/gradient-descent/SarsaGDFactory.h +++ b/include/algorithm/gradient-descent/SarsaGDFactory.h @@ -28,15 +28,25 @@ namespace algorithm { * \brief Factory method for SarsaGD. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam STATE_DIM Number of dimension in state. * Implies that action is D - STATE_DIM. */ -template +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = coding::DEFAULT_TILE_CONT, + size_t STATE_DIM = D-1> class SarsaGDFactory : - public ReinforcementLearningGDFactory { + public ReinforcementLearningGDFactory< + D, NUM_TILINGS, WEIGHT_CONT, STATE_DIM, SarsaGD> { public: using ReinforcementLearningGDFactory< - D, NUM_TILINGS, STATE_DIM, SarsaGD>::ReinforcementLearningGDFactory; + D, + NUM_TILINGS, + WEIGHT_CONT, + STATE_DIM, + SarsaGD>::ReinforcementLearningGDFactory; }; } // namespace algorithm diff --git a/include/coding/TileCode.h b/include/coding/TileCode.h index 0d72ae8..155de08 100644 --- a/include/coding/TileCode.h +++ b/include/coding/TileCode.h @@ -25,10 +25,13 @@ #include #include #include +#include #include "../declares.h" +#include "../utility/IndexAccessorInterface.h" #include "CourseCode.h" #include "DimensionInfo.h" +#include "container/TileCodeContainer.h" using std::array; using std::vector; @@ -37,6 +40,8 @@ using std::shared_ptr; namespace rl { namespace coding { +using DEFAULT_TILE_CONT = vector; + /*! \class TileCode * \brief Base object encapsulate tile coding. * @@ -45,20 +50,28 @@ namespace coding { * * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template -class TileCode : public CourseCode { +template< + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = DEFAULT_TILE_CONT> +class TileCode : + public CourseCode, + public utility::IndexAccessorInterface { public: /** * @param dimensionalInfos An array of dimensionalInfos. */ explicit TileCode(const array, D>& dimensionalInfos); + TileCode(const array, D>& dimensionalInfos, + size_t sizeHint); - FLOAT& operator[](size_t i); - FLOAT operator[](size_t i) const; + typename WEIGHT_CONT::reference + at(size_t i) override; - virtual FLOAT& at(size_t i); - virtual FLOAT at(size_t i) const; + typename WEIGHT_CONT::value_type + at(size_t i) const override; /** * Hashed the parameter in Real space to a Natural space [0, infinity). @@ -98,13 +111,15 @@ class TileCode : public CourseCode { /** * @return Number of possible grid points. */ - size_t _calculateSizeCache(); + size_t _calculateSizeCache() const; + static size_t _calculateSize( + const array, D>& dims); protected: std::random_device _randomDevice; std::default_random_engine _pseudoRNG; - std::vector _w; //!< Vector of weights. + WEIGHT_CONT _w; //!< Vector of weights. /*! \var _sizeCache * @@ -130,16 +145,27 @@ class TileCode : public CourseCode { array, D> _dimensionalInfos; }; -template -using spTileCode = shared_ptr>; +template< + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = DEFAULT_TILE_CONT> +using spTileCode = shared_ptr>; -template -TileCode::TileCode( +template +TileCode::TileCode( const array, D>& dimensionalInfos) : - _dimensionalInfos(dimensionalInfos) { - // Calculate the size. - _sizeCache = _calculateSizeCache(); + TileCode::TileCode( + dimensionalInfos, + TileCode::_calculateSize(dimensionalInfos)) { +} +template +TileCode::TileCode( + const array, D>& dimensionalInfos, + size_t sizeHint) : + _dimensionalInfos(dimensionalInfos), + _sizeCache(sizeHint), + _w(WEIGHT_CONT(sizeHint, 0)) { // Calculate random offsets. std::uniform_real_distribution distribution(0, 1.0F); for (size_t i = 0; i < NUM_TILINGS; i++) { @@ -150,31 +176,29 @@ TileCode::TileCode( * this->_dimensionalInfos[j].getGeneralizationScale()); } } - - _w = floatVector(this->getSize(), 0); } -template -size_t TileCode::getNumTilings() const { +template +size_t TileCode::getNumTilings() const { return NUM_TILINGS; } -template -size_t TileCode::getSize() const { +template +size_t TileCode::getSize() const { return _sizeCache; } -template +template FLOAT -TileCode::getValueFromParameters( +TileCode::getValueFromParameters( const floatArray& parameters) const { FEATURE_VECTOR fv = std::move(this->getFeatureVector(parameters)); return this->getValueFromFeatureVector(fv); } -template -FLOAT TileCode::getValueFromFeatureVector( +template +FLOAT TileCode::getValueFromFeatureVector( const FEATURE_VECTOR& fv) const { rl::FLOAT sum = 0.0F; @@ -185,11 +209,17 @@ FLOAT TileCode::getValueFromFeatureVector( return sum; } -template -size_t TileCode::_calculateSizeCache() { - // Calculate the size. - rl::UINT size = 1; - for (const DimensionInfo& di : this->_dimensionalInfos) { +template +size_t TileCode::_calculateSizeCache() const { + return TileCode::_calculateSize( + _dimensionalInfos); +} + +template +size_t TileCode::_calculateSize( + const array, D>& dims) { + size_t size = 1; + for (auto& di : dims) { size *= di.GetGridCountReal(); } @@ -197,8 +227,8 @@ size_t TileCode::_calculateSizeCache() { return size; } -template -size_t TileCode::paramToGridValue( +template +size_t TileCode::paramToGridValue( rl::FLOAT param, size_t tilingIndex, size_t dimensionIndex) const { auto randomOffset = _randomOffsets.at(tilingIndex).at(dimensionIndex); @@ -218,25 +248,17 @@ size_t TileCode::paramToGridValue( ) * dimGridCountIdeal) / dimRangeMagnitude; // NOLINT: I'd like to make this easy to understand. } -template -FLOAT& TileCode::at(size_t i) { +template +typename WEIGHT_CONT::reference +TileCode::at(size_t i) { return _w.at(i); } -template -FLOAT TileCode::at(size_t i) const { +template +typename WEIGHT_CONT::value_type +TileCode::at(size_t i) const { return _w.at(i); } -template -FLOAT& TileCode::operator[](size_t i) { - return this->at(i); -} - -template -FLOAT TileCode::operator[](size_t i) const { - return this->at(i); -} - } // namespace coding } // namespace rl diff --git a/include/coding/TileCodeCorrect.h b/include/coding/TileCodeCorrect.h index 63568a5..b68c684 100644 --- a/include/coding/TileCodeCorrect.h +++ b/include/coding/TileCodeCorrect.h @@ -41,18 +41,19 @@ namespace coding { * * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template -class TileCodeCorrect : public TileCode { +template +class TileCodeCorrect : public TileCode { public: - using TileCode::TileCode; + using TileCode::TileCode; FEATURE_VECTOR getFeatureVector( const floatArray& parameters) const override; }; -template -FEATURE_VECTOR TileCodeCorrect::getFeatureVector( +template +FEATURE_VECTOR TileCodeCorrect::getFeatureVector( const floatArray& parameters) const { FEATURE_VECTOR fv; fv.resize(NUM_TILINGS); diff --git a/include/coding/TileCodeCorrectFactory.h b/include/coding/TileCodeCorrectFactory.h index 4a992f6..d6506a5 100644 --- a/include/coding/TileCodeCorrectFactory.h +++ b/include/coding/TileCodeCorrectFactory.h @@ -33,11 +33,13 @@ namespace coding { * \brief Factory method for TileCodeCorrect. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template +template class TileCodeCorrectFactory : - public TileCodeFactory { - using TileCodeFactory::TileCodeFactory; + public TileCodeFactory { + using TileCodeFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeCorrect>::TileCodeFactory; }; } // namespace coding diff --git a/include/coding/TileCodeFactory.h b/include/coding/TileCodeFactory.h index 5a82ee2..58e48b0 100644 --- a/include/coding/TileCodeFactory.h +++ b/include/coding/TileCodeFactory.h @@ -33,16 +33,19 @@ namespace coding { * \brief Factory method for TileCode. Abstract class. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam TILE_CODE_CLASS TileCode child that don't utilizes hashing. */ -template class TILE_CODE_CLASS> -class TileCodeFactory : public FactoryAbstract> { +template + class TILE_CODE_CLASS> +class TileCodeFactory : + public FactoryAbstract> { public: explicit TileCodeFactory( const array, D>& dimensionalInfos) { - this->_instance = spTileCode( - new TILE_CODE_CLASS(dimensionalInfos)); + this->_instance = spTileCode( + new TILE_CODE_CLASS(dimensionalInfos)); } }; diff --git a/include/coding/TileCodeHashedFactory.h b/include/coding/TileCodeHashedFactory.h index 911d3b7..e4935a7 100644 --- a/include/coding/TileCodeHashedFactory.h +++ b/include/coding/TileCodeHashedFactory.h @@ -33,17 +33,20 @@ namespace coding { * \brief Factory method for TileCode. Abstract class. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. * \tparam TILE_CODE_CLASS The TileCode child that utilizes hashing. */ -template class TILE_CODE_CLASS> +template + class TILE_CODE_CLASS> class TileCodeHashedFactory : - public TileCodeFactory { + public TileCodeFactory { public: TileCodeHashedFactory(const array, D>& dimensionalInfos, size_t sizeHint) { - this->_instance = spTileCode( - new TILE_CODE_CLASS(dimensionalInfos, sizeHint)); + this->_instance = spTileCode( + new TILE_CODE_CLASS( + dimensionalInfos, sizeHint)); } }; diff --git a/include/coding/TileCodeMt1993764.h b/include/coding/TileCodeMt1993764.h index ca13a3a..cd327c1 100644 --- a/include/coding/TileCodeMt1993764.h +++ b/include/coding/TileCodeMt1993764.h @@ -34,11 +34,12 @@ namespace coding { * \brief Tile Code using Mt1993764 hash.2 * \tparam D Number of dimensions. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template -class TileCodeMt1993764 : public TileCode { +template +class TileCodeMt1993764 : public TileCode { public: - using TileCode::TileCode; + using TileCode::TileCode; /** * @param dimensionalInfos @@ -57,18 +58,16 @@ class TileCodeMt1993764 : public TileCode { mutable std::mt19937_64 _prng; }; -template -TileCodeMt1993764::TileCodeMt1993764( +template +TileCodeMt1993764::TileCodeMt1993764( const array, D>& dimensionalInfos, size_t sizeHint) : - TileCode::TileCode(dimensionalInfos) { - if (sizeHint > this->_sizeCache) { - this->_sizeCache = sizeHint; - } + TileCode::TileCode(dimensionalInfos, sizeHint) { } -template -FEATURE_VECTOR TileCodeMt1993764::getFeatureVector( +template +FEATURE_VECTOR +TileCodeMt1993764::getFeatureVector( const floatArray& parameters) const { vector tileComponents(this->getDimension() + 1); FEATURE_VECTOR fv; diff --git a/include/coding/TileCodeMt1993764Factory.h b/include/coding/TileCodeMt1993764Factory.h index c370c30..68a6fff 100644 --- a/include/coding/TileCodeMt1993764Factory.h +++ b/include/coding/TileCodeMt1993764Factory.h @@ -33,15 +33,17 @@ namespace coding { * \brief Factory method for TileCodeMt1993764. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template +template class TileCodeMt1993764Factory : - public TileCodeHashedFactory { + public TileCodeHashedFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeMt1993764> { public: using TileCodeFactory< - D, NUM_TILINGS, TileCodeMt1993764>::TileCodeFactory; + D, NUM_TILINGS, WEIGHT_CONT, TileCodeMt1993764>::TileCodeFactory; using TileCodeHashedFactory< - D, NUM_TILINGS, TileCodeMt1993764>::TileCodeFactory; + D, NUM_TILINGS, WEIGHT_CONT, TileCodeMt1993764>::TileCodeFactory; }; } // namespace coding diff --git a/include/coding/TileCodeMurMur.h b/include/coding/TileCodeMurMur.h index c4a8a2a..cdc581e 100644 --- a/include/coding/TileCodeMurMur.h +++ b/include/coding/TileCodeMurMur.h @@ -35,27 +35,31 @@ namespace coding { * \brief Tile Code using MurMur3 hash. * \tparam D Number of dimensions. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template -class TileCodeMurMur : public TileCode { +template < + size_t D, + size_t NUM_TILINGS, + class WEIGHT_CONT = DEFAULT_TILE_CONT> +class TileCodeMurMur : public TileCode { public: - using TileCode::TileCode; + using TileCode::TileCode; TileCodeMurMur(const array, D>& dimensionalInfos, size_t sizeHint); FEATURE_VECTOR getFeatureVector( const floatArray& parameters) const override; }; -template -TileCodeMurMur::TileCodeMurMur( +template +TileCodeMurMur::TileCodeMurMur( const array, D>& dimensionalInfos, size_t sizeHint) : - TileCode::TileCode(dimensionalInfos) { - this->_sizeCache = sizeHint; + TileCode::TileCode(dimensionalInfos, sizeHint) { } -template -FEATURE_VECTOR TileCodeMurMur::getFeatureVector( +template +FEATURE_VECTOR +TileCodeMurMur::getFeatureVector( const floatArray& parameters) const { FEATURE_VECTOR fv; diff --git a/include/coding/TileCodeMurMurFactory.h b/include/coding/TileCodeMurMurFactory.h index 5b0a497..4feeabe 100644 --- a/include/coding/TileCodeMurMurFactory.h +++ b/include/coding/TileCodeMurMurFactory.h @@ -33,15 +33,17 @@ namespace coding { * \brief Factory method for TileCodeMurMur. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template +template class TileCodeMurMurFactory : - public TileCodeHashedFactory { + public TileCodeHashedFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeMurMur> { public: using TileCodeFactory< - D, NUM_TILINGS, TileCodeMurMur>::TileCodeFactory; + D, NUM_TILINGS, WEIGHT_CONT, TileCodeMurMur>::TileCodeFactory; using TileCodeHashedFactory< - D, NUM_TILINGS, TileCodeMurMur>::TileCodeFactory; + D, NUM_TILINGS, WEIGHT_CONT, TileCodeMurMur>::TileCodeFactory; }; } // namespace coding diff --git a/include/coding/TileCodeSuperFastHash.h b/include/coding/TileCodeSuperFastHash.h index 000e943..3610082 100644 --- a/include/coding/TileCodeSuperFastHash.h +++ b/include/coding/TileCodeSuperFastHash.h @@ -35,11 +35,12 @@ namespace coding { * \brief Tile Code using SuperFastHash. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template -class TileCodeSuperFastHash : public TileCode { +template +class TileCodeSuperFastHash : public TileCode { public: - using TileCode::TileCode; + using TileCode::TileCode; /** * @param dimensionalInfos @@ -55,18 +56,16 @@ class TileCodeSuperFastHash : public TileCode { const floatArray& parameters) const override; }; -template -TileCodeSuperFastHash::TileCodeSuperFastHash( +template +TileCodeSuperFastHash::TileCodeSuperFastHash( const array, D>& dimensionalInfos, size_t sizeHint) : - TileCode::TileCode(dimensionalInfos) { - if (sizeHint > this->_sizeCache) { - this->_sizeCache = sizeHint; - } + TileCode::TileCode(dimensionalInfos, sizeHint) { } -template -FEATURE_VECTOR TileCodeSuperFastHash::getFeatureVector( +template +FEATURE_VECTOR +TileCodeSuperFastHash::getFeatureVector( const floatArray& parameters) const { vector tileComponents(this->getDimension() + 1); FEATURE_VECTOR fv; diff --git a/include/coding/TileCodeSuperFastHashFactory.h b/include/coding/TileCodeSuperFastHashFactory.h index aa98c51..e671215 100644 --- a/include/coding/TileCodeSuperFastHashFactory.h +++ b/include/coding/TileCodeSuperFastHashFactory.h @@ -33,15 +33,17 @@ namespace coding { * \brief Factory method for TileCodeSuperFastHash. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template +template class TileCodeSuperFastHashFactory : - public TileCodeHashedFactory { + public TileCodeHashedFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeSuperFastHash> { public: using TileCodeFactory< - D, NUM_TILINGS, TileCodeSuperFastHash>::TileCodeFactory; + D, NUM_TILINGS, WEIGHT_CONT, TileCodeSuperFastHash>::TileCodeFactory; using TileCodeHashedFactory< - D, NUM_TILINGS, TileCodeSuperFastHash>::TileCodeFactory; + D, NUM_TILINGS, WEIGHT_CONT, TileCodeSuperFastHash>::TileCodeFactory; }; } // namespace coding diff --git a/include/coding/TileCodeUNH.h b/include/coding/TileCodeUNH.h index 0d99255..b141109 100644 --- a/include/coding/TileCodeUNH.h +++ b/include/coding/TileCodeUNH.h @@ -33,11 +33,12 @@ namespace coding { * \brief Tile Code using University New Hampshire hash, or UNH. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template +template class TileCodeUNH : public TileCode { public: - using TileCode::TileCode; + using TileCode::TileCode; TileCodeUNH(const array, D>& dimensionalInfos, size_t sizeHint); @@ -54,15 +55,11 @@ class TileCodeUNH : public TileCode { vector _normalization; }; -template -TileCodeUNH::TileCodeUNH( +template +TileCodeUNH::TileCodeUNH( const array, D>& dimensionalInfos, size_t sizeHint) : - TileCode::TileCode(dimensionalInfos) { - if (sizeHint > this->_sizeCache) { - this->_sizeCache = sizeHint; - } - + TileCode::TileCode(dimensionalInfos, sizeHint) { _normalization = vector(this->getDimension()); for (size_t i = 0; i < this->_dimensionalInfos.size(); i++) { @@ -71,8 +68,8 @@ TileCodeUNH::TileCodeUNH( } } -template -FEATURE_VECTOR TileCodeUNH::getFeatureVector( +template +FEATURE_VECTOR TileCodeUNH::getFeatureVector( const floatArray& parameters) const { FEATURE_VECTOR fv; diff --git a/include/coding/TileCodeUNHFactory.h b/include/coding/TileCodeUNHFactory.h index 903f25a..ef111dd 100644 --- a/include/coding/TileCodeUNHFactory.h +++ b/include/coding/TileCodeUNHFactory.h @@ -33,13 +33,17 @@ namespace coding { * \brief Factory method for TileCodeUNH. * \tparam D Number of dimension. * \tparam NUM_TILINGS Number of tilings. + * \tparam WEIGHT_CONT The container object to store the weights. */ -template +template class TileCodeUNHFactory : - public TileCodeHashedFactory { + public TileCodeHashedFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeUNH> { public: - using TileCodeFactory::TileCodeFactory; - using TileCodeHashedFactory::TileCodeFactory; + using TileCodeFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeUNH>::TileCodeFactory; + using TileCodeHashedFactory< + D, NUM_TILINGS, WEIGHT_CONT, TileCodeUNH>::TileCodeFactory; }; } // namespace coding diff --git a/include/coding/container/TileCodeContainer.h b/include/coding/container/TileCodeContainer.h new file mode 100644 index 0000000..03b4848 --- /dev/null +++ b/include/coding/container/TileCodeContainer.h @@ -0,0 +1,423 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef ENABLE_DB + +#include +#include +#include +#include +#include +#include + +#include "../../declares.h" +#include "../../utility/CRUDInterface.h" +#include "TileCodeContainerSegment.h" + +using std::string; +using std::shared_ptr; +using std::tuple; +using std::exception; +using std::vector; + +namespace rl { +namespace coding { + +/*!\class TileCodeContainer + * \brief Wraps TileCodeContainer. + * \tparam ID_CHARS The id of the TileCodeContainer. + * + * Container type for TileCode weight vectors. + */ +template +class TileCodeContainer; + +/*!\class spTileCodeContainer + * \brief Wraps TileCodeContainer. + * \tparam ID_CHARS + */ +template +using spTileCodeContainer = shared_ptr>; + +using TileCodeContainerAllocator = std::allocator; + +template +class TileCodeContainer : + public utility::CRUDInterface { + public: + typedef TileCodeContainerAllocator allocator_type; + typedef typename TileCodeContainerAllocator::value_type value_type; + + // We did not use the allocator here to avoid saving this in ram, defeating + // the purpose of using db. + typedef TileCodeContainerCell reference; + typedef TileCodeContainerCell const_reference; + + typedef typename TileCodeContainerAllocator::difference_type difference_type; + typedef typename TileCodeContainerAllocator::size_type size_type; + + public: + /** + * Creates an entry in db if not exist yet, otherwise + * retrieves the existing db. + * + * @param size Number of weight vector in this container. + * @param iniitialValue Initial value. + */ + TileCodeContainer(size_t size, FLOAT initialValue); + TileCodeContainer(); + + /** + * @return true if already in db. + */ + bool created() const; + + void create() override; + void read() override; + void update() override; + void delete2() override; + + /** + * @return ID in string. + */ + string getID() const; + + /** + * @return segment count. + */ + size_t getSegmentCount() const; + + /** + * @return Array of TileCodeContainerSegment. + */ + std::vector getSegments() const; + + /** + * @return TileCodeContainerCell entry in db. + */ + TileCodeContainerCell operator[](size_t i) const; + TileCodeContainerCell at(size_t i) const; + + protected: + void _deleteAllSegments(); + + public: + /** + * Initialize the underlying database for TileCodeContainer. + */ + static void createSchema(); + + /** + * Destroys the underlying database for TileCodeContainer. + */ + static void destroySchema(); + + protected: + static void _createTileContainerTable(); + + static string _insertTileCodeContainer(size_t size, FLOAT initialValue); + + static tuple _getTileContainer(string uuid) + throw(exception); + + protected: + size_t _size; + FLOAT _initialValue = 0.0F; + vector _segments; +}; + +template +TileCodeContainer::TileCodeContainer( + size_t size, FLOAT initialValue) : + _size(size), + _initialValue(initialValue) { + db::initialize(); + TileCodeContainer::createSchema(); + + if (created()) { + read(); + } else { + create(); + } +} + +template +TileCodeContainer::TileCodeContainer() { + db::initialize(); + TileCodeContainer::createSchema(); + read(); +} + +template +bool TileCodeContainer::created() const { + string stmtStr = "" + "SELECT COUNT(*) as count\n" + "FROM rl.tilecodecontainer\n" + "WHERE id='" + getID() + "';\n"; + + CassStatement *stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture *resultFuture = + cass_session_execute(db::session, stmt); + + const CassResult *result = cass_future_get_result(resultFuture); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + throw "Query failed."; + } + + // The future can be freed immediately after getting the result object. + cass_future_free(resultFuture); + + // This can be used to retrieve on the first row of the result. + const CassRow *row = cass_result_first_row(result); + + cass_int64_t size; + // Get the column value of "size" by name. + cass_value_get_int64(cass_row_get_column_by_name(row, "count"), &size); + + // This will free the result as well as the string pointed to by 'key'. + cass_result_free(result); + + return size; +} + +template +void TileCodeContainer::create() { + CassStatement* stmt = + cass_statement_new(db::InsertTileCodeContainer.c_str(), 3); + + // Bind the values using the indices of the bind variables. + cass_statement_bind_string(stmt, 0, getID().c_str()); + cass_statement_bind_int64(stmt, 1, _size); + cass_statement_bind_double(stmt, 2, _initialValue); + + CassFuture* queryFuture = cass_session_execute( + db::session, stmt); + + // Statement objects can be freed immediately after being executed. + cass_statement_free(stmt); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(queryFuture); + + if (rc != CASS_OK) { + printf("Query result: %s\n", cass_error_desc(rc)); + throw "Error creating TileCodeContainer"; + } + + cass_future_free(queryFuture); + + // Create the segments. + size_t segmentCount = getSegmentCount(); +#ifdef DEBUG + std::cout << "Segment count: " << segmentCount << std::endl; +#endif + for (size_t i = 0; i < segmentCount; i++) { + TileCodeContainerSegment tccs(getID(), SEGMENT_SIZE, i); + } + + _segments = getSegments(); +} + +template +void TileCodeContainer::read() { + string stmtStr = "" + "SELECT size, initialValue\n" + "FROM rl.tilecodecontainer\n" + "WHERE id='" + getID() + "';\n"; + CassStatement *stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture *resultFuture = + cass_session_execute(db::session, stmt); + + const CassResult *result = cass_future_get_result(resultFuture); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + throw "Query failed."; + } + + // The future can be freed immediately after getting the result object. + cass_future_free(resultFuture); + + // This can be used to retrieve on the first row of the result. + const CassRow *row = cass_result_first_row(result); + + cass_int64_t size; + // Get the column value of "size" by name. + cass_value_get_int64(cass_row_get_column_by_name(row, "size"), &size); + + cass_double_t initialValue; + // Get the column value of "initialValue" by name. + cass_value_get_double( + cass_row_get_column_by_name(row, "initialValue"), &initialValue); + + // This will free the result as well as the string pointed to by 'key'. + cass_result_free(result); + + this->_size = size; + this->_initialValue = initialValue; + _segments = getSegments(); +} + +template +void TileCodeContainer::update() { +} + +template +void TileCodeContainer::delete2() { + this->_deleteAllSegments(); + + string stmtStr = "" + "DELETE FROM rl.tilecodecontainer\n" + "WHERE id='" + getID() + "';\n"; + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(resultFuture); + + const CassResult* result = cass_future_get_result(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + throw "Query failed."; + } + + cass_statement_free(stmt); + cass_future_free(resultFuture); +} + +template +string TileCodeContainer::getID() const { + std::string str{ { ID_CHARS... } }; + return str; +} + +template +size_t TileCodeContainer::getSegmentCount() const { + return std::ceil( + static_cast(_size) / static_cast(rl::SEGMENT_SIZE)); +} + +template +vector +TileCodeContainer::getSegments() const { + vector tileCodeContainerSegments; + string stmtStr = "" + "SELECT tileCodeContainerId, segmentIndex\n" + "FROM rl.tilecodecontainersegment\n" + "WHERE tileCodeContainerId='" + getID() + "'\n" + "ORDER BY segmentIndex ASC\n" + "ALLOW FILTERING;\n"; + + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(resultFuture); + + const CassResult* result = cass_future_get_result(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + throw "Query failed."; + } + + CassIterator* iterator = cass_iterator_from_result(result); + + while (cass_iterator_next(iterator)) { + const CassRow* row = cass_iterator_get_row(iterator); + // Retrieve and use values from the row. + + cass_int64_t segmentIndex; + cass_value_get_int64( + cass_row_get_column_by_name(row, "segmentIndex"), &segmentIndex); + + tileCodeContainerSegments.push_back( + spTileCodeContainerSegment( + new TileCodeContainerSegment(getID(), segmentIndex))); + } + + // This will free the result as well as the string pointed to by 'key'. + cass_iterator_free(iterator); + cass_result_free(result); + + return tileCodeContainerSegments; +} + +template +void TileCodeContainer::createSchema() { + TileCodeContainer::_createTileContainerTable(); +} + +template +void TileCodeContainer::_createTileContainerTable() { + db::executeStatement(db::TileCodeContainer.c_str()); +} + +template +void TileCodeContainer::_deleteAllSegments() { + for (auto segment : this->getSegments()) { + segment->delete2(); + } +} + +template +TileCodeContainerCell +TileCodeContainer::operator[](size_t i) const { + return this->at(i); +} + +template +TileCodeContainerCell TileCodeContainer::at(size_t i) const { + size_t segmentIndex = i / SEGMENT_SIZE; + size_t innerSegmentIndex = i % SEGMENT_SIZE; + return _segments[segmentIndex]->at(innerSegmentIndex); +} + +} // namespace coding +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/include/coding/container/TileCodeContainerCell.h b/include/coding/container/TileCodeContainerCell.h new file mode 100644 index 0000000..95c5c0f --- /dev/null +++ b/include/coding/container/TileCodeContainerCell.h @@ -0,0 +1,77 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef ENABLE_DB + +#include +#include +#include + +#include "../../declares.h" + +using std::string; +using std::shared_ptr; + +namespace rl { +namespace coding { + +/*!\class TileCodeContainerCell + * \brief Represents a single data that can be modified or retrieved. + */ +class TileCodeContainerCell { + public: + /** + * @param tileCodeContainerId ID of the parent TileCodeContainer. + * @param segmentIndex Index of this segment. + * @param index Index of the data within segment. + */ + TileCodeContainerCell(string tileCodeContainerId, + size_t segmentIndex, + size_t index); + + /** + * @return the data being represented in this cell. + */ + FLOAT get() const; + + /** + * @param val The value to set this cell. + */ + void set(FLOAT val); + + // operator overload. + TileCodeContainerCell& operator=(FLOAT val); + TileCodeContainerCell& operator+=(FLOAT val); + + // Implicit conversion to FLOAT. + operator FLOAT() const; + + protected: + string _tileCodeContainerId; + size_t _segmentIndex; + size_t _index; +}; + +using spTileCodeContainerCell = shared_ptr; + +} // namespace coding +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/include/coding/container/TileCodeContainerSegment.h b/include/coding/container/TileCodeContainerSegment.h new file mode 100644 index 0000000..3dec4d2 --- /dev/null +++ b/include/coding/container/TileCodeContainerSegment.h @@ -0,0 +1,93 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef ENABLE_DB + +#include +#include +#include + +#include "../../declares.h" +#include "../../utility/CRUDInterface.h" +#include "../../utility/IndexAccessorInterface.h" +#include "TileCodeContainerCell.h" + +using std::string; +using std::vector; +using std::shared_ptr; + +namespace rl { +namespace coding { + +/*!\class TileCodeContainerSegment + * \brief Represents a group of data accessible via TileCodeContainerCell. + */ +class TileCodeContainerSegment : + public utility::CRUDInterface { + public: + /** + * @param tileCodeContainerId Id of parent TileCodeContainer + * @param size Size of the segment. + * @param index Index of this segment. + */ + TileCodeContainerSegment( + const string& tileCodeContainerId, + size_t size, + size_t index); + + /** + * @param size Size of the segment. + * @param index Index of this segment. + */ + TileCodeContainerSegment( + const string& tileCodeContainer, + size_t index); + + void create() override; + void read() override; + void update() override; + void delete2() override; + + TileCodeContainerCell operator[](size_t i) const; + TileCodeContainerCell at(size_t i) const; + + public: + /** + * Initialize the underlying database for TileCodeContainer. + */ + static void createSchema(); + + /** + * Destroys the underlying database for TileCodeContainer. + */ + static void destroySchema(); + + protected: + string _tileCodeContainerId; + size_t _size; + size_t _segmentIndex; +}; + +using spTileCodeContainerSegment = shared_ptr; + +} // namespace coding +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/include/db/definition.h b/include/db/definition.h new file mode 100644 index 0000000..f66da52 --- /dev/null +++ b/include/db/definition.h @@ -0,0 +1,51 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef ENABLE_DB + +#include + +using std::string; + +namespace rl { +namespace db { + +/*!\var RLKeySpace + * + * Schema for the rl keyspace. + */ +extern const string RLKeySpace; + +/*!\var TilecodeContainer + * + * Schema for rl.tilecodecontainer. + */ +extern const string TileCodeContainer; + +/*!\var TilecodeContainerSegment + * + * Schema for rl.tilecodecontainersegment. + */ +extern const string TileCodeContainerSegment; + +} // namespace db +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/include/db/manipulation.h b/include/db/manipulation.h new file mode 100644 index 0000000..fcb4549 --- /dev/null +++ b/include/db/manipulation.h @@ -0,0 +1,38 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef ENABLE_DB + +#include + +#include "../declares.h" + +using std::string; + +namespace rl { +namespace db { + +extern const string InsertTileCodeContainer; +extern const string InsertTileCodeContainerSegment; + +} // db +} // rl + +#endif // #ifdef ENABLE_DB diff --git a/include/db/utility.h b/include/db/utility.h new file mode 100644 index 0000000..e4e3fce --- /dev/null +++ b/include/db/utility.h @@ -0,0 +1,59 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef ENABLE_DB + +#include +#include + +using std::string; + +namespace rl { +namespace db { + +extern CassCluster* cluster; +extern CassSession* session; +extern CassFuture* connectFuture; + +bool isInitialized(); + +/** + * Initialize connections and schema (if it doesn't exist yet). + */ +void initialize(); + +/** + * Terminate connections. + */ +void terminate(); + +void executeStatement(const string& stmtStr); + +/** + * Creates the keyspace to be used by this app. + */ +void createKeySpace(); + +CassUuid genUuid(); + +} // namespace db +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/include/declares.h b/include/declares.h index e6de82d..bbf7645 100644 --- a/include/declares.h +++ b/include/declares.h @@ -56,17 +56,12 @@ const UINT MAX_EPISODES = 100000; /*! \typedef FEATURE_VECTOR * Feature vector is a data structure for Tile Coding. It is the indices * that contains the data points to be sampled. - * - * TODO(jandres): See #RL-14 */ typedef std::vector FEATURE_VECTOR; /*! \typedef floatFector * \brief A vector of float. * - * TODO(jandres): See #RL-16 - * - * \deprecated */ using floatVector = std::vector; @@ -79,8 +74,6 @@ using floatArray = std::array; /*! \typedef spFloatVector * \brief Wraps floatVector in shared_ptr. - * - * TODO(jandres): See #RL-16 */ using spFloatVector = std::shared_ptr>; @@ -180,5 +173,7 @@ using spActionSet = std::set, spActionComp>; template using spActionValueMap = std::map, FLOAT, spActionComp>; +constexpr size_t SEGMENT_SIZE = 100; + } // namespace rl diff --git a/include/utility/CRUDInterface.h b/include/utility/CRUDInterface.h new file mode 100644 index 0000000..3f41cc2 --- /dev/null +++ b/include/utility/CRUDInterface.h @@ -0,0 +1,33 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +namespace rl { +namespace utility { + +class CRUDInterface { + public: + virtual void create() = 0; + virtual void read() = 0; + virtual void update() = 0; + virtual void delete2() = 0; +}; + +} // namespace utility +} // namespace rl diff --git a/include/utility/IndexAccessorInterface.h b/include/utility/IndexAccessorInterface.h new file mode 100644 index 0000000..94c4b94 --- /dev/null +++ b/include/utility/IndexAccessorInterface.h @@ -0,0 +1,48 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR CONT PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received CONT copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#include + +namespace rl { +namespace utility { + +template +class IndexAccessorInterface { + public: + virtual typename CONT::reference operator[](size_t i); + virtual typename CONT::value_type operator[](size_t i) const; + + virtual typename CONT::reference at(size_t i) = 0; + virtual typename CONT::value_type at(size_t i) const = 0; +}; + +template +typename CONT::reference IndexAccessorInterface::operator[](size_t i) { + return this->at(i); +} + +template +typename CONT::value_type +IndexAccessorInterface::operator[](size_t i) const { + return this->at(i); +} + +} // namespace utility +} // namespace rl diff --git a/scripts/runlint.sh b/scripts/runlint.sh index 0791715..15b406e 100755 --- a/scripts/runlint.sh +++ b/scripts/runlint.sh @@ -2,8 +2,8 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -# TODO(jandres): Ignore runtime/references error atm. Have a story to convert these to smart pointers or something. python ${DIR}/cpplint.py \ --recursive \ --extensions=h,cpp \ +--filter=-runtime/string \ ${DIR}/../include/ ${DIR}/../src/ ${DIR}/../test/ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0bc63d1..4573dec 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,7 +2,9 @@ include_directories(${CMAKE_SOURCE_DIR}/include) # Intrinsic Sources. add_subdirectory(agent) +add_subdirectory(coding) add_subdirectory(hash) +add_subdirectory(db) # Non-Intrinsic Sources. add_library(rl @@ -12,7 +14,13 @@ add_library(rl $ $ - $) + $ + $ + $) target_link_libraries(rl ${PTHREAD_LIB}) +if (ENABLE_DB STREQUAL "true") + target_link_libraries(rl cassandra) +endif() + install(TARGETS rl DESTINATION lib) diff --git a/src/coding/CMakeLists.txt b/src/coding/CMakeLists.txt new file mode 100644 index 0000000..bbbe5a9 --- /dev/null +++ b/src/coding/CMakeLists.txt @@ -0,0 +1,6 @@ +include_directories(${CMAKE_SOURCE_DIR}/include) + +add_subdirectory(container) + +#file(GLOB SRC_CODING_FILES "*.cpp") +#add_library(rlCoding OBJECT ${SRC_CODING_FILES}) diff --git a/src/coding/container/CMakeLists.txt b/src/coding/container/CMakeLists.txt new file mode 100644 index 0000000..409cd45 --- /dev/null +++ b/src/coding/container/CMakeLists.txt @@ -0,0 +1,4 @@ +include_directories(${CMAKE_SOURCE_DIR}/include) + +file(GLOB SRC_CODING_CONTAINER_FILES "*.cpp") +add_library(rlCodingContainer OBJECT ${SRC_CODING_CONTAINER_FILES}) diff --git a/src/coding/container/TileCodeContainerCell.cpp b/src/coding/container/TileCodeContainerCell.cpp new file mode 100644 index 0000000..fd9a195 --- /dev/null +++ b/src/coding/container/TileCodeContainerCell.cpp @@ -0,0 +1,131 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifdef ENABLE_DB + +#include +#include + +#include "db/utility.h" +#include "coding/container/TileCodeContainerCell.h" + +using std::string; + +namespace rl { +namespace coding { + +TileCodeContainerCell::TileCodeContainerCell( + string tileCodeContainerId, size_t segmentIndex, size_t index) : + _tileCodeContainerId(tileCodeContainerId), + _segmentIndex(segmentIndex), + _index(index) { +} + +FLOAT TileCodeContainerCell::get() const { + const string dataField = "data" + std::to_string(_index); + string stmtStr = "" + "SELECT tileCodeContainerId, segmentIndex, " + dataField + "\n" + "FROM rl.tilecodecontainersegment\n" + "WHERE tileCodeContainerId = '" + _tileCodeContainerId + "' AND\n" + " segmentIndex = " + std::to_string(_segmentIndex) + ";\n"; + + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(resultFuture); + const CassResult* result = cass_future_get_result(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + throw "Query failed."; + } + + // The future can be freed immediately after getting the result object. + cass_future_free(resultFuture); + cass_statement_free(stmt); + + // This can be used to retrieve on the first row of the result. + const CassRow* row = cass_result_first_row(result); + + cass_double_t outVal = 0.0F; + cass_value_get_double( + cass_row_get_column_by_name(row, dataField.c_str()), &outVal); + + // Free memory. + cass_result_free(result); + + return outVal; +} + +void TileCodeContainerCell::set(FLOAT val) { + const string dataField = "data" + std::to_string(_index); + string stmtStr = "" + "UPDATE rl.tilecodecontainersegment\n" + "SET " + dataField + " = " + std::to_string(val) + "\n" + "WHERE tileCodeContainerId='" + _tileCodeContainerId + "' AND\n" + " segmentIndex = " + std::to_string(_segmentIndex) + ";\n"; + + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + +#ifdef DEBUG + // Block till the result is back. + CassError rc = cass_future_error_code(resultFuture); + const CassResult* result = cass_future_get_result(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + cass_statement_free(stmt); + throw "Query failed."; + } +#endif + + // The future can be freed immediately after getting the result object. + cass_future_free(resultFuture); + cass_statement_free(stmt); +} + +TileCodeContainerCell& TileCodeContainerCell::operator=(FLOAT val) { + this->set(val); + return *this; +} + +TileCodeContainerCell& TileCodeContainerCell::operator+=(FLOAT val) { + this->set(this->get() + val); + return *this; +} + +TileCodeContainerCell::operator FLOAT() const { + return get(); +} + +} // namespace coding +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/src/coding/container/TileCodeContainerSegment.cpp b/src/coding/container/TileCodeContainerSegment.cpp new file mode 100644 index 0000000..02c1f2c --- /dev/null +++ b/src/coding/container/TileCodeContainerSegment.cpp @@ -0,0 +1,165 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifdef ENABLE_DB + +#include "coding/container/TileCodeContainerSegment.h" +#include "db/definition.h" +#include "db/manipulation.h" +#include "db/utility.h" + +namespace rl { +namespace coding { + +TileCodeContainerSegment::TileCodeContainerSegment( + const string& tileCodeContainerId, + size_t size, + size_t index) : + _tileCodeContainerId(tileCodeContainerId), + _size(size), + _segmentIndex(index) { + db::initialize(); + TileCodeContainerSegment::createSchema(); + create(); +} + +TileCodeContainerSegment::TileCodeContainerSegment( + const string& tileCodeContainerId, size_t index) : + _tileCodeContainerId(tileCodeContainerId), + _segmentIndex(index) { + db::initialize(); + TileCodeContainerSegment::createSchema(); + read(); +} + +void TileCodeContainerSegment::create() { + CassStatement* stmt = + cass_statement_new(db::InsertTileCodeContainerSegment.c_str(), 103); + + // Bind the values using the indices of the bind variables. + cass_statement_bind_string(stmt, 0, _tileCodeContainerId.c_str()); + cass_statement_bind_int64(stmt, 1, _size); + cass_statement_bind_int64(stmt, 2, _segmentIndex); + for (size_t i = 0; i < SEGMENT_SIZE; i++) { + cass_statement_bind_double(stmt, 3 + i, 0.0F); + } + + CassFuture* queryFuture = cass_session_execute( + db::session, stmt); + + // Statement objects can be freed immediately after being executed. + cass_statement_free(stmt); + + // This will block until the query has finished. + CassError rc = cass_future_error_code(queryFuture); + + if (rc != CASS_OK) { + printf("Query result: %s\n", cass_error_desc(rc)); + cass_future_free(queryFuture); + // TODO(jandres): Exception for failed queries. + throw "Query failed."; + } + + cass_future_free(queryFuture); +} + +void TileCodeContainerSegment::read() { + string stmtStr = "" + "SELECT tileCodeContainerId, size, segmentIndex \n" + "FROM rl.tilecodecontainersegment\n" + "WHERE tileCodeContainerId = '" + _tileCodeContainerId + "' AND\n" + " segmentIndex = " + std::to_string(_segmentIndex) + ";\n"; + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + + CassError rc = cass_future_error_code(resultFuture); + const CassResult* result = cass_future_get_result(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + printf("Query result [read]: %s\n", cass_error_desc(rc)); + cass_future_free(resultFuture); + throw "Query failed."; + } + + // The future can be freed immediately after getting the result object. + cass_future_free(resultFuture); + cass_statement_free(stmt); + + // This can be used to retrieve on the first row of the result. + const CassRow* row = cass_result_first_row(result); + + cass_int64_t size; + // Get the column value of "size" by name. + cass_value_get_int64(cass_row_get_column_by_name(row, "size"), &size); + _size = size; + + cass_int64_t segmentIndex; + cass_value_get_int64( + cass_row_get_column_by_name(row, "segmentIndex"), &segmentIndex); + _segmentIndex = segmentIndex; + + // This will free the result as well as the string pointed to by 'key'. + cass_result_free(result); +} + +void TileCodeContainerSegment::update() { +} + +void TileCodeContainerSegment::delete2() { + string stmtStr = "" + "DELETE FROM rl.tilecodecontainersegment\n" + "WHERE tileCodeContainerId = '" + _tileCodeContainerId + "' AND\n" + " segmentIndex = " + std::to_string(_segmentIndex) + ";\n"; + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + + const CassResult* result = cass_future_get_result(resultFuture); + + // If there was an error then the result won't be available. + if (result == NULL) { + /* Handle error */ + cass_future_free(resultFuture); + throw "Query failed."; + } + + cass_statement_free(stmt); + cass_future_free(resultFuture); +} + +TileCodeContainerCell TileCodeContainerSegment::operator[](size_t i) const { + return this->at(i); +} + +TileCodeContainerCell TileCodeContainerSegment::at(size_t i) const { + return TileCodeContainerCell(_tileCodeContainerId, _segmentIndex, i); +} + +void TileCodeContainerSegment::createSchema() { + db::executeStatement(db::TileCodeContainerSegment.c_str()); +} + +} // namespace coding +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/src/db/CMakeLists.txt b/src/db/CMakeLists.txt new file mode 100644 index 0000000..07eb154 --- /dev/null +++ b/src/db/CMakeLists.txt @@ -0,0 +1,4 @@ +include_directories(${CMAKE_SOURCE_DIR}/include) + +file(GLOB SRC_DB_FILES "*.cpp") +add_library(rlDB OBJECT ${SRC_DB_FILES}) diff --git a/src/db/definition.cpp b/src/db/definition.cpp new file mode 100644 index 0000000..60055da --- /dev/null +++ b/src/db/definition.cpp @@ -0,0 +1,150 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifdef ENABLE_DB + +#include "db/definition.h" + +namespace rl { +namespace db { + +const string RLKeySpace = "" + "CREATE KEYSPACE IF NOT EXISTS rl\n" + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 3};"; + +const string TileCodeContainer = "" + "CREATE TABLE IF NOT EXISTS rl.tilecodecontainer (\n" + " id text PRIMARY KEY,\n" + " size bigint,\n" + " initialValue double\n" + ");"; + +const string TileCodeContainerSegment = "" + "CREATE TABLE IF NOT EXISTS rl.tilecodecontainersegment (\n" + " tileCodeContainerId text,\n" + " size bigint,\n" + " segmentIndex bigint,\n" + + " data0 double,\n" + " data1 double,\n" + " data2 double,\n" + " data3 double,\n" + " data4 double,\n" + " data5 double,\n" + " data6 double,\n" + " data7 double,\n" + " data8 double,\n" + " data9 double,\n" + " data10 double,\n" + " data11 double,\n" + " data12 double,\n" + " data13 double,\n" + " data14 double,\n" + " data15 double,\n" + " data16 double,\n" + " data17 double,\n" + " data18 double,\n" + " data19 double,\n" + " data20 double,\n" + " data21 double,\n" + " data22 double,\n" + " data23 double,\n" + " data24 double,\n" + " data25 double,\n" + " data26 double,\n" + " data27 double,\n" + " data28 double,\n" + " data29 double,\n" + " data30 double,\n" + " data31 double,\n" + " data32 double,\n" + " data33 double,\n" + " data34 double,\n" + " data35 double,\n" + " data36 double,\n" + " data37 double,\n" + " data38 double,\n" + " data39 double,\n" + " data40 double,\n" + " data41 double,\n" + " data42 double,\n" + " data43 double,\n" + " data44 double,\n" + " data45 double,\n" + " data46 double,\n" + " data47 double,\n" + " data48 double,\n" + " data49 double,\n" + " data50 double,\n" + " data51 double,\n" + " data52 double,\n" + " data53 double,\n" + " data54 double,\n" + " data55 double,\n" + " data56 double,\n" + " data57 double,\n" + " data58 double,\n" + " data59 double,\n" + " data60 double,\n" + " data61 double,\n" + " data62 double,\n" + " data63 double,\n" + " data64 double,\n" + " data65 double,\n" + " data66 double,\n" + " data67 double,\n" + " data68 double,\n" + " data69 double,\n" + " data70 double,\n" + " data71 double,\n" + " data72 double,\n" + " data73 double,\n" + " data74 double,\n" + " data75 double,\n" + " data76 double,\n" + " data77 double,\n" + " data78 double,\n" + " data79 double,\n" + " data80 double,\n" + " data81 double,\n" + " data82 double,\n" + " data83 double,\n" + " data84 double,\n" + " data85 double,\n" + " data86 double,\n" + " data87 double,\n" + " data88 double,\n" + " data89 double,\n" + " data90 double,\n" + " data91 double,\n" + " data92 double,\n" + " data93 double,\n" + " data94 double,\n" + " data95 double,\n" + " data96 double,\n" + " data97 double,\n" + " data98 double,\n" + " data99 double,\n" + + " PRIMARY KEY (tileCodeContainerId, segmentIndex)" + ") WITH caching = { 'keys' : 'NONE', 'rows_per_partition' : '120' };"; + +} // namespace db +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/src/db/manipulation.cpp b/src/db/manipulation.cpp new file mode 100644 index 0000000..96fb833 --- /dev/null +++ b/src/db/manipulation.cpp @@ -0,0 +1,153 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifdef ENABLE_DB + +#include "db/manipulation.h" + +namespace rl { +namespace db { + +const string InsertTileCodeContainer = "" + "INSERT INTO rl.tilecodecontainer (" + "id, size, initialValue) VALUES (?, ?, ?);"; + +const string InsertTileCodeContainerSegment = "" + "INSERT INTO rl.tilecodecontainersegment (" + "tileCodeContainerId,\n" + "size,\n" + "segmentIndex,\n" + + "data0,\n" + "data1,\n" + "data2,\n" + "data3,\n" + "data4,\n" + "data5,\n" + "data6,\n" + "data7,\n" + "data8,\n" + "data9,\n" + "data10,\n" + "data11,\n" + "data12,\n" + "data13,\n" + "data14,\n" + "data15,\n" + "data16,\n" + "data17,\n" + "data18,\n" + "data19,\n" + "data20,\n" + "data21,\n" + "data22,\n" + "data23,\n" + "data24,\n" + "data25,\n" + "data26,\n" + "data27,\n" + "data28,\n" + "data29,\n" + "data30,\n" + "data31,\n" + "data32,\n" + "data33,\n" + "data34,\n" + "data35,\n" + "data36,\n" + "data37,\n" + "data38,\n" + "data39,\n" + "data40,\n" + "data41,\n" + "data42,\n" + "data43,\n" + "data44,\n" + "data45,\n" + "data46,\n" + "data47,\n" + "data48,\n" + "data49,\n" + "data50,\n" + "data51,\n" + "data52,\n" + "data53,\n" + "data54,\n" + "data55,\n" + "data56,\n" + "data57,\n" + "data58,\n" + "data59,\n" + "data60,\n" + "data61,\n" + "data62,\n" + "data63,\n" + "data64,\n" + "data65,\n" + "data66,\n" + "data67,\n" + "data68,\n" + "data69,\n" + "data70,\n" + "data71,\n" + "data72,\n" + "data73,\n" + "data74,\n" + "data75,\n" + "data76,\n" + "data77,\n" + "data78,\n" + "data79,\n" + "data80,\n" + "data81,\n" + "data82,\n" + "data83,\n" + "data84,\n" + "data85,\n" + "data86,\n" + "data87,\n" + "data88,\n" + "data89,\n" + "data90,\n" + "data91,\n" + "data92,\n" + "data93,\n" + "data94,\n" + "data95,\n" + "data96,\n" + "data97,\n" + "data98,\n" + "data99 \n" + + ") VALUES (?, ?, ?, " + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?," + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?" + ");"; + +} // namespace db +} // namespace rl + +#endif // #if ENABLE_DB diff --git a/src/db/utility.cpp b/src/db/utility.cpp new file mode 100644 index 0000000..3b25485 --- /dev/null +++ b/src/db/utility.cpp @@ -0,0 +1,128 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifdef ENABLE_DB + +#include + +#include "db/utility.h" +#include "db/definition.h" + +namespace rl { +namespace db { + +CassCluster* cluster = nullptr; +CassSession* session = nullptr; +CassFuture* connectFuture = nullptr; + +void initialize() { + if (db::isInitialized()) { + return; + } + + // Setup and connect to cluster. + rl::db::cluster = cass_cluster_new(); + rl::db::session = cass_session_new(); + + // Add contact points. + // TODO(jandres): Make this a config file. + cass_cluster_set_contact_points(db::cluster, "127.0.0.1"); + + // Provide the cluster object as configuration to connect the session. + db::connectFuture = cass_session_connect( + db::session, db::cluster); + + // This operation will block until the result is ready. + CassError rc = cass_future_error_code(rl::db::connectFuture); + +#ifdef DEBUG + std::cout << "TileCodeContainer::initialize" << std::endl; + + if (rc != CASS_OK) { + rl::db::terminate(); + + auto errorStr = cass_error_desc(rc); + std::cerr << "Connect result: " << errorStr << std::endl; + } +#endif + + // Create key schemas. + db::createKeySpace(); +} + +void terminate() { +#ifdef DEBUG + std::cout << "TileCodeContainer::terminate" << std::endl; +#endif + + // Don't do a thing if already terminated. + if (!isInitialized()) { + return; + } + + cass_future_wait(db::connectFuture); + cass_future_free(db::connectFuture); + db::connectFuture = nullptr; + cass_session_free(db::session); + db::session = nullptr; + cass_cluster_free(db::cluster); + db::cluster = nullptr; +} + +bool isInitialized() { + return db::connectFuture != nullptr && + db::session != nullptr && + db::cluster != nullptr; +} + +void executeStatement(const string& stmtStr) { + CassStatement* stmt = cass_statement_new(stmtStr.c_str(), 0); + + CassFuture* resultFuture = + cass_session_execute(db::session, stmt); + + if (cass_future_error_code(resultFuture) == CASS_OK) { + } else { + // Deal With Error. + /* Handle error */ + const char* message; + size_t message_length; + cass_future_error_message(resultFuture, &message, &message_length); + std::cerr << message << std::endl; + } + + cass_statement_free(stmt); + cass_future_free(resultFuture); +} + +void createKeySpace() { + db::executeStatement(db::RLKeySpace.c_str()); +} + +CassUuid genUuid() { + CassUuidGen* uuidGen = cass_uuid_gen_new(); + CassUuid id; + cass_uuid_gen_from_time(uuidGen, 1234, &id); + cass_uuid_gen_free(uuidGen); + return id; +} + +} // namespace db +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp b/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp index 8f74bd1..69b1623 100644 --- a/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp +++ b/test/src/algorithm/gradient-descent/SarsaETGD_test.cpp @@ -63,13 +63,13 @@ SCENARIO("Sarsa Eligibility Traces and Gradient Descent converge to a " rl::coding::DimensionInfo(0.0F, 2.0F, 3, 0.0F) }; - // Setup tile coding with 8 offsets. - auto tileCode = TileCodeCorrectFactory<3, 10>(dimensionalInfoVector).get(); - auto sarsa = - SarsaETGDFactory<3, 10>(tileCode, 0.1F, 1.0F, 0.9F, policy).get(); - rl::agent::AgentGD<3> agent(mce, sarsa); + WHEN("We do multiple episodes with Default weight container (vector)") { + // Setup tile coding with 8 offsets. + auto tileCode = TileCodeCorrectFactory<3, 8>(dimensionalInfoVector).get(); + auto sarsa = + SarsaETGDFactory<3, 8>(tileCode, 0.1F, 1.0F, 0.9F, policy).get(); + rl::agent::AgentGD<3> agent(mce, sarsa); - WHEN("We do multiple episodes") { rl::INT iterationCount = 0; for (rl::INT i = 0; i < 1000; i++) { agent.reset(); diff --git a/test/src/algorithm/gradient-descent/SarsaGD_test.cpp b/test/src/algorithm/gradient-descent/SarsaGD_test.cpp index 5dad1f6..d7c216b 100644 --- a/test/src/algorithm/gradient-descent/SarsaGD_test.cpp +++ b/test/src/algorithm/gradient-descent/SarsaGD_test.cpp @@ -59,18 +59,51 @@ SCENARIO("Sarsa Gradient Descent converge to a solution", rl::coding::DimensionInfo(0.0F, 2.0F, 3, 0.0F) }; - // Setup tile coding with 8 offsets. - auto tileCode = TileCodeCorrectFactory<3, 8>(dimensionalInfoVector).get(); + WHEN("We do multiple episodes with Default weight container (vector)") { + // Setup tile coding with 8 offsets. + auto tileCode = TileCodeCorrectFactory<3, 8>(dimensionalInfoVector).get(); + auto sarsa = + SarsaGDFactory<3, 8>(tileCode, 0.1F, 1.0F, 0.9F, policy).get(); + rl::agent::AgentGD<3> agent(mce, sarsa); - auto sarsa = SarsaGDFactory<3, 8>(tileCode, 0.1F, 1.0F, 0.9F, policy).get(); - rl::agent::AgentGD<3> agent(mce, sarsa); + rl::INT iterationCount = 0; + for (rl::INT i = 0; i < 1000; i++) { + agent.reset(); + + iterationCount = agent.executeEpisode(); + } + + THEN("At the end, we solve the Mountain Car environment in 100 " + "iteration") { + REQUIRE(iterationCount <= 100); + } + } + +#ifdef ENABLE_DB + WHEN("We do multiple episodes with TileCodeContainer.") { + // Setup tile coding with 8 offsets. + auto tileCode = + TileCodeCorrectFactory< + 3, 8, rl::coding::TileCodeContainer<'i', 'd', '3'>>( + dimensionalInfoVector).get(); + auto sarsa = + SarsaGDFactory<3, 8, rl::coding::TileCodeContainer<'i', 'd', '3'>>( + tileCode, 0.1F, 1.0F, 0.9F, policy).get(); + rl::agent::AgentGD<3> agent(mce, sarsa); - WHEN("We do multiple episodes") { rl::INT iterationCount = 0; + +#ifdef DEBUG + std::cout << "Starting episodes with SarasaGD and TileCodeContainer " + "(cassandradb)" << std::endl; +#endif for (rl::INT i = 0; i < 1000; i++) { agent.reset(); iterationCount = agent.executeEpisode(); +#ifdef DEBUG + std::cout << iterationCount << std::endl; +#endif } THEN("At the end, we solve the Mountain Car environment in 100 " @@ -78,5 +111,6 @@ SCENARIO("Sarsa Gradient Descent converge to a solution", REQUIRE(iterationCount <= 100); } } +#endif // #ifdef ENABLE_DB } } diff --git a/test/src/coding/TileCodeContainerCell_test.cpp b/test/src/coding/TileCodeContainerCell_test.cpp new file mode 100644 index 0000000..f4da910 --- /dev/null +++ b/test/src/coding/TileCodeContainerCell_test.cpp @@ -0,0 +1,76 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "rl" +#include "catch.hpp" + +#ifdef ENABLE_DB + +namespace rl { +namespace coding { + +SCENARIO("TileCodeContainerCell, a single weight in TileCodeContainer.", + "[TileCodeContainerCell]") { + GIVEN("A default TileCodeContainerSegment instance") { + rl::coding::TileCodeContainer<'i', 'd', '1'> tcc(1000, 0.0F); + + WHEN("TileCodeContainerCell is initialized with index 0, 23, 50, 99") { + rl::coding::TileCodeContainerCell tccc1(tcc.getID(), 0, 0); + rl::coding::TileCodeContainerCell tccc2(tcc.getID(), 0, 23); + rl::coding::TileCodeContainerCell tccc3(tcc.getID(), 0, 50); + rl::coding::TileCodeContainerCell tccc4(tcc.getID(), 0, 99); + THEN("All have 0 values") { + REQUIRE(tccc1 == 0.0F); + REQUIRE(tccc2 == 0.0F); + REQUIRE(tccc3 == 0.0F); + REQUIRE(tccc4 == 0.0F); + } + } + + tcc.delete2(); + } + + GIVEN("A TileCodeContainerCell instance") { + rl::coding::TileCodeContainer<'i', 'd', '1'> tcc(1000, 0.0F); + + WHEN("I change the value of index 0 from 0.0 to 69.0") { + TileCodeContainerCell tccc(tcc.getID(), 0, 0); + tccc = 69.0F; + + THEN("That index 0 cell is now 69.0") { + REQUIRE(tccc == 69.0F); + } + } + + WHEN("I change the value of index 99 from 0.0 to -37.0") { + TileCodeContainerCell tccc(tcc.getID(), 0, 99); + tccc = -37.0F; + + THEN("That index 0 cell is now 69.0") { + REQUIRE(tccc == -37.0F); + } + } + + tcc.delete2(); + } +} + +} // namespace coding +} // namespace rl + +#endif // #ifdef ENABLE_DB diff --git a/test/src/coding/TileCodeContainerSegment_test.cpp b/test/src/coding/TileCodeContainerSegment_test.cpp new file mode 100644 index 0000000..d3c68d3 --- /dev/null +++ b/test/src/coding/TileCodeContainerSegment_test.cpp @@ -0,0 +1,39 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "rl" +#include "catch.hpp" + +#ifdef ENABLE_DB + +SCENARIO("TileCodeContainerSegment saves segment of the weight vector", + "[TileCodeContainerSegment]") { + GIVEN("TileCodeContainerSegment instance") { + rl::coding::TileCodeContainer<'i', 'd', '1'> tcc(1000, 0.0F); + rl::coding::TileCodeContainerSegment tccs(tcc.getID(), 0); + WHEN("Initialized") { + THEN("An element exist") { + REQUIRE(tccs.at(0) == 0.0F); + } + } + + tcc.delete2(); + } +} + +#endif // #ifdef ENABLE_DB diff --git a/test/src/coding/TileCodeContainer_test.cpp b/test/src/coding/TileCodeContainer_test.cpp new file mode 100644 index 0000000..a12ddb1 --- /dev/null +++ b/test/src/coding/TileCodeContainer_test.cpp @@ -0,0 +1,76 @@ +/** + * rl - Reinforcement Learning + * Copyright (C) 2016 Joey Andres + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "rl" +#include "catch.hpp" + +#ifdef ENABLE_DB + +SCENARIO("TileCodeContainer, encapsulates the weight storage in tile code.", + "[TileCodeContainer]") { + GIVEN("Uninitialized underlying db") { + WHEN("Created") { + THEN("Schema is created") { + rl::db::initialize(); + } + } + } + + GIVEN("Initialized underlying db") { + WHEN("New TileCodeContainer instance is created with the " + "id agnostic constructor") { + rl::coding::TileCodeContainer<'i', 'd', '1'> tcc(1000, 0.0F); + THEN("It is added in db") { + tcc.delete2(); + } + } + + WHEN("Terminate") { + THEN("Connection is ceased.") { + rl::db::terminate(); + } + } + } + + GIVEN("An inserted TileCodeContainer") { + rl::coding::TileCodeContainer<'i', 'd', '2'> tcc(1000, 0.0F); + + WHEN("A member is accessed.") { + THEN("It is first 0.0") { + REQUIRE(tcc[0] == 0.0); + } + } + + WHEN("A member is modified") { + THEN("It is modified") { + tcc[0] = 69.0F; + REQUIRE(tcc[0] == 69.0); + REQUIRE(tcc[1] == 0.0); // Sanity check. + } + } + + WHEN("TileCodeContainer is deleted") { + tcc.delete2(); + THEN("id disappears") { + // Not practically to query cassandra ourselves. + } + } + } +} + +#endif // #ifdef ENABLE_DB