Skip to content

Commit

Permalink
added weights to NN
Browse files Browse the repository at this point in the history
  • Loading branch information
Dpbm committed Feb 1, 2024
1 parent e87e7e6 commit 3891ed2
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 16 deletions.
29 changes: 29 additions & 0 deletions helpers/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#include <iostream>
#include <random>
#include <fstream>
#include "utils.h"


using std::ofstream;
using std::string;
using std::ios;

namespace Utils {
double random(int start, int end){
std::random_device device;
Expand All @@ -14,4 +21,26 @@ namespace Utils {
unsigned int rand = Utils::random(0, max_range);
return rand-(rand%factor);
}

void append_to_file(string filename, string data){
ofstream file;
file.open(filename, ios::app);

if(!file){
Utils::create_file(filename, data);
return;
}

file << data;
file.close();
}

void create_file(string filename, string data){
ofstream file(filename);

if(file.is_open()){
file << data;
file.close();
}
}
}
6 changes: 6 additions & 0 deletions helpers/utils.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
#pragma once

#include <iostream>

using std::string;

namespace Utils {
unsigned int get_random_pos(unsigned int max_range, unsigned int factor);
double random(int start, int end);
bool passed_debounce_time(int last_tick);
void append_to_file(string filename, string data);
void create_file(string filename, string data);
}
40 changes: 39 additions & 1 deletion machine/machine.cpp
Original file line number Diff line number Diff line change
@@ -1,43 +1,81 @@
#include "machine.h"
#include "layer.h"
#include "weights.h"
#include <stdexcept>
#include <vector>

using Layers::Layer;
using NNWeights::Weights;
using std::vector;
using std::string;
using std::invalid_argument;

namespace Machine {
vector<Layer*>* layers;
vector<Weights*>* weights;
unsigned int total_layers;

void NN::add_layer(unsigned int size){
layers->push_back(new Layer(size));
this->total_layers++;
this->NN::add_weight();
}

void NN::add_layer(Layer* layer){
layers->push_back(layer);
this->total_layers++;
}
this->NN::add_weight();
}

void NN::add_weight(){
if(this->total_layers <= 1)
return;

unsigned int first_layer_size = layers->at(this->total_layers-2)->get_size();
unsigned int second_layer_size = layers->at(this->total_layers-1)->get_size();
this->weights->push_back(new Weights(first_layer_size, second_layer_size));
this->total_weights++;
}

vector<Layer*>* NN::get_layers(){
return this->layers;
}

void NN::save_weights(string filename){
for(Weights* weight: (*this->weights))
weight->save_weights(filename);
}

Layer* NN::get_layer(unsigned int i){
if(this->total_layers == 0 || i > this->total_layers-1)
throw invalid_argument("invalid layer position");

return this->layers->at(i);
}

Weights* NN::get_weight(unsigned int i){
if(this->total_weights == 0 || i > this->total_weights-1)
throw invalid_argument("invalid weights position");
return this->weights->at(i);
}

vector<Weights*>* NN::get_weights(){
return this->weights;
}


unsigned int NN::get_total_layers(){
return this->total_layers;
}

unsigned int NN::get_total_weights(){
return this->total_weights;
}

NN::~NN(){
for(Layer *layer: (*this->layers))
delete layer;
for(Weights *weight: (*this->weights))
delete weight;
}
}
15 changes: 11 additions & 4 deletions machine/machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,32 @@
#include <iostream>
#include <vector>
#include "layer.h"
#include "weights.h"

using Layers::Layer;
using NNWeights::Weights;
using std::vector;
using std::string;

namespace Machine {

class NN{

class NN{
public:
void add_layer(unsigned int size);
void add_layer(Layer* layer);
vector<Layer*>* get_layers();
vector<Weights*>* get_weights();
Layer* get_layer(unsigned int i);
Weights* get_weight(unsigned int i);
unsigned int get_total_layers();
unsigned int get_total_weights();
void save_weights(string filename);
~NN();

private:
vector<Layer*>* layers = new vector<Layer*>;
vector<Weights*>* weights = new vector<Weights*>;
unsigned int total_layers = 0;

unsigned int total_weights = 0;
void add_weight();
};
};
25 changes: 14 additions & 11 deletions machine/weights.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#include "weights.h"
#include "../matrix/matrix.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include "../helpers/utils.h"

using Matrices::Matrix;
using std::ofstream;
using std::string;
using std::stringstream;
using Utils::append_to_file;

namespace NNWeights {
unsigned int width, height;
Expand All @@ -26,18 +28,19 @@ namespace NNWeights {
}

void Weights::save_weights(string filename){
ofstream file(filename);
stringstream header;
header << "w:" << this->width << ";h:" << this->height << "\n";

append_to_file(filename, header.str());

if(file.is_open()){
file << "w:" << this->width << ";h:" << this->height << "\n";

for(unsigned int i = 0; i < this->height; i++){
for(unsigned int j = 0; j < this->width; j++)
file << this->weights->get_position_value(i, j) << (j < this->width-1 ? "," : "");
file << "\n";
for(unsigned int i = 0; i < this->height; i++){
for(unsigned int j = 0; j < this->width; j++){
stringstream data;
data << this->weights->get_position_value(i, j) << (j < this->width-1 ? "," : "");
append_to_file(filename, data.str());
}

file.close();
append_to_file(filename, "\n");
}
}

Expand Down
10 changes: 10 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstdlib>
#include <cstring>
#include <iostream>
#include "machine/machine.h"
#include "matrix/matrix.h"
#include "genetic/gene.h"
#include "genetic/chromosome.h"
Expand Down Expand Up @@ -79,6 +80,15 @@ int main(){
weights->save_weights("test.wg");
delete weights;

Machine::NN *nn = new Machine::NN;
nn->add_layer(6);
nn->add_layer(10);
nn->add_layer(5);
nn->add_layer(3);
nn->save_weights("hello.wg");
delete nn;


/*
char *title = (char*) malloc(10*sizeof(char));
strcpy(title, "snake game");
Expand Down

0 comments on commit 3891ed2

Please sign in to comment.