Skip to content

Commit

Permalink
Added tests with Syntethic data and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
ivankitanovski committed Apr 29, 2019
1 parent a16f790 commit 361a285
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 107 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ DeltrTopDocs preidictionSet = new DeltrTopDocsImpl(2); // 2 is the question ID
// let's create each prediction doc manually (docId, judgement/score)

// item 7
DeltrDoc item7 = new DeltrDocImpl(7, 0.9645f); // the curret score is not important
DeltrDoc item7 = new DeltrDocImpl(7, 0.9645f); // the current score is not really important
item7.put("f0", 0.0, false);
item7.put("f1", 0.9645);

Expand Down Expand Up @@ -154,7 +154,7 @@ item12.put("f1", 0.8312);

//add the items in the set
DeltrDoc[] predArr = new DeltrDoc[]{item7, item8, item9, item10, item11, item12};
preidictionSet.setDocs(predArr);
preidictionSet.put(predArr);

DeltrTopDocs reranked = deltr.rank(preidictionSet);
// reranked ->
Expand Down
34 changes: 26 additions & 8 deletions src/main/java/com/github/fairsearch/deltr/Deltr.java
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ public DeltrTopDocs rank(DeltrTopDocs docs) {
DeltrDoc doc = docs.doc(i);
for(String key : doc.keys()) {
if(!key.equals(doc.protectedFeatureName()))
doc.set(key, (doc.feature(key) - this.mu)/this.sigma);
doc.put(key, (doc.feature(key) - this.mu)/this.sigma);
}
}
}

//re-calculate the judgement for each document
for(DeltrDoc doc : docs.docs()) {
for(int j=0; j<docs.size(); j++) {
DeltrDoc doc = docs.doc(j);
double dotProduct = 0;
for(int i=0; i<doc.size(); i++) {
dotProduct += doc.feature(i) * this.omega[i];
Expand Down Expand Up @@ -181,19 +182,36 @@ private TrainerData append(TrainerData data) {
//copy query data
int[] tmp = queryIds;
queryIds = new int[tmp.length + data.queryIds.length];
System.arraycopy(data, 0, queryIds, tmp.length, data.queryIds.length);
System.arraycopy(tmp, 0, queryIds, 0, tmp.length);
System.arraycopy(data.queryIds, 0, queryIds, tmp.length, data.queryIds.length);

//copy protected assignFeature data
tmp = protectedElementFeature;
protectedElementFeature = new int[tmp.length + data.protectedElementFeature.length];
System.arraycopy(data, 0, protectedElementFeature, tmp.length, data.protectedElementFeature.length);
System.arraycopy(tmp, 0, protectedElementFeature, 0, tmp.length);
System.arraycopy(data.protectedElementFeature, 0, protectedElementFeature,
tmp.length, data.protectedElementFeature.length);

//create new feature matrix
INDArray newFeatureMatrix = Nd4j.create(featureMatrix.rows() + data.featureMatrix.rows(),
featureMatrix.columns());
INDArray newTrainingScores = Nd4j.create(trainingScores.rows() + data.trainingScores.rows(),
trainingScores.columns());


//copy assignFeature matrix and training scores (if any)
for(int i=0; i<data.featureMatrix.rows(); i++) {
featureMatrix.addRowVector(data.featureMatrix.getRow(i));
if(trainingScores != null && data.trainingScores != null)
trainingScores.addRowVector(data.trainingScores.getRow(i));
for(int i=0; i< featureMatrix.rows(); i++) {
newFeatureMatrix.putRow(i, featureMatrix.getRow(i));
newTrainingScores.putRow(i, trainingScores.getRow(i));
}
for(int i=0; i< data.featureMatrix.rows(); i++) {
newFeatureMatrix.putRow(i + featureMatrix.rows(), data.featureMatrix.getRow(i));
newTrainingScores.putRow(i + featureMatrix.rows(), data.trainingScores.getRow(i));
}

featureMatrix = newFeatureMatrix;
trainingScores = newTrainingScores;

return this;
}
}
Expand Down
Loading

0 comments on commit 361a285

Please sign in to comment.