diff --git a/src/main/scala/com/redislabs/provider/redis/util/Logging.scala b/src/main/scala/com/redislabs/provider/redis/util/Logging.scala index 814586d8..620b565d 100644 --- a/src/main/scala/com/redislabs/provider/redis/util/Logging.scala +++ b/src/main/scala/com/redislabs/provider/redis/util/Logging.scala @@ -29,6 +29,12 @@ trait Logging { } } + def logWarn(msg: => String): Unit = { + if (logger.isWarnEnabled) { + _logger.warn(msg) + } + } + def logDebug(msg: => String): Unit = { if (logger.isDebugEnabled) { _logger.debug(msg) diff --git a/src/main/scala/com/redislabs/provider/redis/util/SparkUtils.scala b/src/main/scala/com/redislabs/provider/redis/util/SparkUtils.scala new file mode 100644 index 00000000..3b6de2b5 --- /dev/null +++ b/src/main/scala/com/redislabs/provider/redis/util/SparkUtils.scala @@ -0,0 +1,17 @@ +package com.redislabs.provider.redis.util + +import org.apache.spark.sql.types.StructType + +object SparkUtils { + /** + * Setting the schema column positions the same order as in requiredFields + * @param schema Current schema + * @param requiredColumns Column positions expecting by Catalyst + */ + def alignSchemaWithCatalyst(schema: StructType, requiredColumns: Seq[String]): StructType = { + val fieldsMap = schema.fields.map(f => (f.name, f)).toMap + StructType(requiredColumns.map { c => + fieldsMap(c) + }) + } +} diff --git a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala index c9b0a981..be4d7572 100644 --- a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala @@ -2,6 +2,7 @@ package org.apache.spark.sql.redis import java.nio.charset.StandardCharsets.UTF_8 +import com.redislabs.provider.redis.util.SparkUtils import org.apache.commons.lang3.SerializationUtils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema @@ -34,6 +35,10 @@ class BinaryRedisPersistence extends RedisPersistence[Array[Byte]] { override def decodeRow(keyMap: (String, String), value: Array[Byte], schema: StructType, requiredColumns: Seq[String]): Row = { val valuesArray: Array[Any] = SerializationUtils.deserialize(value) - new GenericRowWithSchema(valuesArray, schema) + // Aligning column positions with what Catalyst expecting + val alignedSchema = SparkUtils.alignSchemaWithCatalyst(schema, requiredColumns) + val names = schema.fieldNames + val alignedValuesArray = requiredColumns.toArray.map(f => valuesArray(names.indexOf(f))) + new GenericRowWithSchema(alignedValuesArray, alignedSchema) } } diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index f2c84911..6561fca8 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -5,7 +5,7 @@ import java.util.{List => JList} import com.redislabs.provider.redis.rdd.Keys import com.redislabs.provider.redis.util.ConnectionUtils.withConnection -import com.redislabs.provider.redis.util.Logging +import com.redislabs.provider.redis.util.{Logging, SparkUtils} import com.redislabs.provider.redis.util.PipelineUtils._ import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisDataTypeHash, RedisDataTypeString, RedisEndpoint, RedisNode, toRedisContext} import org.apache.commons.lang3.SerializationUtils @@ -159,19 +159,17 @@ class RedisSourceRelation(override val sqlContext: SQLContext, new GenericRow(Array[Any]()) } } else { - // filter schema columns, it should be in the same order as given 'requiredColumns' - val requiredSchema = { - val fieldsMap = schema.fields.map(f => (f.name, f)).toMap - val requiredFields = requiredColumns.map { c => - fieldsMap(c) - } - StructType(requiredFields) - } - val keyType = + /* + For binary its crucial to have a schema, as we cen't infer it and catalyst requiredColumns doesn't guarantee + the same order. Thus the schema is only place where we can read correct attribute positions for binary + */ + val (keyType, requiredSchema) = if (persistenceModel == SqlOptionModelBinary) { - RedisDataTypeString + if (this.schema == null) + logWarn("Unable to identify the schema when reading a dataframe in Binary mode. It can cause type inconsistency!") + (RedisDataTypeString, this.schema) } else { - RedisDataTypeHash + (RedisDataTypeHash, SparkUtils.alignSchemaWithCatalyst(this.schema, requiredColumns)) } keysRdd.mapPartitions { partition => // grouped iterator to only allocate memory for a portion of rows