Skip to content

Commit

Permalink
Added library serialization
Browse files Browse the repository at this point in the history
Added serialization sample
Changed version to 1.0.0
  • Loading branch information
ivankitanovski committed Apr 30, 2019
1 parent aa46475 commit 2fa6fc0
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 7 deletions.
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ buildscript {
}

group 'com.github.fair-search'
version '0.0.1'
version '1.0.0'

apply plugin: 'java'
apply plugin: 'maven-publish'
Expand Down Expand Up @@ -59,7 +59,7 @@ publishing {
mavenJava(MavenPublication) {
groupId 'com.github.fair-search'
artifactId 'fairsearchdeltr-java'
version '0.0.1'
version '1.0.0'
from components.java
}
}
Expand Down
89 changes: 85 additions & 4 deletions src/main/java/com/github/fairsearch/deltr/Deltr.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package com.github.fairsearch.deltr;

import com.fasterxml.jackson.annotation.JsonGetter;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.github.fairsearch.deltr.models.DeltrDoc;
import com.github.fairsearch.deltr.models.DeltrTopDocs;
import com.github.fairsearch.deltr.models.TrainStep;
import com.github.fairsearch.deltr.parsers.DeltrDeserializer;
import com.google.common.primitives.Doubles;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.Serializable;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
Expand All @@ -23,22 +30,33 @@
* the resulting rankings and thus prevents systematic biases against a protected group in the model,
* even though such bias might be present in the training data.
*/
public class Deltr implements Serializable {

protected static final Logger LOGGER = Logger.getLogger(Deltr.class.getName());
@JsonDeserialize(using = DeltrDeserializer.class)
public class Deltr {

protected static final Logger LOGGER = Logger.getLogger(Deltr.class.getName());
@JsonProperty
private double gamma; //gamma parameter for the cost calculation in the training phase (recommended to be around 1)

@JsonProperty("number_of_iterations")
private int numberOfIterations; // number of iteration in gradient descent
@JsonProperty("learning_rate")
private double learningRate; // learning rate in gradient descent
@JsonProperty
private double lambda; // regularization constant
@JsonProperty("init_var")
private double initVar; // range of values for initialization of weights

@JsonProperty("standardize")
protected boolean shouldStandardize; // boolean indicating whether the data should be standardized or not
@JsonProperty
protected double mu = 0; // mu for standardization
@JsonProperty
protected double sigma = 0; // sigma for standardization

@JsonProperty
protected double[] omega = null;

@JsonIgnore
protected List<TrainStep> log = null;

/**
Expand Down Expand Up @@ -83,6 +101,25 @@ public Deltr(double gamma, int numberOfIterations, double learningRate, double l
this.shouldStandardize = shouldStandardize;
}

/**
* @param gamma gamma parameter for the cost calculation in the training phase (recommended to be around 1)
* @param numberOfIterations number of iteration in gradient descent
* @param learningRate learning rate in gradient descent
* @param lambda regularization constant
* @param initVar range of values for initialization of weights
* @param shouldStandardize boolean indicating whether the data should be standardized or not
* @param mu set mu for standardization
* @param sigma set sigma for standardization
* @param omega set precomputed omega
*/
public Deltr(double gamma, int numberOfIterations, double learningRate, double lambda,
double initVar, boolean shouldStandardize, double mu, double sigma, double[] omega){
this(gamma, numberOfIterations, learningRate, lambda, initVar, shouldStandardize);
this.mu = mu;
this.sigma = sigma;
this.omega = omega;
}

/**
* Trains a DELTR model on a given training set
* @param ranks A list of DeltrTopDocs (query-to-documents) containing `DeltrDoc` instance implementations
Expand Down Expand Up @@ -250,4 +287,48 @@ public double[] getOmega() {
public List<TrainStep> getLog() {
return this.log;
}

/**
* Serializes the object to a JSON string. The `log` is not serialized.
* @return A string representing the object
*/
public String toJson() {
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.writeValueAsString(this);
} catch (JsonProcessingException e) {
LOGGER.severe(String.format("Exception in parsing: '%s'", e.getMessage()));
}
return null;
}

/**
* Deseralizes a Deltr object from a JSON string.
* @param jsonString The JSON representation of the object
* @return The created Deltr instance
*/
public static Deltr createFromJson(String jsonString) {
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.readValue(jsonString, Deltr.class);
} catch (IOException e) {
LOGGER.severe(String.format("IOException in parsing: '%s'", e.getMessage()));
}
return null ;
}

@Override
public String toString() {
return "Deltr{" +
"gamma=" + gamma +
", numberOfIterations=" + numberOfIterations +
", learningRate=" + learningRate +
", lambda=" + lambda +
", initVar=" + initVar +
", shouldStandardize=" + shouldStandardize +
", mu=" + mu +
", sigma=" + sigma +
", omega=" + Arrays.toString(omega) +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.github.fairsearch.deltr.parsers;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.github.fairsearch.deltr.Deltr;

import java.io.IOException;

public class DeltrDeserializer extends JsonDeserializer<Deltr> {

@Override
public Deltr deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException, JsonProcessingException {
ObjectCodec oc = jp.getCodec();
JsonNode node = oc.readTree(jp);

final double gamma = node.get("gamma").asDouble();
final int numberOfIterations = node.get("number_of_iterations").asInt();
final double learningRate = node.get("learning_rate").asDouble();
final double lambda = node.get("lambda").asDouble();
final double initVar = node.get("init_var").asDouble();
final boolean shouldStandardize = node.get("standardize").asBoolean();

final double mu = node.get("mu").asDouble();
final double sigma = node.get("sigma").asDouble();

double[] omega = new double[node.get("omega").size()];
for(int i=0; i< node.get("omega").size(); i++) {
omega[i] = node.get("omega").get(i).asDouble();
}

return new Deltr(gamma, numberOfIterations, learningRate, lambda, initVar, shouldStandardize, mu, sigma, omega);
}
}
14 changes: 13 additions & 1 deletion src/test/java/com/github/fairsearch/deltr/HelloWorld.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
package com.github.fairsearch.deltr;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.fairsearch.deltr.models.DeltrDoc;
import com.github.fairsearch.deltr.models.DeltrDocImpl;
import com.github.fairsearch.deltr.models.DeltrTopDocs;
import com.github.fairsearch.deltr.models.DeltrTopDocsImpl;
import com.github.fairsearch.deltr.models.TrainStep;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class HelloWorld {

public static void main(String[] args) {
public static void main(String[] args) throws IOException {
// create some data
List<DeltrTopDocs> trainSet = new ArrayList<>();

Expand Down Expand Up @@ -129,5 +132,14 @@ public static void main(String[] args) {
// .
// .
// .

//serialize the object
String jsonString = deltr.toJson();
//print the string
System.out.println(jsonString);
//deserialize the object
Deltr again = Deltr.createFromJson(jsonString);
//print the object
System.out.println(again);
}
}

0 comments on commit 2fa6fc0

Please sign in to comment.