diff --git a/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala b/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala
index b1c77a0ba..92b014e34 100644
--- a/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala
+++ b/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala
@@ -131,7 +131,7 @@ trait SparkCassandraITSpecBase
def pv = conn.withSessionDo(_.getContext.getProtocolVersion)
- def report(message: String): Unit = alert(message)
+ def report(message: String): Unit = cancel(message)
val ks = getKsName
@@ -147,16 +147,23 @@ trait SparkCassandraITSpecBase
/** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version
* (if this is a DSE cluster) */
- def from(cassandra: Version, dse: Version)(f: => Unit): Unit = {
+ def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f)
+ def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = {
if (isDse(conn)) {
- from(dse)(f)
+ dse match {
+ case Some(dseVersion) => from(dseVersion)(f)
+ case None => report(s"Skipped because not DSE")
+ }
} else {
- from(cassandra)(f)
+ cassandra match {
+ case Some(cassandraVersion) => from(cassandraVersion)(f)
+ case None => report(s"Skipped because not Cassandra")
+ }
}
}
/** Skips the given test if the Cluster Version is lower or equal to the given version */
- def from(version: Version)(f: => Unit): Unit = {
+ private def from(version: Version)(f: => Unit): Unit = {
skip(cluster.getCassandraVersion, version) { f }
}
diff --git a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala
index f30f09d0b..a3b30aef9 100644
--- a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala
+++ b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala
@@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
| d14_varchar varchar,
| d15_varint varint,
| d16_address frozen
,
+ | d17_vector frozen>,
| PRIMARY KEY ((k1, k2, k3), c1, c2, c3)
|)
""".stripMargin)
@@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
"allow to read regular column definitions" in {
val columns = table.regularColumns
- columns.size shouldBe 16
+ columns.size shouldBe 17
columns.map(_.columnName).toSet shouldBe Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
- "d15_varint", "d16_address")
+ "d15_varint", "d16_address", "d17_vector")
}
"allow to read proper types of columns" in {
@@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
table.columnByName("d14_varchar").columnType shouldBe VarCharType
table.columnByName("d15_varint").columnType shouldBe VarIntType
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
+ table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3)
}
"allow to list fields of a user defined type" in {
diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala
index 3a9ac7e90..7bf60fe28 100644
--- a/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala
+++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala
@@ -9,7 +9,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption
import com.datastax.oss.driver.api.core.cql.SimpleStatement
import com.datastax.oss.driver.api.core.cql.SimpleStatement._
import com.datastax.spark.connector._
-import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0}
+import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf}
import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean}
@@ -794,7 +794,7 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster
results should contain ((KeyGroup(3, 300), (3, 300, "0003")))
}
- it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) {
+ it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) {
val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect
result.length should be (1)
}
diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala
index d882cbbd6..5175173d3 100644
--- a/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala
+++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala
@@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement}
import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion}
import com.datastax.spark.connector._
-import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0
+import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.embedded.SparkTemplate._
@@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {
}
- it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){
+ it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){
val source = sc.parallelize(keys).map(x => (x, x * 100))
val someCass = source
.joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group"))
diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala
new file mode 100644
index 000000000..af860d84d
--- /dev/null
+++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala
@@ -0,0 +1,242 @@
+package com.datastax.spark.connector.rdd.typeTests
+
+import com.datastax.oss.driver.api.core.cql.Row
+import com.datastax.oss.driver.api.core.{CqlSession, Version}
+import com.datastax.spark.connector._
+import com.datastax.spark.connector.ccm.CcmConfig
+import com.datastax.spark.connector.cluster.DefaultCluster
+import com.datastax.spark.connector.cql.CassandraConnector
+import com.datastax.spark.connector.datasource.CassandraCatalog
+import com.datastax.spark.connector.mapper.ColumnMapper
+import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType}
+import com.datastax.spark.connector.rdd.reader.RowReaderFactory
+import com.datastax.spark.connector.types.TypeConverter
+import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper}
+
+import scala.collection.convert.ImplicitConversionsToScala._
+import scala.collection.immutable
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+import scala.reflect._
+
+
+abstract class VectorTypeTest[
+ ScalaType: ClassTag : TypeTag,
+ DriverType <: Number : ClassTag,
+ CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster
+{
+ /** Skips the given test if the cluster is not Cassandra */
+ override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f)
+
+ override lazy val conn = CassandraConnector(sparkConf)
+
+ val VectorTable = "vectors"
+
+ def createVectorTable(session: CqlSession, table: String): Unit = {
+ session.execute(
+ s"""CREATE TABLE IF NOT EXISTS $ks.$table (
+ | id INT PRIMARY KEY,
+ | v VECTOR<$typeName, 3>
+ |)""".stripMargin)
+ }
+
+ def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1"))
+
+ def minDSEVersion: Option[Version] = None
+
+ def vectorFromInts(ints: Int*): Seq[ScalaType]
+
+ def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType
+
+ override lazy val spark = SparkSession.builder()
+ .config(sparkConf)
+ .config("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog")
+ .withExtensions(new CassandraSparkExtensions).getOrCreate().newSession()
+
+ override def beforeClass() {
+ conn.withSessionDo { session =>
+ session.execute(
+ s"""CREATE KEYSPACE IF NOT EXISTS $ks
+ |WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""
+ .stripMargin)
+ }
+ }
+
+ private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = {
+ val returnedVectors = for (i <- expectedVectors.indices) yield {
+ rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq
+ }
+
+ returnedVectors should contain theSameElementsInOrderAs expectedVectors
+ }
+
+ "SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_write_caseclass_to_df"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+
+ spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
+ .write
+ .cassandraFormat(table, ks)
+ .mode(SaveMode.Append)
+ .save()
+ hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
+ Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
+
+ spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9))))
+ .write
+ .cassandraFormat(table, ks)
+ .mode(SaveMode.Append)
+ .save()
+ hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
+ Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9)))
+
+ spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12))))
+ .write
+ .cassandraFormat(table, ks)
+ .mode(SaveMode.Overwrite)
+ .option("confirm.truncate", value = true)
+ .save()
+ hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
+ Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12)))
+ }
+ }
+
+ it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_write_tuple_to_df"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+
+ spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .toDF("id", "v")
+ .write
+ .cassandraFormat(table, ks)
+ .mode(SaveMode.Append)
+ .save()
+ hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
+ Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
+ }
+ }
+
+ it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+
+ spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+ hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
+ Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
+ }
+ }
+
+ it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_write_tuple_to_rdd"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+
+ spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+ hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
+ Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
+ }
+ }
+
+ it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_read_caseclass_from_df"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+ }
+ spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+
+ import spark.implicits._
+ spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs
+ Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
+ }
+
+ it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_read_tuple_from_df"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+ }
+ spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+
+ import spark.implicits._
+ spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
+ Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
+ }
+
+ it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+ }
+ spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+
+ spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs
+ Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
+ }
+
+ it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_read_tuple_from_rdd"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+ }
+ spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+
+ spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs
+ Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
+ }
+
+ it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) {
+ val table = s"${typeName.toLowerCase}_read_rows_from_sql"
+ conn.withSessionDo { session =>
+ createVectorTable(session, table)
+ }
+ spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
+ .saveToCassandra(ks, table)
+
+ import spark.implicits._
+ spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
+ Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
+ }
+
+}
+
+class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") {
+ override def vectorFromInts(ints: Int*): Seq[Int] = ints
+
+ override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v)
+}
+
+case class IntVectorItem(id: Int, v: Seq[Int])
+
+class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") {
+ override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong)
+
+ override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v)
+}
+
+case class LongVectorItem(id: Int, v: Seq[Long])
+
+class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") {
+ override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f)
+
+ override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v)
+}
+
+case class FloatVectorItem(id: Int, v: Seq[Float])
+
+class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") {
+ override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d)
+
+ override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v)
+}
+
+case class DoubleVectorItem(id: Int, v: Seq[Double])
+
diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala
index 19764a0a6..657fab249 100644
--- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala
+++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala
@@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource
import java.util.Locale
import com.datastax.oss.driver.api.core.ProtocolVersion
-import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
+import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
import com.datastax.oss.driver.api.core.`type`.DataTypes._
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
@@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
case udt: UserDefinedType => fromUdt(udt)
case t: TupleType => fromTuple(t)
+ case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
case VARINT =>
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
primitiveCatalystDataType(cassandraType)
diff --git a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala
index 1287bf23d..0f279215d 100644
--- a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala
+++ b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala
@@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
cassandraType match {
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
+ case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField))
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple))
diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala
index b4afa5da6..3c68b5f52 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala
@@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper](
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
+ case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
+ val argConverter = converter(argColumnType, argScalaType)
+ TypeConverter.forType[U](Seq(argConverter))
+
case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala
index 3b69267ae..05f682f47 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala
@@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging {
val valueConverter = converter(valueColumnType, valueScalaType)
TypeConverter.javaHashMapConverter(keyConverter, valueConverter)
+ case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) =>
+ val argConverter = converter(argColumnType, argScalaType)
+ TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]])
+
case (tt @ TupleType(argColumnType1, argColumnType2),
TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) =>
val c1 = converter(argColumnType1.columnType, argScalaType1)
diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala
index 99a3f9fb3..b5aaf57fb 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala
@@ -7,7 +7,7 @@ import java.util.{Date, UUID}
import com.datastax.dse.driver.api.core.`type`.DseDataTypes
import com.datastax.oss.driver.api.core.DefaultProtocolVersion.V4
import com.datastax.oss.driver.api.core.ProtocolVersion
-import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType}
+import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType, VectorType => DriverVectorType}
import com.datastax.spark.connector.util._
@@ -77,6 +77,7 @@ object ColumnType {
case mapType: DriverMapType => MapType(fromDriverType(mapType.getKeyType), fromDriverType(mapType.getValueType), mapType.isFrozen)
case userType: DriverUserDefinedType => UserDefinedType(userType)
case tupleType: DriverTupleType => TupleType(tupleType)
+ case vectorType: DriverVectorType => VectorType(fromDriverType(vectorType.getElementType), vectorType.getDimensions)
case dataType => primitiveTypeMap(dataType)
}
@@ -153,6 +154,7 @@ object ColumnType {
val converter: TypeConverter[_] =
dataType match {
case list: DriverListType => TypeConverter.javaArrayListConverter(converterToCassandra(list.getElementType))
+ case vec: DriverVectorType => TypeConverter.cqlVectorConverter(vec.getDimensions)(converterToCassandra(vec.getElementType).asInstanceOf[TypeConverter[Number]])
case set: DriverSetType => TypeConverter.javaHashSetConverter(converterToCassandra(set.getElementType))
case map: DriverMapType => TypeConverter.javaHashMapConverter(converterToCassandra(map.getKeyType), converterToCassandra(map.getValueType))
case udt: DriverUserDefinedType => new UserDefinedType.DriverUDTValueConverter(udt)
diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala
index 58615965b..ea13093e9 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala
@@ -9,7 +9,7 @@ import java.util.{Calendar, Date, GregorianCalendar, TimeZone, UUID}
import com.datastax.dse.driver.api.core.data.geometry.{LineString, Point, Polygon}
import com.datastax.dse.driver.api.core.data.time.DateRange
-import com.datastax.oss.driver.api.core.data.CqlDuration
+import com.datastax.oss.driver.api.core.data.{CqlDuration, CqlVector}
import com.datastax.spark.connector.TupleValue
import com.datastax.spark.connector.UDTValue.UDTValueConverter
import com.datastax.spark.connector.util.ByteBufferUtil
@@ -700,6 +700,7 @@ object TypeConverter {
case x: java.util.List[_] => newCollection(x.asScala)
case x: java.util.Set[_] => newCollection(x.asScala)
case x: java.util.Map[_, _] => newCollection(x.asScala)
+ case x: CqlVector[_] => newCollection(x.asScala)
case x: Iterable[_] => newCollection(x)
}
}
@@ -768,6 +769,29 @@ object TypeConverter {
}
}
+ class CqlVectorConverter[T <: Number : TypeConverter](dimension: Int) extends TypeConverter[CqlVector[T]] {
+ val elemConverter = implicitly[TypeConverter[T]]
+
+ implicit def elemTypeTag: TypeTag[T] = elemConverter.targetTypeTag
+
+ @transient
+ lazy val targetTypeTag = {
+ implicitly[TypeTag[CqlVector[T]]]
+ }
+
+ def newCollection(items: Iterable[Any]): java.util.List[T] = {
+ val buf = new java.util.ArrayList[T](dimension)
+ for (item <- items) buf.add(elemConverter.convert(item))
+ buf
+ }
+
+ def convertPF = {
+ case x: CqlVector[_] => x.asInstanceOf[CqlVector[T]] // it is an optimization - should we skip converting the elements?
+ case x: java.lang.Iterable[_] => CqlVector.newInstance[T](newCollection(x.asScala))
+ case x: Iterable[_] => CqlVector.newInstance[T](newCollection(x))
+ }
+ }
+
class JavaArrayListConverter[T : TypeConverter] extends CollectionConverter[java.util.ArrayList[T], T] {
@transient
lazy val targetTypeTag = {
@@ -869,6 +893,9 @@ object TypeConverter {
implicit def javaArrayListConverter[T : TypeConverter]: JavaArrayListConverter[T] =
new JavaArrayListConverter[T]
+ implicit def cqlVectorConverter[T <: Number : TypeConverter](dimension: Int): CqlVectorConverter[T] =
+ new CqlVectorConverter[T](dimension)
+
implicit def javaSetConverter[T : TypeConverter]: JavaSetConverter[T] =
new JavaSetConverter[T]
diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala
new file mode 100644
index 000000000..8060fe225
--- /dev/null
+++ b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala
@@ -0,0 +1,20 @@
+package com.datastax.spark.connector.types
+
+import scala.language.existentials
+import scala.reflect.runtime.universe._
+
+case class VectorType[T](elemType: ColumnType[T], dimension: Int) extends ColumnType[Seq[T]] {
+
+ override def isCollection: Boolean = false
+
+ @transient
+ lazy val scalaTypeTag = {
+ implicit val elemTypeTag = elemType.scalaTypeTag
+ implicitly[TypeTag[Seq[T]]]
+ }
+
+ def cqlTypeName = s"vector<${elemType.cqlTypeName}, ${dimension}>"
+
+ override def converterToCassandra: TypeConverter[_ <: AnyRef] =
+ new TypeConverter.OptionToNullConverter(TypeConverter.seqConverter(elemType.converterToCassandra))
+}