diff --git a/build.properties b/build.properties index a2498df5..54376951 100644 --- a/build.properties +++ b/build.properties @@ -15,7 +15,7 @@ test.result.dir=${target.dir}/test-results user.name=myui #java.version=1.6 -project.version=0.4.0 +project.version=0.4.0-1 project.name=Hivemall project.groupId=hivemall project.organization.name=Treasure Data, Inc. diff --git a/pom.xml b/pom.xml index a3fa7087..0d1b9127 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ io.github.myui hivemall - 0.4.0 + 0.4.0-1 Hivemall Scalable Machine Learning Library for Apache Hive diff --git a/src/main/java/hivemall/HivemallConstants.java b/src/main/java/hivemall/HivemallConstants.java index bab3eda9..11f33335 100644 --- a/src/main/java/hivemall/HivemallConstants.java +++ b/src/main/java/hivemall/HivemallConstants.java @@ -20,7 +20,7 @@ public final class HivemallConstants { - public static final String VERSION = "0.4.0"; + public static final String VERSION = "0.4.0-1"; public static final String BIAS_CLAUSE = "0"; public static final String CONFKEY_RAND_AMPLIFY_SEED = "hivemall.amplify.seed"; diff --git a/src/main/java/hivemall/fm/FMPredictUDAF.java b/src/main/java/hivemall/fm/FMPredictUDAF.java index 70ac3428..18b7acfb 100644 --- a/src/main/java/hivemall/fm/FMPredictUDAF.java +++ b/src/main/java/hivemall/fm/FMPredictUDAF.java @@ -109,6 +109,11 @@ void iterate(@Nullable DoubleWritable Wj, @Nullable List Vjf, @Nu if(Xj == null) { throw new HiveException("Xj should not be null"); } + final int factor = Vjf.size(); + if(factor == 0) {// workaround for TD + return; + } + if(sumVjXj == null) { int factors = Vjf.size(); this.sumVjXj = Arrays.asList(MutableDouble.initArray(factors, 0.d)); @@ -116,7 +121,6 @@ void iterate(@Nullable DoubleWritable Wj, @Nullable List Vjf, @Nu } final double x = Xj.get(); - final int factor = Vjf.size(); if(factor < 1) { throw new HiveException("# of Factor should be more than 0: " + Vjf.toString()); } @@ -141,6 +145,7 @@ void merge(PartialResult other) { this.sumV2X2 = other.sumV2X2; } else { add(other.sumVjXj, sumVjXj); + assert (sumV2X2 != null); add(other.sumV2X2, sumV2X2); } } @@ -163,7 +168,10 @@ void merge(PartialResult other) { return ret; } - private static void add(@Nonnull final List src, @Nonnull final List dst) { + private static void add(@Nullable final List src, @Nonnull final List dst) { + if(src == null) { + return; + } for(int i = 0, size = src.size(); i < size; i++) { MutableDouble s = src.get(i); assert (s != null); diff --git a/src/main/java/hivemall/mf/MFPredictionUDF.java b/src/main/java/hivemall/mf/MFPredictionUDF.java index 56c65244..2f5fb226 100644 --- a/src/main/java/hivemall/mf/MFPredictionUDF.java +++ b/src/main/java/hivemall/mf/MFPredictionUDF.java @@ -38,12 +38,17 @@ public FloatWritable evaluate(List Pu, List Qi, double mu) throws if(Pu == null || Qi == null) { return null; //throw new HiveException("Pu should not be NULL"); } - final int factor = Pu.size(); - if(Qi.size() != factor) { - throw new HiveException("|Pu| " + factor + " was not equal to |Qi| " + Qi.size()); + final int PuSize = Pu.size(); + final int QiSize = Qi.size(); + if(QiSize != PuSize) { + throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize); } + if(PuSize == 0) {// workaround for TD + return null; + } + float ret = (float) mu; - for(int k = 0; k < factor; k++) { + for(int k = 0; k < PuSize; k++) { ret += Pu.get(k) * Qi.get(k); } return new FloatWritable(ret); @@ -68,12 +73,17 @@ public FloatWritable evaluate(List Pu, List Qi, double Bu, double return new FloatWritable(ret); } - final int factor = Pu.size(); - if(Qi.size() != factor) { - throw new HiveException("|Pu| " + factor + " was not equal to |Qi| " + Qi.size()); + final int PuSize = Pu.size(); + final int QiSize = Qi.size(); + if(QiSize != PuSize) { + throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize); } + if(PuSize == 0) {// workaround for TD + return null; + } + float ret = (float) (mu + Bu + Bi); - for(int k = 0; k < factor; k++) { + for(int k = 0; k < PuSize; k++) { ret += Pu.get(k) * Qi.get(k); } return new FloatWritable(ret);