diff --git a/build.properties b/build.properties
index a2498df5..54376951 100644
--- a/build.properties
+++ b/build.properties
@@ -15,7 +15,7 @@ test.result.dir=${target.dir}/test-results
user.name=myui
#java.version=1.6
-project.version=0.4.0
+project.version=0.4.0-1
project.name=Hivemall
project.groupId=hivemall
project.organization.name=Treasure Data, Inc.
diff --git a/pom.xml b/pom.xml
index a3fa7087..0d1b9127 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
io.github.myui
hivemall
- 0.4.0
+ 0.4.0-1
Hivemall
Scalable Machine Learning Library for Apache Hive
diff --git a/src/main/java/hivemall/HivemallConstants.java b/src/main/java/hivemall/HivemallConstants.java
index bab3eda9..11f33335 100644
--- a/src/main/java/hivemall/HivemallConstants.java
+++ b/src/main/java/hivemall/HivemallConstants.java
@@ -20,7 +20,7 @@
public final class HivemallConstants {
- public static final String VERSION = "0.4.0";
+ public static final String VERSION = "0.4.0-1";
public static final String BIAS_CLAUSE = "0";
public static final String CONFKEY_RAND_AMPLIFY_SEED = "hivemall.amplify.seed";
diff --git a/src/main/java/hivemall/fm/FMPredictUDAF.java b/src/main/java/hivemall/fm/FMPredictUDAF.java
index 70ac3428..18b7acfb 100644
--- a/src/main/java/hivemall/fm/FMPredictUDAF.java
+++ b/src/main/java/hivemall/fm/FMPredictUDAF.java
@@ -109,6 +109,11 @@ void iterate(@Nullable DoubleWritable Wj, @Nullable List Vjf, @Nu
if(Xj == null) {
throw new HiveException("Xj should not be null");
}
+ final int factor = Vjf.size();
+ if(factor == 0) {// workaround for TD
+ return;
+ }
+
if(sumVjXj == null) {
int factors = Vjf.size();
this.sumVjXj = Arrays.asList(MutableDouble.initArray(factors, 0.d));
@@ -116,7 +121,6 @@ void iterate(@Nullable DoubleWritable Wj, @Nullable List Vjf, @Nu
}
final double x = Xj.get();
- final int factor = Vjf.size();
if(factor < 1) {
throw new HiveException("# of Factor should be more than 0: " + Vjf.toString());
}
@@ -141,6 +145,7 @@ void merge(PartialResult other) {
this.sumV2X2 = other.sumV2X2;
} else {
add(other.sumVjXj, sumVjXj);
+ assert (sumV2X2 != null);
add(other.sumV2X2, sumV2X2);
}
}
@@ -163,7 +168,10 @@ void merge(PartialResult other) {
return ret;
}
- private static void add(@Nonnull final List src, @Nonnull final List dst) {
+ private static void add(@Nullable final List src, @Nonnull final List dst) {
+ if(src == null) {
+ return;
+ }
for(int i = 0, size = src.size(); i < size; i++) {
MutableDouble s = src.get(i);
assert (s != null);
diff --git a/src/main/java/hivemall/mf/MFPredictionUDF.java b/src/main/java/hivemall/mf/MFPredictionUDF.java
index 56c65244..2f5fb226 100644
--- a/src/main/java/hivemall/mf/MFPredictionUDF.java
+++ b/src/main/java/hivemall/mf/MFPredictionUDF.java
@@ -38,12 +38,17 @@ public FloatWritable evaluate(List Pu, List Qi, double mu) throws
if(Pu == null || Qi == null) {
return null; //throw new HiveException("Pu should not be NULL");
}
- final int factor = Pu.size();
- if(Qi.size() != factor) {
- throw new HiveException("|Pu| " + factor + " was not equal to |Qi| " + Qi.size());
+ final int PuSize = Pu.size();
+ final int QiSize = Qi.size();
+ if(QiSize != PuSize) {
+ throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}
+ if(PuSize == 0) {// workaround for TD
+ return null;
+ }
+
float ret = (float) mu;
- for(int k = 0; k < factor; k++) {
+ for(int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
}
return new FloatWritable(ret);
@@ -68,12 +73,17 @@ public FloatWritable evaluate(List Pu, List Qi, double Bu, double
return new FloatWritable(ret);
}
- final int factor = Pu.size();
- if(Qi.size() != factor) {
- throw new HiveException("|Pu| " + factor + " was not equal to |Qi| " + Qi.size());
+ final int PuSize = Pu.size();
+ final int QiSize = Qi.size();
+ if(QiSize != PuSize) {
+ throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}
+ if(PuSize == 0) {// workaround for TD
+ return null;
+ }
+
float ret = (float) (mu + Bu + Bi);
- for(int k = 0; k < factor; k++) {
+ for(int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
}
return new FloatWritable(ret);