diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 8328f3181a3df..6be2a07bc4bd5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2577,6 +2577,22 @@ def test_collect_nested_type(self): ).collect(), ) + def test_simple_udt(self): + from pyspark.ml.linalg import MatrixUDT, VectorUDT + + for schema in [ + StructType().add("key", LongType()).add("val", PythonOnlyUDT()), + StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())), + StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())), + StructType().add("key", LongType()).add("val", PythonOnlyUDT()), + StructType().add("key", LongType()).add("vec", VectorUDT()), + StructType().add("key", LongType()).add("mat", MatrixUDT()), + ]: + cdf = self.connect.createDataFrame(data=[], schema=schema) + sdf = self.spark.createDataFrame(data=[], schema=schema) + + self.assertEqual(cdf.schema, sdf.schema) + def test_simple_udt_from_read(self): from pyspark.ml.linalg import Matrices, Vectors